diff --git a/examples/chat/Dockerfile b/examples/chat/Dockerfile
index d331e8d..c4a2f05 100644
--- a/examples/chat/Dockerfile
+++ b/examples/chat/Dockerfile
@@ -2,6 +2,7 @@ FROM python:3
ENV FLASK_APP=chat.py
ENV FLASK_ENV=development
+ENV SHATTERED_APP=chat.py
COPY requirements.txt /src/requirements.txt
diff --git a/examples/chat/chat.py b/examples/chat/chat.py
index e51357d..b5f8827 100644
--- a/examples/chat/chat.py
+++ b/examples/chat/chat.py
@@ -21,7 +21,3 @@ def echo(headers, body, conn):
data = dict(parse_qsl(body))
email = data["email"]
conn.send(body=f'
{email}: {data["message"]}', destination="/topic/chat")
-
-
-if __name__ == "__main__":
- shattered_app.run()
diff --git a/examples/chat/docker-compose.yml b/examples/chat/docker-compose.yml
index 92256b2..ffd35b7 100644
--- a/examples/chat/docker-compose.yml
+++ b/examples/chat/docker-compose.yml
@@ -15,8 +15,8 @@ services:
"300",
"rabbitmq:61613",
"--",
- "python",
- "chat.py",
+ "shattered",
+ "run",
]
chat:
build: .
diff --git a/pyproject.toml b/pyproject.toml
index 9479744..6b705a5 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,12 +1,15 @@
[tool.poetry]
name = "shattered"
-version = "0.2.0"
+version = "0.3.0"
description = "STOMP meets bottle.py"
authors = ["Jimmy Bradshaw "]
repository = "https://github.com/bradshjg/shattered"
license = "MIT"
readme = "README-PYPI.md"
+[tool.poetry.scripts]
+shattered = "shattered.cli:main"
+
[tool.poetry.dependencies]
python = "^3.6"
"stomp.py" = "^5.0.0"
diff --git a/src/shattered/__init__.py b/src/shattered/__init__.py
index c3d2549..831105d 100644
--- a/src/shattered/__init__.py
+++ b/src/shattered/__init__.py
@@ -1,3 +1,3 @@
-__version__ = "0.2.0"
+__version__ = "0.3.0"
from shattered.shattered import Shattered
diff --git a/src/shattered/cli.py b/src/shattered/cli.py
new file mode 100644
index 0000000..2b6d57e
--- /dev/null
+++ b/src/shattered/cli.py
@@ -0,0 +1,47 @@
+import argparse
+import os
+import sys
+
+from .shattered import Shattered
+
+
+def load_app(args):
+ sys.path.insert(0, os.getcwd())
+ module_file = args.app or os.getenv("SHATTERED_APP", "app.py")
+ module_name = os.path.splitext(module_file)[0]
+ __import__(module_name)
+ module = sys.modules[module_name]
+ matches = [v for v in module.__dict__.values() if isinstance(v, Shattered)]
+ if len(matches) == 1:
+ return matches[0]
+ else:
+ raise EnvironmentError("Detected multiple Shattered applications.")
+
+
+def run_app(args):
+ app = load_app(args)
+ app.run()
+
+
+def print_config(args):
+ app = load_app(args)
+ for k, v in app.config.items():
+ print(f"{k}: {v}")
+
+
+def parse_arguments():
+ parser = argparse.ArgumentParser(prog="shattered")
+ parser.add_argument("-a", "--app", help="shattered app")
+ subparsers = parser.add_subparsers(
+ title="subcommands", description="valid subcommands", help="additional help"
+ )
+ parser_run = subparsers.add_parser("run", help="run shattered app")
+ parser_run.set_defaults(func=run_app)
+ parser_config = subparsers.add_parser("config", help="show shattered app config")
+ parser_config.set_defaults(func=print_config)
+ return parser.parse_args()
+
+
+def main():
+ args = parse_arguments()
+ args.func(args)
diff --git a/src/shattered/shattered.py b/src/shattered/shattered.py
index f9fbfbd..c5884b3 100644
--- a/src/shattered/shattered.py
+++ b/src/shattered/shattered.py
@@ -27,16 +27,25 @@ def on_message(self, headers, body):
class Shattered:
def __init__(self, **config):
self.config = config
- self.host = self.config.get("host", "localhost")
- self.port = self.config.get("port", 61613)
- self.username = self.config.get("username", "guest")
- self.password = self.config.get("password", "guest")
- self.vhost = self.config.get("vhost", "/")
+ self._set_config_defaults()
+
+ self.host = self.config["host"]
+ self.port = self.config["port"]
+ self.username = self.config["username"]
+ self.password = self.config["password"]
+ self.vhost = self.config["vhost"]
self.subscriptions = {}
self.conn = None
self.listener = ShatteredListener(self)
+ def _set_config_defaults(self):
+ self.config.setdefault("host", "localhost")
+ self.config.setdefault("port", 61613)
+ self.config.setdefault("username", "guest")
+ self.config.setdefault("password", "guest")
+ self.config.setdefault("vhost", "/")
+
def add_subscription(self, destination, callback):
if destination not in self.subscriptions:
self.subscriptions[destination] = []