diff --git a/setup.py b/setup.py index a60cd9f..2c552ff 100644 --- a/setup.py +++ b/setup.py @@ -15,6 +15,7 @@ tests_require = [ "transaction", + "pyramid", "nose", "coverage", ] diff --git a/sqlahelper.py b/sqlahelper.py index ebf18e3..c0afe70 100644 --- a/sqlahelper.py +++ b/sqlahelper.py @@ -1,3 +1,4 @@ +import re import sqlalchemy as sa import sqlalchemy.ext.declarative as declarative import sqlalchemy.orm as orm @@ -6,6 +7,60 @@ # Import only version 1 API with "import *" __all__ = ["add_engine", "get_base", "get_session", "get_engine"] +# pyramid configuration + +truthy = frozenset(['true', 'yes', 'on', 'y', 't', '1']) +falsy = frozenset(['false', 'no', 'off', 'n', 'f', '0']) + +def asbool(obj): + if isinstance(obj, basestring): + obj = obj.strip().lower() + if obj in truthy: + return True + elif obj in falsy: + return False + else: + raise ValueError("String is not true/false: %r" % obj) + return bool(obj) + +engine_url_pattern = re.compile(r'sqlahelper\.(?P\w+)\.url') +engine_echo_pattern = re.compile(r'sqlahelper\.(?P\w+)\.echo') + +def includeme(config): + """ set up engines from config. + + usege in :term:`Pyramid`:: + + config.include('sqlahelper') + + config.ini:: + + sqlalchemy.url = sqlite:///%(here)s/myapp.db + sqlahelper.otherengine.url = sqlite:///%(here)s/myapp_other.db + + ``sqlalchemy.url`` is set to default engine. + ``sqlahelper.otherengine.url`` is set to engine named "otherengine". + + """ + + settings = config.registry.settings + if 'sqlalchemy.url' in settings: + engine = sa.engine_from_config(settings) + set_default_engine(engine) + + for k, v in settings.items(): + url_match = engine_url_pattern.match(k) + if url_match: + engine_name = url_match.groupdict()['engine_name'] + engine = sa.create_engine(v) + setattr(engines, engine_name, engine) + echo_match = engine_echo_pattern.match(k) + if echo_match: + engine_name = echo_match.groupdict()['engine_name'] + if not hasattr(engines, engine_name): + continue + setattr(getattr(engines, engine_name), "echo", asbool(v)) + # VERSION 2 API class AttributeContainer(object): diff --git a/tests.py b/tests.py index 7ef89df..d3343bc 100644 --- a/tests.py +++ b/tests.py @@ -166,6 +166,48 @@ def test1(self): self.assertNotEqual(base2, base) self.assertEqual(base2, my_base) +class IncludeMeTests(unittest.TestCase): + def setUp(self): + from pyramid import testing + self.config = testing.setUp() + + def tearDown(self): + from pyramid import testing + testing.tearDown() + + def _callFUT(self, *args, **kwargs): + from sqlahelper import includeme + return includeme(*args, **kwargs) + + def test_default_engine(self): + engine_settings = { + 'sqlalchemy.url': 'sqlite:///', + 'sqlalchemy.echo': 'true', + } + self.config.registry.settings.update(engine_settings) + + self._callFUT(self.config) + + import sqlahelper + + self.assertIsNotNone(sqlahelper.engines.default) + self.assertEqual(str(sqlahelper.engines.default.url), 'sqlite:///') + self.assertTrue(sqlahelper.engines.default.echo) + + def test_another_engines(self): + engine_settings = { + 'sqlahelper.other.url': 'sqlite:///', + 'sqlahelper.other.echo': 'true', + } + self.config.registry.settings.update(engine_settings) + + self._callFUT(self.config) + + import sqlahelper + + self.assertIsNotNone(sqlahelper.engines.other) + self.assertEqual(str(sqlahelper.engines.other.url), 'sqlite:///') + self.assertTrue(sqlahelper.engines.other.echo) if __name__ == "__main__": unittest.main()