From ccbf60c686815f00dee3094c4beed8c3e65e5bf4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89lie=20Bouttier?= Date: Thu, 24 Mar 2022 17:52:59 +0100 Subject: [PATCH] apply black --- setup.py | 60 +- src/utils_flask_sqla/commands.py | 160 +-- src/utils_flask_sqla/errors.py | 17 +- src/utils_flask_sqla/generic.py | 106 +- .../versions/3842a6d800a0_sql_utils.py | 16 +- src/utils_flask_sqla/response.py | 13 +- src/utils_flask_sqla/schema.py | 12 +- src/utils_flask_sqla/serializers.py | 252 ++--- src/utils_flask_sqla/tests/test_schema.py | 46 +- .../tests/test_serializers.py | 950 +++++++++++------- 10 files changed, 910 insertions(+), 722 deletions(-) diff --git a/setup.py b/setup.py index 83bedf1..5dd91c6 100644 --- a/setup.py +++ b/setup.py @@ -3,51 +3,53 @@ root_dir = Path(__file__).absolute().parent -with (root_dir / 'VERSION').open() as f: +with (root_dir / "VERSION").open() as f: version = f.read() -with (root_dir / 'README.md').open() as f: +with (root_dir / "README.md").open() as f: long_description = f.read() -with (root_dir / 'requirements.in').open() as f: +with (root_dir / "requirements.in").open() as f: requirements = f.read().splitlines() setuptools.setup( - name='utils-flask-sqlalchemy', + name="utils-flask-sqlalchemy", version=version, description="Python lib of tools for Flask and SQLAlchemy", long_description=long_description, - long_description_content_type='text/markdown', - maintainer='Parcs nationaux des Écrins et des Cévennes', - maintainer_email='geonature@ecrins-parcnational.fr', - url='https://github.com/PnX-SI/Utils-Flask-SQLAlchemy', - packages=setuptools.find_packages('src'), - package_dir={'': 'src'}, + long_description_content_type="text/markdown", + maintainer="Parcs nationaux des Écrins et des Cévennes", + maintainer_email="geonature@ecrins-parcnational.fr", + url="https://github.com/PnX-SI/Utils-Flask-SQLAlchemy", + packages=setuptools.find_packages("src"), + package_dir={"": "src"}, install_requires=requirements, extras_require={ - 'tests': [ - 'pytest', - 'geoalchemy2', - 'shapely', - 'jsonschema', - 'flask-marshmallow', - 'marshmallow-sqlalchemy', + "tests": [ + "pytest", + "geoalchemy2", + "shapely", + "jsonschema", + "flask-marshmallow", + "marshmallow-sqlalchemy", ], }, entry_points={ - 'alembic': [ - 'migrations = utils_flask_sqla.migrations:versions', + "alembic": [ + "migrations = utils_flask_sqla.migrations:versions", ], - 'pytest11': [ - 'sqla = utils_flask_sqla.tests.plugin', + "pytest11": [ + "sqla = utils_flask_sqla.tests.plugin", ], - 'flask.commands': [ - 'db = utils_flask_sqla.commands:db_cli', + "flask.commands": [ + "db = utils_flask_sqla.commands:db_cli", ], }, - classifiers=['Development Status :: 1 - Planning', - 'Intended Audience :: Developers', - 'Natural Language :: English', - 'Programming Language :: Python :: 3', - 'License :: OSI Approved :: GNU Affero General Public License v3', - 'Operating System :: OS Independent'], + classifiers=[ + "Development Status :: 1 - Planning", + "Intended Audience :: Developers", + "Natural Language :: English", + "Programming Language :: Python :: 3", + "License :: OSI Approved :: GNU Affero General Public License v3", + "Operating System :: OS Independent", + ], ) diff --git a/src/utils_flask_sqla/commands.py b/src/utils_flask_sqla/commands.py index 7e828ac..4ae0f2b 100644 --- a/src/utils_flask_sqla/commands.py +++ b/src/utils_flask_sqla/commands.py @@ -13,54 +13,59 @@ def box_drowing(up, down, left, right): - if not up and not down and not left and not right: - return '─' - elif up and not down and not left and not right: - return '┸' - elif not up and down and not left and not right: - return '┰' - elif up and down and not left and not right: - return '┃' - elif up and not down and left and not right: - return '┛' - elif up and not down and not left and right: - return '┗' - elif not up and not down and left and right: - return '━' - elif not up and down and left and not right: - return '┓' - elif not up and down and not left and right: - return '┏' - elif up and down and not left and right: - return '┣' - elif up and down and left and not right: - return '┫' - elif up and not down and left and right: - return '┻' - elif not up and down and left and right: - return '┳' - elif up and down and left and right: - return '╋' + if not up and not down and not left and not right: + return "─" + elif up and not down and not left and not right: + return "┸" + elif not up and down and not left and not right: + return "┰" + elif up and down and not left and not right: + return "┃" + elif up and not down and left and not right: + return "┛" + elif up and not down and not left and right: + return "┗" + elif not up and not down and left and right: + return "━" + elif not up and down and left and not right: + return "┓" + elif not up and down and not left and right: + return "┏" + elif up and down and not left and right: + return "┣" + elif up and down and left and not right: + return "┫" + elif up and not down and left and right: + return "┻" + elif not up and down and left and right: + return "┳" + elif up and down and left and right: + return "╋" else: raise Exception("Unexpected box drowing symbol") @db_cli.command() -@click.option('-d', '--directory', default=None, - help=('Migration script directory (default is "migrations")')) -@click.option('--sql', is_flag=True, - help=('Don\'t emit SQL to database - dump to standard output ' - 'instead')) -@click.option('--tag', default=None, - help=('Arbitrary "tag" name - can be used by custom env.py ' - 'scripts')) -@click.option('-x', '--x-arg', multiple=True, - help='Additional arguments consumed by custom env.py scripts') +@click.option( + "-d", + "--directory", + default=None, + help=('Migration script directory (default is "migrations")'), +) +@click.option( + "--sql", is_flag=True, help=("Don't emit SQL to database - dump to standard output " "instead") +) +@click.option( + "--tag", default=None, help=('Arbitrary "tag" name - can be used by custom env.py ' "scripts") +) +@click.option( + "-x", "--x-arg", multiple=True, help="Additional arguments consumed by custom env.py scripts" +) @with_appcontext def autoupgrade(directory, sql, tag, x_arg): """Upgrade all branches to head.""" - db = current_app.extensions['sqlalchemy'].db - migrate = current_app.extensions['migrate'].migrate + db = current_app.extensions["sqlalchemy"].db + migrate = current_app.extensions["migrate"].migrate config = migrate.get_config(directory, x_arg) script = ScriptDirectory.from_config(config) heads = set(script.get_heads()) @@ -69,36 +74,41 @@ def autoupgrade(directory, sql, tag, x_arg): # get_current_heads does not return implicit revision through dependecies, get_all_current does current_heads = set(map(lambda rev: rev.revision, script.get_all_current(current_heads))) for head in current_heads - heads: - revision = head + '@head' + revision = head + "@head" flask_migrate.upgrade(directory, revision, sql, tag, x_arg) @db_cli.command() -@click.option('-d', '--directory', default=None, - help=('Migration script directory (default is "migrations")')) -@click.option('-x', '--x-arg', multiple=True, - help='Additional arguments consumed by custom env.py scripts') -@click.argument('branches', nargs=-1) +@click.option( + "-d", + "--directory", + default=None, + help=('Migration script directory (default is "migrations")'), +) +@click.option( + "-x", "--x-arg", multiple=True, help="Additional arguments consumed by custom env.py scripts" +) +@click.argument("branches", nargs=-1) @with_appcontext def status(directory, x_arg, branches): """Show all revisions sorted by branches.""" - db = current_app.extensions['sqlalchemy'].db - migrate = current_app.extensions['migrate'].migrate + db = current_app.extensions["sqlalchemy"].db + migrate = current_app.extensions["migrate"].migrate config = migrate.get_config(directory, x_arg) script = ScriptDirectory.from_config(config) migration_context = MigrationContext.configure(db.session.connection()) current_heads = migration_context.get_current_heads() - applied_rev = set(script.iterate_revisions(current_heads, 'base')) + applied_rev = set(script.iterate_revisions(current_heads, "base")) - bases = [ script.get_revision(base) for base in script.get_bases() ] - heads = [ script.get_revision(head) for head in script.get_heads() ] + bases = [script.get_revision(base) for base in script.get_bases()] + heads = [script.get_revision(head) for head in script.get_heads()] outdated = False for branch_base in sorted(bases, key=lambda rev: next(iter(rev.branch_labels))): output = StringIO() - branch, = branch_base.branch_labels + (branch,) = branch_base.branch_labels if branches and branch not in branches: continue levels = defaultdict(set) @@ -116,13 +126,11 @@ def status(directory, x_arg, branches): down_revisions = [rev.down_revision] else: down_revisions = [] - down_revisions = [ script.get_revision(r) for r in down_revisions ] + down_revisions = [script.get_revision(r) for r in down_revisions] - next_revisions = [ script.get_revision(r) for r in rev.nextrev ] + next_revisions = [script.get_revision(r) for r in rev.nextrev] - if (rev.is_merge_point - and (not seen.issuperset(down_revisions) - or rev in todo)): + if rev.is_merge_point and (not seen.issuperset(down_revisions) or rev in todo): continue seen.add(rev) @@ -139,38 +147,44 @@ def status(directory, x_arg, branches): all_levels = list(chain(down_levels, next_levels)) min_level = min(all_levels, default=0) max_level = max(all_levels, default=0) - symbol = '' + symbol = "" for i in range(max_level + 1): if i < min_level: - symbol += ' ' + symbol += " " else: symbol += box_drowing( - up = i in down_levels, - down = i in next_levels, - left = i > min_level, - right = i < max_level, + up=i in down_levels, + down=i in next_levels, + left=i > min_level, + right=i < max_level, ) - check = 'x' if rev in applied_rev else ' ' + check = "x" if rev in applied_rev else " " if branch_base in applied_rev and rev in applied_rev: - fg = 'white' + fg = "white" elif branch_base in applied_rev: outdated = True branch_outdated = True - fg = 'red' + fg = "red" else: fg = None - print(click.style(f" [{check}] {symbol} {rev.revision} {rev.doc}", fg=fg), file=output) + print( + click.style(f" [{check}] {symbol} {rev.revision} {rev.doc}", fg=fg), file=output + ) if branch_base in applied_rev: - fg = 'white' - mark = ' ' - mark += click.style('×', fg='red') if branch_outdated else click.style('✓', fg='green') + fg = "white" + mark = " " + mark += click.style("×", fg="red") if branch_outdated else click.style("✓", fg="green") else: fg = None - mark = '' - click.echo(click.style(f"[{branch}", bold=True, fg=fg) + mark + click.style("]", bold=True, fg=fg)) + mark = "" + click.echo( + click.style(f"[{branch}", bold=True, fg=fg) + mark + click.style("]", bold=True, fg=fg) + ) click.echo(output.getvalue(), nl=False) if outdated: - click.secho("Some branches are outdated, you can upgrade with 'autoupgrade' sub-command.", fg="red") + click.secho( + "Some branches are outdated, you can upgrade with 'autoupgrade' sub-command.", fg="red" + ) diff --git a/src/utils_flask_sqla/errors.py b/src/utils_flask_sqla/errors.py index 997e083..eefab4d 100644 --- a/src/utils_flask_sqla/errors.py +++ b/src/utils_flask_sqla/errors.py @@ -4,22 +4,15 @@ def __init__(self, message, status_code=500): self.message = message self.status_code = status_code raised_error = self.__class__.__name__ - log_message = "Raise: {}, {}".format( - raised_error, - message - ) + log_message = "Raise: {}, {}".format(raised_error, message) def to_dict(self): return { - 'message': self.message, - 'status_code': self.status_code, - 'raisedError': self.__class__.__name__ + "message": self.message, + "status_code": self.status_code, + "raisedError": self.__class__.__name__, } def __str__(self): message = "Error {}, Message: {}, raised error: {}" - return message.format( - self.status_code, - self.message, - self.__class__.__name__ - ) + return message.format(self.status_code, self.message, self.__class__.__name__) diff --git a/src/utils_flask_sqla/generic.py b/src/utils_flask_sqla/generic.py index e01cc9b..1285b4b 100644 --- a/src/utils_flask_sqla/generic.py +++ b/src/utils_flask_sqla/generic.py @@ -11,9 +11,9 @@ def testDataType(value, sqlType, paramName): """ - Test the type of a filter - #TODO: antipatern: should raise something which can be exect by the function which use it - # and not return the error + Test the type of a filter + #TODO: antipatern: should raise something which can be exect by the function which use it + # and not return the error """ if sqlType == Integer or isinstance(sqlType, (Integer)): try: @@ -35,12 +35,12 @@ def testDataType(value, sqlType, paramName): def test_type_and_generate_query(param_name, value, model, q): """ - Generate a query with the filter given, checking the params is the good type of the columns, and formmatting it - Params: - - param_name (str): the name of the column - - value (any): the value of the filter - - model (SQLA model) - - q (SQLA Query) + Generate a query with the filter given, checking the params is the good type of the columns, and formmatting it + Params: + - param_name (str): the name of the column + - value (any): the value of the filter + - model (SQLA model) + - q (SQLA Query) """ # check the attribut exist in the model try: @@ -52,22 +52,17 @@ def test_type_and_generate_query(param_name, value, model, q): try: return q.filter(col == int(value)) except Exception as e: - raise UtilsSqlaError( - "{0} must be an integer".format(param_name)) + raise UtilsSqlaError("{0} must be an integer".format(param_name)) if sql_type == Numeric or isinstance(sql_type, (Numeric)): try: return q.filter(col == float(value)) except Exception as e: - raise UtilsSqlaError( - "{0} must be an float (decimal separator .)".format(param_name) - ) + raise UtilsSqlaError("{0} must be an float (decimal separator .)".format(param_name)) if sql_type == DateTime or isinstance(sql_type, (Date, DateTime)): try: return q.filter(col == parser.parse(value)) except Exception as e: - raise UtilsSqlaError( - "{0} must be an date (yyyy-mm-dd)".format(param_name) - ) + raise UtilsSqlaError("{0} must be an date (yyyy-mm-dd)".format(param_name)) if sql_type == Boolean or isinstance(sql_type, Boolean): try: @@ -93,34 +88,33 @@ def test_type_and_generate_query(param_name, value, model, q): class GenericTable: """ - Classe permettant de créer à la volée un mapping - d'une vue avec la base de données par rétroingénierie + Classe permettant de créer à la volée un mapping + d'une vue avec la base de données par rétroingénierie """ def __init__(self, tableName, schemaName, engine): - ''' - params: - - tableName - - schemaName - - engine : sqlalchemy instance engine - for exemple : DB.engine if DB = Sqlalchemy() - ''' + """ + params: + - tableName + - schemaName + - engine : sqlalchemy instance engine + for exemple : DB.engine if DB = Sqlalchemy() + """ meta = MetaData(schema=schemaName, bind=engine) meta.reflect(views=True) try: self.tableDef = meta.tables["{}.{}".format(schemaName, tableName)] except KeyError: - raise KeyError("table {}.{} doesn't exists".format( - schemaName, tableName)) + raise KeyError("table {}.{} doesn't exists".format(schemaName, tableName)) # Mise en place d'un mapping des colonnes en vue d'une sérialisation self.serialize_columns, self.db_cols = self.get_serialized_columns() def get_serialized_columns(self, serializers=SERIALIZERS): """ - Return a tuple of serialize_columns, and db_cols - from the generic table + Return a tuple of serialize_columns, and db_cols + from the generic table """ regular_serialize = [] db_cols = [] @@ -128,9 +122,7 @@ def get_serialized_columns(self, serializers=SERIALIZERS): if not db_col.type.__class__.__name__ == "Geometry": serialize_attr = ( name, - serializers.get( - db_col.type.__class__.__name__.lower(), lambda x: x - ), + serializers.get(db_col.type.__class__.__name__.lower(), lambda x: x), ) regular_serialize.append(serialize_attr) @@ -140,31 +132,32 @@ def get_serialized_columns(self, serializers=SERIALIZERS): def as_dict(self, data, columns=[], fields=[]): fields = list(chain(fields, columns)) if columns: - warn("'columns' argument is deprecated. Please add columns to serialize " - "directly in 'fields' argument.", DeprecationWarning) + warn( + "'columns' argument is deprecated. Please add columns to serialize " + "directly in 'fields' argument.", + DeprecationWarning, + ) if fields: - fprops = list( - filter(lambda d: d[0] in fields, self.serialize_columns)) + fprops = list(filter(lambda d: d[0] in fields, self.serialize_columns)) else: fprops = self.serialize_columns return {item: _serializer(getattr(data, item)) for item, _serializer in fprops} - class GenericQuery: """ - Classe permettant de manipuler des objets GenericTable - - params: - - DB: sqlalchemy instantce (DB if DB = Sqlalchemy()) - - tableName - - schemaName - - filters: array of filter of the query - - engine : sqlalchemy instance engine - for exemple : DB.engine if DB = Sqlalchemy() - - limit - - offset + Classe permettant de manipuler des objets GenericTable + + params: + - DB: sqlalchemy instantce (DB if DB = Sqlalchemy()) + - tableName + - schemaName + - filters: array of filter of the query + - engine : sqlalchemy instance engine + for exemple : DB.engine if DB = Sqlalchemy() + - limit + - offset """ def __init__( @@ -186,7 +179,7 @@ def __init__( def build_query_filters(self, query, parameters): """ - Construction des filtres + Construction des filtres """ for f in parameters: query = self.build_query_filter(query, f, parameters.get(f)) @@ -195,8 +188,7 @@ def build_query_filters(self, query, parameters): def build_query_filter(self, query, param_name, param_value): if param_name in self.view.tableDef.columns.keys(): - query = query.filter( - self.view.tableDef.columns[param_name] == param_value) + query = query.filter(self.view.tableDef.columns[param_name] == param_value) if param_name.startswith("ilike_"): col = self.view.tableDef.columns[param_name[6:]] @@ -235,7 +227,7 @@ def build_query_order(self, query, parameters): # et prend la forme suivante : nom_colonne[:ASC|DESC] if parameters.get("orderby", None).replace(" ", ""): order_by = parameters.get("orderby") - col, *sort = order_by.split(':') + col, *sort = order_by.split(":") if col in self.view.tableDef.columns.keys(): ordel_col = getattr(self.view.tableDef.columns, col) if (sort[0:1] or ["ASC"])[0].lower() == "desc": @@ -245,7 +237,7 @@ def build_query_order(self, query, parameters): def query(self): """ - Lance la requete et retourne l'objet sqlalchemy + Lance la requete et retourne l'objet sqlalchemy """ q = self.DB.session.query(self.view.tableDef) nb_result_without_filter = q.count() @@ -266,8 +258,8 @@ def query(self): def return_query(self): """ - Lance la requete (execute self.query()) - et retourne les résutats dans un format standard + Lance la requete (execute self.query()) + et retourne les résutats dans un format standard """ @@ -318,6 +310,6 @@ def serializeQueryTest(data, column_def): elif isinstance(c["type"], Numeric): inter[c["name"]] = float(getattr(row, c["name"])) # elif not isinstance(c["type"], Geometry): - # inter[c["name"]] = getattr(row, c["name"]) + # inter[c["name"]] = getattr(row, c["name"]) rows.append(inter) return rows diff --git a/src/utils_flask_sqla/migrations/versions/3842a6d800a0_sql_utils.py b/src/utils_flask_sqla/migrations/versions/3842a6d800a0_sql_utils.py index e067c6f..55f0cac 100644 --- a/src/utils_flask_sqla/migrations/versions/3842a6d800a0_sql_utils.py +++ b/src/utils_flask_sqla/migrations/versions/3842a6d800a0_sql_utils.py @@ -12,14 +12,15 @@ # revision identifiers, used by Alembic. -revision = '3842a6d800a0' +revision = "3842a6d800a0" down_revision = None -branch_labels = ('sql_utils',) +branch_labels = ("sql_utils",) depends_on = None def upgrade(): - op.execute(''' + op.execute( + """ CREATE OR REPLACE FUNCTION public.fct_trg_meta_dates_change() RETURNS trigger AS $BODY$ @@ -37,10 +38,13 @@ def upgrade(): $BODY$ LANGUAGE plpgsql VOLATILE COST 100; -''') +""" + ) def downgrade(): - op.execute(''' + op.execute( + """ DROP FUNCTION public.fct_trg_meta_dates_change(); - ''') + """ + ) diff --git a/src/utils_flask_sqla/response.py b/src/utils_flask_sqla/response.py index a06cb26..7c6c9d5 100644 --- a/src/utils_flask_sqla/response.py +++ b/src/utils_flask_sqla/response.py @@ -7,6 +7,7 @@ from flask.json import JSONEncoder from werkzeug.datastructures import Headers + def json_resp_accept(accepted_list=[]): def json_resp(fn): """ @@ -31,8 +32,6 @@ def _json_resp(*args, **kwargs): json_resp_accept_empty_list = json_resp_accept([[]]) - - def to_json_resp( res, status=200, @@ -43,7 +42,7 @@ def to_json_resp( accepted_list=[], ): if res is None: - return ('', 204) + return ("", 204) headers = None if as_file: @@ -56,9 +55,7 @@ def to_json_resp( ) # use Flask JSONEncoder for raw sql query return Response( - json.dumps( - res, ensure_ascii=False, indent=indent, cls=JSONEncoder - ), + json.dumps(res, ensure_ascii=False, indent=indent, cls=JSONEncoder), status=status, mimetype="application/json", headers=headers, @@ -83,9 +80,7 @@ def to_csv_resp(filename, data, columns, separator=";"): headers = Headers() headers.add("Content-Type", "text/plain") - headers.add( - "Content-Disposition", "attachment", filename="export_%s.csv" % filename - ) + headers.add("Content-Disposition", "attachment", filename="export_%s.csv" % filename) out = generate_csv_content(columns, data, separator) return Response(out, headers=headers) diff --git a/src/utils_flask_sqla/schema.py b/src/utils_flask_sqla/schema.py index 768920f..c833ff6 100644 --- a/src/utils_flask_sqla/schema.py +++ b/src/utils_flask_sqla/schema.py @@ -11,16 +11,16 @@ def __init__(self, *args, **kwargs): else: natural_fields.add(name) - only = kwargs.pop('only', None) + only = kwargs.pop("only", None) only = set(only) if only is not None else set() - firstlevel_only = { field.split('.', 1)[0] for field in only } - exclude = kwargs.pop('exclude', None) + firstlevel_only = {field.split(".", 1)[0] for field in only} + exclude = kwargs.pop("exclude", None) exclude = set(exclude) if exclude is not None else set() - exclude |= (nested_fields - firstlevel_only) + exclude |= nested_fields - firstlevel_only if not firstlevel_only - nested_fields: only |= natural_fields if only: - kwargs['only'] = only + kwargs["only"] = only if exclude: - kwargs['exclude'] = exclude + kwargs["exclude"] = exclude super().__init__(*args, **kwargs) diff --git a/src/utils_flask_sqla/serializers.py b/src/utils_flask_sqla/serializers.py index a5976a2..46090be 100644 --- a/src/utils_flask_sqla/serializers.py +++ b/src/utils_flask_sqla/serializers.py @@ -53,22 +53,21 @@ def get_serializable_decorator(fields=[], exclude=[], stringify=True): default_fields = fields default_exclude = exclude default_stringify = stringify - firstlevel_default_fields = { field.split('.')[0] - for field in default_fields } + firstlevel_default_fields = {field.split(".")[0] for field in default_fields} def _serializable(cls): """ - Décorateur de classe pour les DB.Models - Permet de rajouter la fonction as_dict - qui est basée sur le mapping SQLAlchemy + Décorateur de classe pour les DB.Models + Permet de rajouter la fonction as_dict + qui est basée sur le mapping SQLAlchemy """ mapper = inspect(cls) def get_cls_db_columns(): """ - Liste des propriétés sérialisables de la classe - associées à leur sérializer en fonction de leur type + Liste des propriétés sérialisables de la classe + associées à leur sérializer en fonction de leur type """ cls_db_columns = [] for prop in cls.__mapper__.column_attrs: @@ -78,18 +77,17 @@ def get_cls_db_columns(): db_col = prop.columns[-1] # HACK # -> Récupération du nom de l'attribut sans la classe - name = str(prop).split('.', 1)[1] - if db_col.type.__class__.__name__ == 'Geometry': + name = str(prop).split(".", 1)[1] + if db_col.type.__class__.__name__ == "Geometry": continue if name in exclude: continue - cls_db_columns.append(( - name, - SERIALIZERS.get( - db_col.type.__class__.__name__.lower(), - lambda x: x + cls_db_columns.append( + ( + name, + SERIALIZERS.get(db_col.type.__class__.__name__.lower(), lambda x: x), ) - )) + ) """ Liste des propriétés synonymes sérialisables de la classe @@ -98,32 +96,25 @@ def get_cls_db_columns(): for syn in cls.__mapper__.synonyms: col = cls.__mapper__.c[syn.name] # if column type is geometry pass - if col.type.__class__.__name__ == 'Geometry': + if col.type.__class__.__name__ == "Geometry": pass # else add synonyms in columns properties - cls_db_columns.append(( - syn.key, - SERIALIZERS.get( - col.type.__class__.__name__.lower(), - lambda x: x - ) - )) + cls_db_columns.append( + (syn.key, SERIALIZERS.get(col.type.__class__.__name__.lower(), lambda x: x)) + ) return cls_db_columns def get_cls_db_relationships(): """ - Liste des propriétés de type relationship - uselist permet de savoir si c'est une collection de sous objet - sa valeur est déduite du type de relation - (OneToMany, ManyToOne ou ManyToMany) + Liste des propriétés de type relationship + uselist permet de savoir si c'est une collection de sous objet + sa valeur est déduite du type de relation + (OneToMany, ManyToOne ou ManyToMany) """ return [ - ( - db_rel.key, - db_rel.uselist, - getattr(cls, db_rel.key).mapper.class_ - ) for db_rel in cls.__mapper__.relationships + (db_rel.key, db_rel.uselist, getattr(cls, db_rel.key).mapper.class_) + for db_rel in cls.__mapper__.relationships ] @lru_cache(maxsize=None) @@ -136,10 +127,10 @@ def get_columns_and_relationships(fields=None, exclude=None): additional_fields = set() relationship_fields = set() for field in fields: - if field.split('.')[0] in mapper.relationships: + if field.split(".")[0] in mapper.relationships: relationship_fields.add(field) - elif field.startswith('+'): - field = field.lstrip('+') + elif field.startswith("+"): + field = field.lstrip("+") additional_fields.add(field) else: base_fields.add(field) @@ -158,48 +149,63 @@ def get_columns_and_relationships(fields=None, exclude=None): exclude = _default_exclude # take 'a' instead of 'a.b' - firstlevel_fields = [ rel.split('.')[0] for rel in fields ] - - properties = { key: None - for key in dir(cls) - if isinstance(getattr_static(cls, key), property) } - hybrid_properties = { key: attr - for key, attr in mapper.all_orm_descriptors.items() - if attr.extension_type == HYBRID_PROPERTY } - for field in set([ f for f in fields if '.' not in f ]) \ - - set(mapper.column_attrs.keys()) \ - - set(properties.keys()) \ - - set(hybrid_properties.keys()) \ - - set(mapper.relationships.keys()): + firstlevel_fields = [rel.split(".")[0] for rel in fields] + + properties = { + key: None for key in dir(cls) if isinstance(getattr_static(cls, key), property) + } + hybrid_properties = { + key: attr + for key, attr in mapper.all_orm_descriptors.items() + if attr.extension_type == HYBRID_PROPERTY + } + for field in ( + set([f for f in fields if "." not in f]) + - set(mapper.column_attrs.keys()) + - set(properties.keys()) + - set(hybrid_properties.keys()) + - set(mapper.relationships.keys()) + ): raise Exception(f"Field '{field}' does not exist on {cls}.") - for field in set([ f.split('.')[0] for f in fields if '.' in f ]) \ - - set(mapper.relationships.keys()): + for field in set([f.split(".")[0] for f in fields if "." in f]) - set( + mapper.relationships.keys() + ): raise Exception(f"Relationship '{field}' does not exist on {cls}.") - _columns = { key: col - for key, col in ChainMap(dict(mapper.column_attrs), properties, hybrid_properties).items() - if key in fields } - _relationships = { key: rel - for key, rel in mapper.relationships.items() - if key in firstlevel_fields } + _columns = { + key: col + for key, col in ChainMap( + dict(mapper.column_attrs), properties, hybrid_properties + ).items() + if key in fields + } + _relationships = { + key: rel for key, rel in mapper.relationships.items() if key in firstlevel_fields + } if not _columns: _columns = ChainMap(dict(mapper.column_attrs), properties, hybrid_properties) if exclude: - _columns = { key: col - for key, col in _columns.items() - if key not in exclude } - _relationships = { key: rel - for key, rel in _relationships.items() - if key not in exclude } + _columns = {key: col for key, col in _columns.items() if key not in exclude} + _relationships = { + key: rel for key, rel in _relationships.items() if key not in exclude + } - _columns = { key: (col, get_serializer(col)) - for key, col in _columns.items() } + _columns = {key: (col, get_serializer(col)) for key, col in _columns.items()} return fields, exclude, _columns, _relationships - def serializefn(self, recursif=False, columns=[], relationships=[], - fields=None, exclude=None, stringify=None, unloaded=None, - depth=None, _excluded_mappers=[]): + def serializefn( + self, + recursif=False, + columns=[], + relationships=[], + fields=None, + exclude=None, + stringify=None, + unloaded=None, + depth=None, + _excluded_mappers=[], + ): """ Méthode qui renvoie les données de l'objet sous la forme d'un dict @@ -243,14 +249,20 @@ def serializefn(self, recursif=False, columns=[], relationships=[], stringify = default_stringify if columns: - warn("'columns' argument is deprecated. Please add columns to serialize " - "directly in 'fields' argument.", DeprecationWarning) + warn( + "'columns' argument is deprecated. Please add columns to serialize " + "directly in 'fields' argument.", + DeprecationWarning, + ) if fields is None: fields = [] fields = chain(fields, columns) if relationships: - warn("'relationships' argument is deprecated. Please add relationships to serialize " - "directly in 'fields' argument.", DeprecationWarning) + warn( + "'relationships' argument is deprecated. Please add relationships to serialize " + "directly in 'fields' argument.", + DeprecationWarning, + ) if fields is None: fields = [] fields = chain(fields, relationships) @@ -258,15 +270,23 @@ def serializefn(self, recursif=False, columns=[], relationships=[], if depth: recursif = True if recursif: - warn("'recursif' argument is deprecated. Please add relationships to serialize " - "directly in 'fields' argument.", DeprecationWarning) - _excluded_mappers = _excluded_mappers + [ mapper ] + warn( + "'recursif' argument is deprecated. Please add relationships to serialize " + "directly in 'fields' argument.", + DeprecationWarning, + ) + _excluded_mappers = _excluded_mappers + [mapper] if depth is None or depth > 0: if fields is None: fields = [] - fields = chain(fields, [ rel.key for rel in mapper.relationships - if rel.key not in fields - and rel.mapper not in _excluded_mappers ]) + fields = chain( + fields, + [ + rel.key + for rel in mapper.relationships + if rel.key not in fields and rel.mapper not in _excluded_mappers + ], + ) if depth: depth -= 1 @@ -275,13 +295,15 @@ def serializefn(self, recursif=False, columns=[], relationships=[], if exclude is not None: exclude = frozenset(exclude) - fields, exclude, _columns, _relationships = get_columns_and_relationships(fields, exclude) + fields, exclude, _columns, _relationships = get_columns_and_relationships( + fields, exclude + ) serialize_kwargs = { - 'recursif': recursif, - 'depth': depth, - 'unloaded': unloaded, - '_excluded_mappers': _excluded_mappers, + "recursif": recursif, + "depth": depth, + "unloaded": unloaded, + "_excluded_mappers": _excluded_mappers, } data = {} @@ -295,21 +317,21 @@ def serializefn(self, recursif=False, columns=[], relationships=[], m = inspect(self) if key in m.unloaded: err = f"Relationship '{key}' on '{self}' is not loaded" - if unloaded == 'raise': + if unloaded == "raise": raise Exception(err) - elif unloaded == 'warn': + elif unloaded == "warn": warn(err) kwargs = serialize_kwargs.copy() - _fields = [ field.split('.', 1)[1] - for field in fields - if field.startswith(f'{key}.') ] - kwargs['fields'] = _fields or None - _exclude = [ field.split('.', 1)[1] - for field in exclude - if field.startswith(f'{key}.') ] - kwargs['exclude'] = _exclude or None + _fields = [ + field.split(".", 1)[1] for field in fields if field.startswith(f"{key}.") + ] + kwargs["fields"] = _fields or None + _exclude = [ + field.split(".", 1)[1] for field in exclude if field.startswith(f"{key}.") + ] + kwargs["exclude"] = _exclude or None if rel.uselist: - data[key] = [ o.as_dict(**kwargs) for o in getattr(self, key) ] + data[key] = [o.as_dict(**kwargs) for o in getattr(self, key)] else: rel_object = getattr(self, rel.key) if rel_object: @@ -317,11 +339,11 @@ def serializefn(self, recursif=False, columns=[], relationships=[], else: # relationship may be null data[rel.key] = None return data - serializefn.__original_decorator = True + serializefn.__original_decorator = True def populatefn(self, dict_in, recursif=False): - ''' + """ Méthode qui initie les valeurs de l'objet à partir d'un dictionnaire Parameters @@ -329,7 +351,7 @@ def populatefn(self, dict_in, recursif=False): dict_in : dictionnaire contenant les valeurs à passer à l'objet recursif: si on renseigne les relationships - ''' + """ cls_db_columns_key = list(map(lambda x: x[0], get_cls_db_columns())) @@ -374,22 +396,12 @@ def populatefn(self, dict_in, recursif=False): values_inter.append(data) values = values_inter - - # preload with id # pour faire une seule requête - ids = filter( - lambda x: x, - map( - lambda x: x.get(id_field_name), - values - ) - ) - preload_res_with_ids = ( - Model.query - .filter(getattr(Model, id_field_name).in_(ids)) - .all() - ) + ids = filter(lambda x: x, map(lambda x: x.get(id_field_name), values)) + preload_res_with_ids = Model.query.filter( + getattr(Model, id_field_name).in_(ids) + ).all() # resul v_obj = [] @@ -409,39 +421,39 @@ def populatefn(self, dict_in, recursif=False): list( filter( lambda x: getattr(x, id_field_name) == id_value, - preload_res_with_ids + preload_res_with_ids, ) )[0] if id_value and len(preload_res_with_ids) # sinon on cree une nouvelle instance else Model() ) - if hasattr(res, 'from_dict'): + if hasattr(res, "from_dict"): res.from_dict(data, recursif) v_obj.append(res) # attribution de la relation # si uselist est à false -> on prend le premier de la liste - setattr( - self, - rel, - v_obj if uselist else v_obj[0] - ) + setattr(self, rel, v_obj if uselist else v_obj[0]) return self - if hasattr(cls, 'as_dict'): + if hasattr(cls, "as_dict"): # the Model has a as_dict(self, data) method, which expects serialized data as argument if len(signature(cls.as_dict).parameters) == 2: userfn = cls.as_dict + def chainedserializefn(self, *args, **kwargs): return userfn(self, serializefn(self, *args, **kwargs)) + cls.as_dict = chainedserializefn # the Model has its own as_dict method - elif 'as_dict' in vars(cls): + elif "as_dict" in vars(cls): # the serialize decorator is applied a second time with new default arguments - if hasattr(cls.as_dict, '__original_decorator') and (default_fields or default_exclude): + if hasattr(cls.as_dict, "__original_decorator") and ( + default_fields or default_exclude + ): cls.as_dict = serializefn # which will call super().as_dict() it-self else: diff --git a/src/utils_flask_sqla/tests/test_schema.py b/src/utils_flask_sqla/tests/test_schema.py index 890d37a..d31ce92 100644 --- a/src/utils_flask_sqla/tests/test_schema.py +++ b/src/utils_flask_sqla/tests/test_schema.py @@ -184,12 +184,11 @@ def test_null_relationship(self): }, ) - def test_polymorphic_model(self): class PolyModel(db.Model): __mapper_args__ = { - 'polymorphic_identity': 'IdentityBase', - 'polymorphic_on': 'kind', + "polymorphic_identity": "IdentityBase", + "polymorphic_on": "kind", } pk = db.Column(db.Integer, primary_key=True) kind = db.Column(db.String) @@ -201,7 +200,7 @@ class Meta: class PolyModelA(PolyModel): __mapper_args__ = { - 'polymorphic_identity': 'A', + "polymorphic_identity": "A", } pk = db.Column(db.Integer, db.ForeignKey(PolyModel.pk), primary_key=True) a = db.Column(db.String) @@ -212,7 +211,7 @@ class Meta: class PolyModelB(PolyModel): __mapper_args__ = { - 'polymorphic_identity': 'B', + "polymorphic_identity": "B", } pk = db.Column(db.Integer, db.ForeignKey(PolyModel.pk), primary_key=True) b = db.Column(db.String) @@ -221,25 +220,25 @@ class PolyModelBSchema(SmartRelationshipsMixin, SQLAlchemyAutoSchema): class Meta: model = PolyModelB - a = PolyModelA(pk=1, base='BA', a='A') + a = PolyModelA(pk=1, base="BA", a="A") TestCase().assertDictEqual( PolyModelASchema().dump(a), { - 'pk': 1, - 'kind': 'A', - 'base': 'BA', - 'a': 'A', + "pk": 1, + "kind": "A", + "base": "BA", + "a": "A", }, ) - b = PolyModelB(pk=2, base='BB', b='B') + b = PolyModelB(pk=2, base="BB", b="B") TestCase().assertDictEqual( PolyModelBSchema().dump(b), { - 'pk': 2, - 'kind': 'B', - 'base': 'BB', - 'b': 'B', + "pk": 2, + "kind": "B", + "base": "BB", + "b": "B", }, ) @@ -251,24 +250,25 @@ class HybridModel(db.Model): @hybrid_property def concat(self): - return '{0} {1}'.format(self.part1, self.part2) + return "{0} {1}".format(self.part1, self.part2) @concat.expression def concat(cls): - return db.func.concat(cls.part1, ' ', cls.part2) + return db.func.concat(cls.part1, " ", cls.part2) class HybridModelSchema(SmartRelationshipsMixin, SQLAlchemyAutoSchema): class Meta: model = HybridModel - concat = ma.fields.String(dump_only=True) - h = HybridModel(pk=1, part1='a', part2='b') + concat = ma.fields.String(dump_only=True) + + h = HybridModel(pk=1, part1="a", part2="b") TestCase().assertDictEqual( HybridModelSchema().dump(h), { - 'pk': 1, - 'part1': 'a', - 'part2': 'b', - 'concat': 'a b', + "pk": 1, + "part1": "a", + "part2": "b", + "concat": "a b", }, ) diff --git a/src/utils_flask_sqla/tests/test_serializers.py b/src/utils_flask_sqla/tests/test_serializers.py index ec0a4ca..9ac0ed9 100644 --- a/src/utils_flask_sqla/tests/test_serializers.py +++ b/src/utils_flask_sqla/tests/test_serializers.py @@ -27,7 +27,7 @@ class Parent(db.Model): class Child(db.Model): pk = db.Column(db.Integer, primary_key=True) parent_pk = db.Column(db.Integer, db.ForeignKey(Parent.pk)) - parent = relationship('Parent', backref='childs') + parent = relationship("Parent", backref="childs") @serializable(stringify=False) @@ -39,14 +39,14 @@ class A(db.Model): class B(db.Model): pk = db.Column(db.Integer, primary_key=True) a_pk = db.Column(db.Integer, db.ForeignKey(A.pk)) - a = relationship('A', backref='b_set') + a = relationship("A", backref="b_set") @serializable(stringify=False) class C(db.Model): pk = db.Column(db.Integer, primary_key=True) b_pk = db.Column(db.Integer, db.ForeignKey(B.pk)) - b = relationship('B', backref='c_set') + b = relationship("B", backref="c_set") class TestSerializers: @@ -69,46 +69,48 @@ class TestModel(db.Model): now = datetime.datetime.now() uuid = uuid4() - geom = wkt.loads('POINT(6 10)') - kwargs = dict(pk=1, - string='string', - unicode='unicode', - date=now.date(), - time=now.time(), - datetime=now, - boolean=True, - uuid=uuid, - hstore={'a': ['b', 'c']}, - json={'e': [1, 2]}, - jsonb=[3, {'f': 'g'}], - array=[1, 2], - geom=geom) + geom = wkt.loads("POINT(6 10)") + kwargs = dict( + pk=1, + string="string", + unicode="unicode", + date=now.date(), + time=now.time(), + datetime=now, + boolean=True, + uuid=uuid, + hstore={"a": ["b", "c"]}, + json={"e": [1, 2]}, + jsonb=[3, {"f": "g"}], + array=[1, 2], + geom=geom, + ) o = TestModel(**kwargs) d = o.as_dict(stringify=False) TestCase().assertDictEqual(kwargs, d) - d = o.as_dict(exclude=['geom']) + d = o.as_dict(exclude=["geom"]) json.dumps(d) # check dict is JSON-serializable assert d == { - 'pk': 1, - 'string': 'string', - 'unicode': 'unicode', - 'date': str(now.date()), - 'time': str(now.time()), - 'datetime': str(now), - 'boolean': True, - 'uuid': str(uuid), - 'hstore': { - 'a': ['b', 'c'], - }, - 'json': { - 'e': [1, 2], - }, - 'jsonb': [ + "pk": 1, + "string": "string", + "unicode": "unicode", + "date": str(now.date()), + "time": str(now.time()), + "datetime": str(now), + "boolean": True, + "uuid": str(uuid), + "hstore": { + "a": ["b", "c"], + }, + "json": { + "e": [1, 2], + }, + "jsonb": [ 3, - {'f': 'g'}, + {"f": "g"}, ], - 'array': [1, 2], + "array": [1, 2], } def test_many_to_one(self): @@ -116,70 +118,96 @@ def test_many_to_one(self): child = Child(pk=1, parent_pk=parent.pk, parent=parent) d = parent.as_dict() - TestCase().assertDictEqual({'pk': 1}, d) + TestCase().assertDictEqual({"pk": 1}, d) d = child.as_dict() - TestCase().assertDictEqual({'pk': 1, 'parent_pk': 1}, d) + TestCase().assertDictEqual({"pk": 1, "parent_pk": 1}, d) with pytest.deprecated_call(): - d = parent.as_dict(recursif=True) # no recursion limit - TestCase().assertDictEqual({ - 'pk': 1, - 'childs': [{ - 'pk': 1, - 'parent_pk': 1, - }], - }, d) + d = parent.as_dict(recursif=True) # no recursion limit + TestCase().assertDictEqual( + { + "pk": 1, + "childs": [ + { + "pk": 1, + "parent_pk": 1, + } + ], + }, + d, + ) with pytest.deprecated_call(): - d = child.as_dict(recursif=True) # no recursion limit - TestCase().assertDictEqual({ - 'pk': 1, - 'parent_pk': 1, - 'parent': { - 'pk': 1, + d = child.as_dict(recursif=True) # no recursion limit + TestCase().assertDictEqual( + { + "pk": 1, + "parent_pk": 1, + "parent": { + "pk": 1, + }, }, - }, d) + d, + ) with pytest.deprecated_call(): - d = parent.as_dict(recursif=True, depth=1) - TestCase().assertDictEqual({ - 'pk': 1, - 'childs': [{ - 'pk': 1, - 'parent_pk': 1, - }], - }, d) + d = parent.as_dict(recursif=True, depth=1) + TestCase().assertDictEqual( + { + "pk": 1, + "childs": [ + { + "pk": 1, + "parent_pk": 1, + } + ], + }, + d, + ) with pytest.deprecated_call(): - d = child.as_dict(recursif=True, depth=1) - TestCase().assertDictEqual({ - 'pk': 1, - 'parent_pk': 1, - 'parent': { - 'pk': 1, + d = child.as_dict(recursif=True, depth=1) + TestCase().assertDictEqual( + { + "pk": 1, + "parent_pk": 1, + "parent": { + "pk": 1, + }, }, - }, d) + d, + ) with pytest.deprecated_call(): d = parent.as_dict(recursif=True, depth=2) # depth=2 must not go further than depth=1 because of loop-avoidance - TestCase().assertDictEqual({ - 'pk': 1, - 'childs': [{ - 'pk': 1, - 'parent_pk': 1, - }], - }, d) - - d = parent.as_dict(fields=['childs']) - TestCase().assertDictEqual({ - 'pk': 1, - 'childs': [{ - 'pk': 1, - 'parent_pk': 1, - }] - }, d) + TestCase().assertDictEqual( + { + "pk": 1, + "childs": [ + { + "pk": 1, + "parent_pk": 1, + } + ], + }, + d, + ) + + d = parent.as_dict(fields=["childs"]) + TestCase().assertDictEqual( + { + "pk": 1, + "childs": [ + { + "pk": 1, + "parent_pk": 1, + } + ], + }, + d, + ) def test_depth(self): a = A(pk=1) @@ -188,47 +216,69 @@ def test_depth(self): with pytest.deprecated_call(): d = a.as_dict(recursif=True, depth=0) - TestCase().assertDictEqual({ - 'pk': 1, - }, d) + TestCase().assertDictEqual( + { + "pk": 1, + }, + d, + ) with pytest.deprecated_call(): d = a.as_dict(recursif=True, depth=1) - TestCase().assertDictEqual({ - 'pk': 1, - 'b_set': [{ - 'pk': 10, - 'a_pk': 1, - }], - }, d) + TestCase().assertDictEqual( + { + "pk": 1, + "b_set": [ + { + "pk": 10, + "a_pk": 1, + } + ], + }, + d, + ) with pytest.deprecated_call(): d = a.as_dict(recursif=True, depth=2) - TestCase().assertDictEqual({ - 'pk': 1, - 'b_set': [{ - 'pk': 10, - 'a_pk': 1, - 'c_set': [{ - 'pk': 100, - 'b_pk': 10, - }], - }], - }, d) + TestCase().assertDictEqual( + { + "pk": 1, + "b_set": [ + { + "pk": 10, + "a_pk": 1, + "c_set": [ + { + "pk": 100, + "b_pk": 10, + } + ], + } + ], + }, + d, + ) with pytest.deprecated_call(): d = a.as_dict(recursif=True, depth=3) - TestCase().assertDictEqual({ - 'pk': 1, - 'b_set': [{ - 'pk': 10, - 'a_pk': 1, - 'c_set': [{ - 'pk': 100, - 'b_pk': 10, - }], - }], - }, d) + TestCase().assertDictEqual( + { + "pk": 1, + "b_set": [ + { + "pk": 10, + "a_pk": 1, + "c_set": [ + { + "pk": 100, + "b_pk": 10, + } + ], + } + ], + }, + d, + ) def test_relationships(self): a = A(pk=1) @@ -236,73 +286,95 @@ def test_relationships(self): c = C(pk=100, b_pk=b.pk, b=b) # as 'pk' is specified, 'b_pk' is not taken - d = c.as_dict(fields=['pk', 'b']) - TestCase().assertDictEqual({ - 'pk': 100, - 'b': { - 'pk': 10, - 'a_pk': 1, - }, - }, d) - - d = c.as_dict(fields=['pk', 'b.pk']) - TestCase().assertDictEqual({ - 'pk': 100, - 'b': { - 'pk': 10, - }, - }, d) - - d = c.as_dict(fields=['pk', 'b.a']) - TestCase().assertDictEqual({ - 'pk': 100, - 'b': { - 'pk': 10, - 'a_pk': 1, - 'a': { - 'pk': 1, + d = c.as_dict(fields=["pk", "b"]) + TestCase().assertDictEqual( + { + "pk": 100, + "b": { + "pk": 10, + "a_pk": 1, }, }, - }, d) - - d = c.as_dict(fields=['pk', 'b.pk', 'b.a']) - TestCase().assertDictEqual({ - 'pk': 100, - 'b': { - 'pk': 10, - 'a': { - 'pk': 1, + d, + ) + + d = c.as_dict(fields=["pk", "b.pk"]) + TestCase().assertDictEqual( + { + "pk": 100, + "b": { + "pk": 10, }, }, - }, d) - - d = c.as_dict(fields=['b.c_set.b', 'b.c_set.pk']) - TestCase().assertDictEqual({ - 'pk': 100, - 'b_pk': 10, - 'b': { - 'pk': 10, - 'a_pk': 1, - 'c_set': [{ - 'pk': 100, - 'b': { - 'pk': 10, - 'a_pk': 1, + d, + ) + + d = c.as_dict(fields=["pk", "b.a"]) + TestCase().assertDictEqual( + { + "pk": 100, + "b": { + "pk": 10, + "a_pk": 1, + "a": { + "pk": 1, }, - }], + }, }, - }, d) - - d = c.as_dict(fields=['b', 'b.c_set'], exclude=['b_pk', 'b.a_pk', 'b.c_set.pk']) - TestCase().assertDictEqual({ - 'pk': 100, - 'b': { - 'pk': 10, - 'c_set': [{ - 'b_pk': 10, - }], + d, + ) + + d = c.as_dict(fields=["pk", "b.pk", "b.a"]) + TestCase().assertDictEqual( + { + "pk": 100, + "b": { + "pk": 10, + "a": { + "pk": 1, + }, + }, + }, + d, + ) + + d = c.as_dict(fields=["b.c_set.b", "b.c_set.pk"]) + TestCase().assertDictEqual( + { + "pk": 100, + "b_pk": 10, + "b": { + "pk": 10, + "a_pk": 1, + "c_set": [ + { + "pk": 100, + "b": { + "pk": 10, + "a_pk": 1, + }, + } + ], + }, }, - }, d) + d, + ) + + d = c.as_dict(fields=["b", "b.c_set"], exclude=["b_pk", "b.a_pk", "b.c_set.pk"]) + TestCase().assertDictEqual( + { + "pk": 100, + "b": { + "pk": 10, + "c_set": [ + { + "b_pk": 10, + } + ], + }, + }, + d, + ) def test_unexisting_field(self): @serializable(stringify=False) @@ -310,111 +382,149 @@ class Model(db.Model): pk = db.Column(db.Integer, primary_key=True) obj = Model(pk=1) - d = obj.as_dict(fields=['pk']) - TestCase().assertDictEqual({'pk': 1}, d) + d = obj.as_dict(fields=["pk"]) + TestCase().assertDictEqual({"pk": 1}, d) with pytest.raises(Exception) as excinfo: - obj.as_dict(fields=['unexisting']) - assert('does not exist on' in str(excinfo.value)) + obj.as_dict(fields=["unexisting"]) + assert "does not exist on" in str(excinfo.value) with pytest.raises(Exception) as excinfo: - obj.as_dict(fields=['pk.unexisting']) - assert('does not exist on' in str(excinfo.value)) + obj.as_dict(fields=["pk.unexisting"]) + assert "does not exist on" in str(excinfo.value) def test_backward_compatibility(self): parent = Parent(pk=1) child = Child(pk=1, parent_pk=parent.pk, parent=parent) with pytest.deprecated_call(): - d = child.as_dict(columns=['parent_pk']) - TestCase().assertDictEqual({'parent_pk': 1}, d) + d = child.as_dict(columns=["parent_pk"]) + TestCase().assertDictEqual({"parent_pk": 1}, d) with pytest.deprecated_call(): - d = parent.as_dict(relationships=['childs']) - TestCase().assertDictEqual({ - 'pk': 1, - 'childs': [{ - 'pk': 1, - 'parent_pk': 1, - }], - }, d) + d = parent.as_dict(relationships=["childs"]) + TestCase().assertDictEqual( + { + "pk": 1, + "childs": [ + { + "pk": 1, + "parent_pk": 1, + } + ], + }, + d, + ) with pytest.deprecated_call(): - d = child.as_dict(columns=['pk'], relationships=['parent']) - TestCase().assertDictEqual({ - 'pk': 1, - 'parent': { - 'pk': 1, + d = child.as_dict(columns=["pk"], relationships=["parent"]) + TestCase().assertDictEqual( + { + "pk": 1, + "parent": { + "pk": 1, + }, }, - }, d) + d, + ) def test_serializable_parameters(self): - @serializable(fields=['v_set'], stringify=False) + @serializable(fields=["v_set"], stringify=False) class U(db.Model): pk = db.Column(db.Integer, primary_key=True) - @serializable(exclude=['u_pk'], stringify=False) + @serializable(exclude=["u_pk"], stringify=False) class V(db.Model): pk = db.Column(db.Integer, primary_key=True) u_pk = db.Column(db.Integer, db.ForeignKey(U.pk)) - u = relationship('U', backref='v_set') + u = relationship("U", backref="v_set") u = U(pk=1) v = V(pk=1, u_pk=u.pk, u=u) d = u.as_dict() - TestCase().assertDictEqual({ - 'pk': 1, - 'v_set': [{ - 'pk': 1, - }], - }, d) + TestCase().assertDictEqual( + { + "pk": 1, + "v_set": [ + { + "pk": 1, + } + ], + }, + d, + ) d = v.as_dict() - TestCase().assertDictEqual({ - 'pk': 1, - }, d) + TestCase().assertDictEqual( + { + "pk": 1, + }, + d, + ) d = u.as_dict(fields=None, exclude=None) - TestCase().assertDictEqual({ - 'pk': 1, - 'v_set': [{ - 'pk': 1, - }], - }, d) + TestCase().assertDictEqual( + { + "pk": 1, + "v_set": [ + { + "pk": 1, + } + ], + }, + d, + ) d = v.as_dict(fields=None, exclude=None) - TestCase().assertDictEqual({ - 'pk': 1, - }, d) + TestCase().assertDictEqual( + { + "pk": 1, + }, + d, + ) d = u.as_dict(fields=[]) - TestCase().assertDictEqual({ - 'pk': 1, - }, d) + TestCase().assertDictEqual( + { + "pk": 1, + }, + d, + ) d = v.as_dict(exclude=[]) - TestCase().assertDictEqual({ - 'pk': 1, - 'u_pk': 1, - }, d) - - d = v.as_dict(fields=['u.pk']) - TestCase().assertDictEqual({ - 'pk': 1, - 'u': { - 'pk': 1, - }, - }, d) - - d = v.as_dict(fields=['u']) # use default U serialization parameters - TestCase().assertDictEqual({ - 'pk': 1, - 'u': { - 'pk': 1, - 'v_set': [{ - 'pk': 1, - }], - }, - }, d) + TestCase().assertDictEqual( + { + "pk": 1, + "u_pk": 1, + }, + d, + ) + + d = v.as_dict(fields=["u.pk"]) + TestCase().assertDictEqual( + { + "pk": 1, + "u": { + "pk": 1, + }, + }, + d, + ) + + d = v.as_dict(fields=["u"]) # use default U serialization parameters + TestCase().assertDictEqual( + { + "pk": 1, + "u": { + "pk": 1, + "v_set": [ + { + "pk": 1, + } + ], + }, + }, + d, + ) def test_as_dict_override(self): @serializable(stringify=False) @@ -422,7 +532,7 @@ class O(db.Model): pk = db.Column(db.Integer, primary_key=True) def as_dict(self, data): - data['o'] = True + data["o"] = True return data @serializable(stringify=False) @@ -433,76 +543,94 @@ class P(O): class Q(O): def as_dict(self, data): data = super().as_dict() - data['q'] = True + data["q"] = True return data @serializable(stringify=False) class R(P): def as_dict(self, data): - data['r'] = True + data["r"] = True return data @serializable(stringify=False) class S(P): def as_dict(self): data = super().as_dict() - data['s'] = True + data["s"] = True return data @serializable(stringify=False) class T(O): def as_dict(self, *args, **kwargs): data = super().as_dict(*args, **kwargs) - data['t'] = True + data["t"] = True return data - @serializable(exclude=['pk'], stringify=False) + @serializable(exclude=["pk"], stringify=False) @serializable class W(O): pass o = O(pk=1) d = o.as_dict() - TestCase().assertDictEqual({ - 'pk': 1, - 'o': True, - }, d) + TestCase().assertDictEqual( + { + "pk": 1, + "o": True, + }, + d, + ) p = P(pk=2) d = p.as_dict() - TestCase().assertDictEqual({ - 'pk': 2, - }, d) + TestCase().assertDictEqual( + { + "pk": 2, + }, + d, + ) q = Q(pk=3) d = q.as_dict() - TestCase().assertDictEqual({ - 'pk': 3, - 'o': True, - 'q': True, - }, d) + TestCase().assertDictEqual( + { + "pk": 3, + "o": True, + "q": True, + }, + d, + ) r = R(pk=4) d = r.as_dict() - TestCase().assertDictEqual({ - 'pk': 4, - 'r': True, - }, d) + TestCase().assertDictEqual( + { + "pk": 4, + "r": True, + }, + d, + ) s = S(pk=5) d = s.as_dict() - TestCase().assertDictEqual({ - 'pk': 5, - 's': True, - }, d) + TestCase().assertDictEqual( + { + "pk": 5, + "s": True, + }, + d, + ) t = T(pk=6) - d = t.as_dict(fields=['pk']) - TestCase().assertDictEqual({ - 'pk': 6, - 'o': True, - 't': True, - }, d) + d = t.as_dict(fields=["pk"]) + TestCase().assertDictEqual( + { + "pk": 6, + "o": True, + "t": True, + }, + d, + ) w = W(pk=7) d = w.as_dict() @@ -512,38 +640,47 @@ def test_renamed_field(self): @serializable(stringify=False) class TestModel2(db.Model): pk = db.Column(db.Integer, primary_key=True) - field = db.Column('column', db.String) + field = db.Column("column", db.String) - o = TestModel2(pk=1, field='test') + o = TestModel2(pk=1, field="test") d = o.as_dict() - TestCase().assertDictEqual({ - 'pk': 1, - 'field': 'test', - }, d) + TestCase().assertDictEqual( + { + "pk": 1, + "field": "test", + }, + d, + ) def test_null_relationship(self): parent = Parent(pk=1) child = Child(pk=1, parent_pk=None, parent=None) - d = parent.as_dict(fields=['childs']) - TestCase().assertDictEqual({ - 'pk': 1, - 'childs': [], - }, d) - - d = child.as_dict(fields=['parent']) - TestCase().assertDictEqual({ - 'pk': 1, - 'parent_pk': None, - 'parent': None, - }, d) + d = parent.as_dict(fields=["childs"]) + TestCase().assertDictEqual( + { + "pk": 1, + "childs": [], + }, + d, + ) + + d = child.as_dict(fields=["parent"]) + TestCase().assertDictEqual( + { + "pk": 1, + "parent_pk": None, + "parent": None, + }, + d, + ) def test_polymorphic_model(self): @serializable(stringify=False) class PolyModel(db.Model): __mapper_args__ = { - 'polymorphic_identity': 'IdentityBase', - 'polymorphic_on': 'kind', + "polymorphic_identity": "IdentityBase", + "polymorphic_on": "kind", } pk = db.Column(db.Integer, primary_key=True) kind = db.Column(db.String) @@ -552,7 +689,7 @@ class PolyModel(db.Model): @serializable(stringify=False) class PolyModelA(PolyModel): __mapper_args__ = { - 'polymorphic_identity': 'A', + "polymorphic_identity": "A", } pk = db.Column(db.Integer, db.ForeignKey(PolyModel.pk), primary_key=True) a = db.Column(db.String) @@ -560,34 +697,40 @@ class PolyModelA(PolyModel): @serializable(stringify=False) class PolyModelB(PolyModel): __mapper_args__ = { - 'polymorphic_identity': 'B', + "polymorphic_identity": "B", } pk = db.Column(db.Integer, db.ForeignKey(PolyModel.pk), primary_key=True) b = db.Column(db.String) - a = PolyModelA(pk=1, base='BA', a='A') + a = PolyModelA(pk=1, base="BA", a="A") d = a.as_dict() - TestCase().assertDictEqual({ - 'pk': 1, - 'kind': 'A', - 'base': 'BA', - 'a': 'A', - }, d) - - b = PolyModelB(pk=2, base='BB', b='B') + TestCase().assertDictEqual( + { + "pk": 1, + "kind": "A", + "base": "BA", + "a": "A", + }, + d, + ) + + b = PolyModelB(pk=2, base="BB", b="B") d = b.as_dict() - TestCase().assertDictEqual({ - 'pk': 2, - 'kind': 'B', - 'base': 'BB', - 'b': 'B', - }, d) + TestCase().assertDictEqual( + { + "pk": 2, + "kind": "B", + "base": "BB", + "b": "B", + }, + d, + ) d = { - 'pk': 2, - 'kind': 'B', - 'base': 'BB', - 'b': 'B', + "pk": 2, + "kind": "B", + "base": "BB", + "b": "B", } TestCase().assertDictEqual(d, PolyModelB().from_dict(d).as_dict()) @@ -604,12 +747,15 @@ def sum(self): h = PropertyModel(pk=1, a=2, b=3) d = h.as_dict() - TestCase().assertDictEqual({ - 'pk': 1, - 'a': 2, - 'b': 3, - 'sum': 5, - }, d) + TestCase().assertDictEqual( + { + "pk": 1, + "a": 2, + "b": 3, + "sum": 5, + }, + d, + ) def test_hybrid_property(self): @serializable(stringify=False) @@ -620,96 +766,126 @@ class HybridModel(db.Model): @hybrid_property def concat(self): - return '{0} {1}'.format(self.part1, self.part2) + return "{0} {1}".format(self.part1, self.part2) @concat.expression def concat(cls): - return db.func.concat(cls.part1, ' ', cls.part2) + return db.func.concat(cls.part1, " ", cls.part2) - h = HybridModel(pk=1, part1='a', part2='b') + h = HybridModel(pk=1, part1="a", part2="b") d = h.as_dict() - TestCase().assertDictEqual({ - 'pk': 1, - 'part1': 'a', - 'part2': 'b', - 'concat': 'a b', - }, d) + TestCase().assertDictEqual( + { + "pk": 1, + "part1": "a", + "part2": "b", + "concat": "a b", + }, + d, + ) def test_additional_fields(self): - @serializable(fields=['v_set'], exclude=['field2']) + @serializable(fields=["v_set"], exclude=["field2"]) class U2(db.Model): pk = db.Column(db.Integer, primary_key=True) field1 = db.Column(db.String) field2 = db.Column(db.String) - @serializable(fields=['pk']) + @serializable(fields=["pk"]) class V2(db.Model): pk = db.Column(db.Integer, primary_key=True) u_pk = db.Column(db.Integer, db.ForeignKey(U2.pk)) - u = relationship('U2', backref='v_set') + u = relationship("U2", backref="v_set") - u = U2(pk=1, field1='test', field2='test') + u = U2(pk=1, field1="test", field2="test") v = V2(pk=1, u_pk=u.pk, u=u) d = u.as_dict() - TestCase().assertDictEqual({ - 'pk': 1, - 'field1': 'test', - 'v_set': [{'pk': 1}], - }, d) + TestCase().assertDictEqual( + { + "pk": 1, + "field1": "test", + "v_set": [{"pk": 1}], + }, + d, + ) d = u.as_dict(exclude=[]) - TestCase().assertDictEqual({ - 'pk': 1, - 'field1': 'test', - 'field2': 'test', - 'v_set': [{'pk': 1}], - }, d) - - d = u.as_dict(fields=['field1']) - TestCase().assertDictEqual({ - 'field1': 'test', - }, d) + TestCase().assertDictEqual( + { + "pk": 1, + "field1": "test", + "field2": "test", + "v_set": [{"pk": 1}], + }, + d, + ) + + d = u.as_dict(fields=["field1"]) + TestCase().assertDictEqual( + { + "field1": "test", + }, + d, + ) # Verify that field2 is removed from default_exclude - d = u.as_dict(fields=['field1', 'field2']) - TestCase().assertDictEqual({ - 'field1': 'test', - 'field2': 'test', - }, d) + d = u.as_dict(fields=["field1", "field2"]) + TestCase().assertDictEqual( + { + "field1": "test", + "field2": "test", + }, + d, + ) # Verify that field2 is removed from default_exclude - d = u.as_dict(fields=['field1', '+field2']) - TestCase().assertDictEqual({ - 'field1': 'test', - 'field2': 'test', - }, d) + d = u.as_dict(fields=["field1", "+field2"]) + TestCase().assertDictEqual( + { + "field1": "test", + "field2": "test", + }, + d, + ) # Verify that field2 is removed from default_exclude - d = u.as_dict(fields=['field2']) - TestCase().assertDictEqual({ - 'field2': 'test', - }, d) + d = u.as_dict(fields=["field2"]) + TestCase().assertDictEqual( + { + "field2": "test", + }, + d, + ) # Verify that field2 is removed from default_exclude, will keeping default_fields - d = u.as_dict(fields=['+field2', 'v_set']) - TestCase().assertDictEqual({ - 'pk': 1, - 'field1': 'test', - 'field2': 'test', - 'v_set': [{'pk': 1}], - }, d) - - d = u.as_dict(fields=['v_set.u_pk']) - TestCase().assertDictEqual({ - 'pk': 1, - 'field1': 'test', - 'v_set': [{'u_pk': 1}], - }, d) - - d = u.as_dict(fields=['v_set.+u_pk']) - TestCase().assertDictEqual({ - 'pk': 1, - 'field1': 'test', - 'v_set': [{'pk': 1, 'u_pk': 1}], - }, d) + d = u.as_dict(fields=["+field2", "v_set"]) + TestCase().assertDictEqual( + { + "pk": 1, + "field1": "test", + "field2": "test", + "v_set": [{"pk": 1}], + }, + d, + ) + + d = u.as_dict(fields=["v_set.u_pk"]) + TestCase().assertDictEqual( + { + "pk": 1, + "field1": "test", + "v_set": [{"u_pk": 1}], + }, + d, + ) + + d = u.as_dict(fields=["v_set.+u_pk"]) + TestCase().assertDictEqual( + { + "pk": 1, + "field1": "test", + "v_set": [{"pk": 1, "u_pk": 1}], + }, + d, + )