diff --git a/examples/chat/Dockerfile b/examples/chat/Dockerfile index d331e8d..c4a2f05 100644 --- a/examples/chat/Dockerfile +++ b/examples/chat/Dockerfile @@ -2,6 +2,7 @@ FROM python:3 ENV FLASK_APP=chat.py ENV FLASK_ENV=development +ENV SHATTERED_APP=chat.py COPY requirements.txt /src/requirements.txt diff --git a/examples/chat/chat.py b/examples/chat/chat.py index e51357d..b5f8827 100644 --- a/examples/chat/chat.py +++ b/examples/chat/chat.py @@ -21,7 +21,3 @@ def echo(headers, body, conn): data = dict(parse_qsl(body)) email = data["email"] conn.send(body=f'
  • {email}: {data["message"]}
  • ', destination="/topic/chat") - - -if __name__ == "__main__": - shattered_app.run() diff --git a/examples/chat/docker-compose.yml b/examples/chat/docker-compose.yml index 92256b2..ffd35b7 100644 --- a/examples/chat/docker-compose.yml +++ b/examples/chat/docker-compose.yml @@ -15,8 +15,8 @@ services: "300", "rabbitmq:61613", "--", - "python", - "chat.py", + "shattered", + "run", ] chat: build: . diff --git a/pyproject.toml b/pyproject.toml index 9479744..6b705a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,12 +1,15 @@ [tool.poetry] name = "shattered" -version = "0.2.0" +version = "0.3.0" description = "STOMP meets bottle.py" authors = ["Jimmy Bradshaw "] repository = "https://github.com/bradshjg/shattered" license = "MIT" readme = "README-PYPI.md" +[tool.poetry.scripts] +shattered = "shattered.cli:main" + [tool.poetry.dependencies] python = "^3.6" "stomp.py" = "^5.0.0" diff --git a/src/shattered/__init__.py b/src/shattered/__init__.py index c3d2549..831105d 100644 --- a/src/shattered/__init__.py +++ b/src/shattered/__init__.py @@ -1,3 +1,3 @@ -__version__ = "0.2.0" +__version__ = "0.3.0" from shattered.shattered import Shattered diff --git a/src/shattered/cli.py b/src/shattered/cli.py new file mode 100644 index 0000000..2b6d57e --- /dev/null +++ b/src/shattered/cli.py @@ -0,0 +1,47 @@ +import argparse +import os +import sys + +from .shattered import Shattered + + +def load_app(args): + sys.path.insert(0, os.getcwd()) + module_file = args.app or os.getenv("SHATTERED_APP", "app.py") + module_name = os.path.splitext(module_file)[0] + __import__(module_name) + module = sys.modules[module_name] + matches = [v for v in module.__dict__.values() if isinstance(v, Shattered)] + if len(matches) == 1: + return matches[0] + else: + raise EnvironmentError("Detected multiple Shattered applications.") + + +def run_app(args): + app = load_app(args) + app.run() + + +def print_config(args): + app = load_app(args) + for k, v in app.config.items(): + print(f"{k}: {v}") + + +def parse_arguments(): + parser = argparse.ArgumentParser(prog="shattered") + parser.add_argument("-a", "--app", help="shattered app") + subparsers = parser.add_subparsers( + title="subcommands", description="valid subcommands", help="additional help" + ) + parser_run = subparsers.add_parser("run", help="run shattered app") + parser_run.set_defaults(func=run_app) + parser_config = subparsers.add_parser("config", help="show shattered app config") + parser_config.set_defaults(func=print_config) + return parser.parse_args() + + +def main(): + args = parse_arguments() + args.func(args) diff --git a/src/shattered/shattered.py b/src/shattered/shattered.py index f9fbfbd..c5884b3 100644 --- a/src/shattered/shattered.py +++ b/src/shattered/shattered.py @@ -27,16 +27,25 @@ def on_message(self, headers, body): class Shattered: def __init__(self, **config): self.config = config - self.host = self.config.get("host", "localhost") - self.port = self.config.get("port", 61613) - self.username = self.config.get("username", "guest") - self.password = self.config.get("password", "guest") - self.vhost = self.config.get("vhost", "/") + self._set_config_defaults() + + self.host = self.config["host"] + self.port = self.config["port"] + self.username = self.config["username"] + self.password = self.config["password"] + self.vhost = self.config["vhost"] self.subscriptions = {} self.conn = None self.listener = ShatteredListener(self) + def _set_config_defaults(self): + self.config.setdefault("host", "localhost") + self.config.setdefault("port", 61613) + self.config.setdefault("username", "guest") + self.config.setdefault("password", "guest") + self.config.setdefault("vhost", "/") + def add_subscription(self, destination, callback): if destination not in self.subscriptions: self.subscriptions[destination] = []