From a149f31c3a0e2f26621284078d5e6358130fabbd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Anton=C3=ADn=20Dufka?= Date: Fri, 4 Aug 2023 13:22:10 +0200 Subject: [PATCH] Add db export functionality --- README.md | 26 ++- dissect/utils/database_handler.py | 270 ++++++++++++++++++++---------- 2 files changed, 206 insertions(+), 90 deletions(-) diff --git a/README.md b/README.md index 1145f4b..ec470cb 100644 --- a/README.md +++ b/README.md @@ -118,14 +118,30 @@ dissect-feature_clusters features.csv outliers.csv ## Database -Command `dissect-database` provides a simple interface for uploading DiSSECT data from JSON files to a database for further analysis. To use this command you have to provide database URL which should be a string in format `"mongodb://USERNAME:PASSWORD@HOST/"` (e.g., `"mongodb://root:password@mongo:27017/`). +Command `dissect-database` provides a simple interface for import/export of database data. To use this command you have to provide database URL which should be a string in format `"mongodb://USERNAME:PASSWORD@HOST/"` (e.g., `"mongodb://root:password@mongo:27017/`) and select whether you want to `import` or `export` data. -To upload curves from a JSON file, use: +Curves can be imported from a JSON file with the following command: ```shell -dissect-database curves [DATABASE_URL] +dissect-database [DATABASE_URL] import -i ``` -To upload trait results from a JSON file, use: +Trait results can be imported using the same command, but the file name has to start with `trait_` prefix, for example, `trait_cofactor.json`: ```shell -dissect-database traits [DATABASE_URL] +dissect-database [DATABASE_URL] import -i +``` + +To export curves, use: +```shell +dissect-database [DATABASE_URL] export --no-traits -o +``` + +To export a selected trait results, use: +```shell +dissect-database [DATABASE_URL] export --no-curves --trait -o +``` + +All records in the database can be exported and imported using the following commands. +```shell +dissect-database [DATABASE_URL] export [-o ] +dissect-database [DATABASE_URL] import [-i ] ``` \ No newline at end of file diff --git a/dissect/utils/database_handler.py b/dissect/utils/database_handler.py index 3c99b6e..ea13a69 100644 --- a/dissect/utils/database_handler.py +++ b/dissect/utils/database_handler.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 +import bz2 import json -import sys +import os from typing import Optional, Tuple, Iterable, Dict, Any from pymongo import MongoClient @@ -38,7 +39,8 @@ def _format_curve(curve): c["params"] = curve["params"] try: if (curve["generator"]["x"]["raw"] or curve["generator"]["x"]["poly"]) and ( - curve["generator"]["y"]["raw"] or curve["generator"]["y"]["poly"]): + curve["generator"]["y"]["raw"] or curve["generator"]["y"]["poly"] + ): c["generator"] = curve["generator"] except: pass @@ -62,7 +64,11 @@ def _format_curve(curve): sim = {} if "seed" in curve and curve["seed"]: sim["seed"] = hex(int(curve["seed"], base=16)) - elif "characteristics" in curve and "seed" in curve["characteristics"] and curve["characteristics"]["seed"]: + elif ( + "characteristics" in curve + and "seed" in curve["characteristics"] + and curve["characteristics"]["seed"] + ): sim["seed"] = hex(int(curve["characteristics"]["seed"], base=16)) if sim: @@ -79,19 +85,17 @@ def _format_curve(curve): return c -def upload_curves(db: Database, path: str = None) -> Tuple[int, int]: +def import_curves(db: Database, curves: json) -> Tuple[int, int]: try: - if path: - with open(path, "r") as f: - curves = json.load(f) - else: - curves = json.load(sys.stdin) - - if not isinstance(curves, list): # inconsistency between simulated and standard format + if not isinstance( + curves, list + ): # inconsistency between simulated and standard format curves = curves["curves"] except Exception: # invalid format return 0, 0 + create_curves_index(db) + success = 0 for curve in curves: try: @@ -103,19 +107,14 @@ def upload_curves(db: Database, path: str = None) -> Tuple[int, int]: return success, len(curves) -def upload_results(db: Database, trait_name: str, path: str = None) -> Tuple[int, int]: - try: - if path: - with open(path, "r") as f: - results = json.load(f) - else: - results = json.load(sys.stdin) - except Exception: # invalid format - return 0, 0 +def import_trait_results( + db: Database, trait_name: str, trait_results: json = None +) -> Tuple[int, int]: + create_trait_index(db, trait_name) success = 0 total = 0 - for result in results: + for result in trait_results: total += 1 record = {} @@ -145,7 +144,9 @@ def upload_results(db: Database, trait_name: str, path: str = None) -> Tuple[int def get_curves(db: Database, query: Any = None) -> Iterable[Any]: aggregate_pipeline = [] - aggregate_pipeline.append({"$match": format_curve_query(query) if query else dict()}) + aggregate_pipeline.append( + {"$match": format_curve_query(query) if query else dict()} + ) aggregate_pipeline.append({"$unset": "_id"}) curves = list(db["curves"].aggregate(aggregate_pipeline)) @@ -159,10 +160,11 @@ def get_curves_count(db: Database, query: Any = None) -> int: def get_curve_categories(db: Database) -> Iterable[str]: return db["curves"].distinct("category") + def format_curve_query(query: Dict[str, Any]) -> Dict[str, Any]: result = {} - def helper(key, cast, db_key = None): + def helper(key, cast, db_key=None): if key not in query: return @@ -174,7 +176,7 @@ def helper(key, cast, db_key = None): if len(query[key]) == 1: result[db_key] = cast(query[key][0]) else: - result[db_key] = { "$in": list(map(cast, query[key])) } + result[db_key] = {"$in": list(map(cast, query[key]))} elif query[key] != "all": result[db_key] = cast(query[key]) @@ -191,6 +193,7 @@ def helper(key, cast, db_key = None): def _cast_sage_types(result: Any) -> Any: from sage.all import Integer + if isinstance(result, Integer): return int(result) @@ -206,6 +209,7 @@ def _cast_sage_types(result: Any) -> Any: def _encode_ints(result: Any) -> Any: from sage.all import Integer + if isinstance(result, Integer) or isinstance(result, int): return hex(result) if isinstance(result, dict): @@ -219,11 +223,11 @@ def _encode_ints(result: Any) -> Any: def store_trait_result( - db: Database, - curve: Any, - trait: str, - params: Dict[str, Any], - result: Dict[str, Any], + db: Database, + curve: Any, + trait: str, + params: Dict[str, Any], + result: Dict[str, Any], ) -> bool: trait_result = {} trait_result["curve"] = {} @@ -243,22 +247,19 @@ def store_trait_result( return False -def is_solved( - db: Database, curve: Any, trait: str, params: Dict[str, Any] -) -> bool: - trait_result = { "curve.name": curve.name() } +def is_solved(db: Database, curve: Any, trait: str, params: Dict[str, Any]) -> bool: + trait_result = {"curve.name": curve.name()} trait_result["params"] = _cast_sage_types(params) return db[f"trait_{trait}"].find_one(trait_result) is not None def get_trait_results( - db: Database, - trait: str, - query: Dict[str, Any] = None, - limit: int = None + db: Database, trait: str, query: Dict[str, Any] = None, limit: int = None ): aggregate_pipeline = [] - aggregate_pipeline.append({"$match": format_trait_query(trait, query) if query else dict()}) + aggregate_pipeline.append( + {"$match": format_trait_query(trait, query) if query else dict()} + ) aggregate_pipeline.append({"$unset": "_id"}) if limit: aggregate_pipeline.append({"$limit": limit}) @@ -266,14 +267,18 @@ def get_trait_results( aggregated = list(db[f"trait_{trait}"].aggregate(aggregate_pipeline)) return map(_decode_ints, map(_flatten_trait_result, aggregated)) + def get_trait_results_count(db: Database, trait: str, query: Dict[str, Any] = None): - return db[f"trait_{trait}"].count_documents(format_trait_query(trait, query) if query else dict()) + return db[f"trait_{trait}"].count_documents( + format_trait_query(trait, query) if query else dict() + ) + def format_trait_query(trait_name: str, query: Dict[str, Any]) -> Dict[str, Any]: result = {} query = query.copy() - def helper(key, cast, db_key = None): + def helper(key, cast, db_key=None): if key not in query: return @@ -285,7 +290,7 @@ def helper(key, cast, db_key = None): if len(query[key]) == 1: result[db_key] = cast(query[key][0]) else: - result[db_key] = { "$in": list(map(cast, query[key])) } + result[db_key] = {"$in": list(map(cast, query[key]))} elif query[key] != "all": result[db_key] = cast(query[key]) @@ -303,12 +308,15 @@ def helper(key, cast, db_key = None): helper(key, TRAITS[trait_name].INPUT[key][0], f"params.{key}") for key in TRAITS[trait_name].OUTPUT: - helper(key, lambda x: _encode_ints(TRAITS[trait_name].OUTPUT[key][0](x)), f"result.{key}") + helper( + key, + lambda x: _encode_ints(TRAITS[trait_name].OUTPUT[key][0](x)), + f"result.{key}", + ) return result - # TODO move to data_processing? def _flatten_trait_result(record: Dict[str, Any]): output = dict() @@ -323,7 +331,7 @@ def _flatten_trait_result(record: Dict[str, Any]): def _flatten_trait_result_rec( - record: Dict[str, Any], prefix: str, output: Dict[str, Any] + record: Dict[str, Any], prefix: str, output: Dict[str, Any] ): for key in record: if isinstance(record[key], dict): @@ -333,7 +341,9 @@ def _flatten_trait_result_rec( def _decode_ints(source: Any) -> Any: - if isinstance(source, str) and (source[:2].lower() == "0x" or source[:3].lower() == "-0x"): + if isinstance(source, str) and ( + source[:2].lower() == "0x" or source[:3].lower() == "-0x" + ): return int(source, base=16) if isinstance(source, dict): for key, value in source.items(): @@ -344,48 +354,138 @@ def _decode_ints(source: Any) -> Any: return source +def import_file(db: Database, path: str): + name = os.path.basename(path) + if name.startswith("trait_"): + trait_name = name[len("trait_") :].split(os.extsep, 1)[0] + print(f"Importing trait {trait_name} from {name}") + if name.endswith(".json.bz2"): + with bz2.open(path, "rb") as f: + results = json.load(f) + elif name.endswith(".json"): + with open(path, "r") as f: + results = json.load(f) + else: + raise ValueError("Invalid file format") + + succ, total = import_trait_results(db, trait_name, results) + print(f"- imported {succ} out of {total} trait results") + else: + print(f"Importing curves from {name}") + + if name.endswith(".json"): + with open(path, "r") as f: + curves = json.load(f) + elif name.endswith(".json.bz2"): + with bz2.open(path, "rb") as f: + curves = json.load(f) + else: + raise ValueError("Invalid file format") + + succ, total = import_curves(db, curves) + print(f"- imported {succ} out of {total} curves") + + def main(): - import sys + import argparse + import tempfile + import tarfile + import shutil - if len(sys.argv) < 3 or not sys.argv[1] in ("curves", "traits"): - print( - f"USAGE: {sys.argv[0]} curves [database_url] ", - file=sys.stderr, - ) - print( - f" OR: {sys.argv[0]} traits [database_url] ", - file=sys.stderr, - ) - sys.exit(1) - - database_url = "mongodb://localhost:27017/" - args = sys.argv[2:] - for idx, arg in enumerate(args): - if "mongodb://" in arg: - database_url = arg - del args[idx] - break - - print(f"Connecting to database {database_url}") - db = connect(database_url) - - def upload_curves_from_files(curve_files_list): - for curves_file in curve_files_list: - print(f"Loading curves from file {curves_file}") - create_curves_index(db) - uploaded, total = upload_curves(db, curves_file) - print(f"Successfully uploaded {uploaded} out of {total}") - - def upload_results_from_file(trait_name, results_file): - print(f"Loading trait {trait_name} results from file {results_file}") - create_trait_index(db, trait_name) - uploaded, total = upload_results(db, trait_name, results_file) - print(f"Successfully uploaded {uploaded} out of {total}") - - if sys.argv[1] == "curves": - upload_curves_from_files(args if args else [None]) - elif sys.argv[1] == "traits": - upload_results_from_file(args[0], args[1] if len(args) > 1 else None) + parser = argparse.ArgumentParser() + parser.add_argument( + "--database_url", type=str, default="mongodb://localhost:27017/" + ) + subparsers = parser.add_subparsers(dest="command") + + parser_export = subparsers.add_parser("export") + parser_export.add_argument("-o", "--output", type=str, default="dissect.tar") + parser_export.add_argument("--no-curves", default=False, action="store_true") + parser_export.add_argument("--no-traits", default=False, action="store_true") + parser_export.add_argument("--trait", type=str, default=["all"], nargs="*") + + parser_import = subparsers.add_parser("import") + parser_import.add_argument("-i", "--input", type=str, default="dissect.tar") + + args = parser.parse_args() + + db = connect(args.database_url) + + with tempfile.TemporaryDirectory() as tmpdir: + + def export_records(collection_name): + document_count = db[collection_name].estimated_document_count() + + if document_count == 0: + print(f"Skipping {collection_name} (no records)") + return + + with open(os.path.join(tmpdir, f"{collection_name}.json.bz2"), "wb") as f: + print(f"Exporting {collection_name} (~{document_count} records)") + compressor = bz2.BZ2Compressor() + f.write(compressor.compress(b"[\n")) + for idx, record in enumerate(db[collection_name].find()): + if idx != 0: + f.write(compressor.compress(b",\n")) + del record["_id"] + f.write(compressor.compress(json.dumps(record, indent=2).encode())) + f.write(compressor.compress(b"\n]\n")) + f.write(compressor.flush()) + + if args.command == "export": + if not args.no_curves: + export_records("curves") + + if not args.no_traits: + trait_collections = list( + filter(lambda x: x.startswith("trait_"), db.list_collection_names()) + ) + if not "all" in args.trait: + trait_collections = list( + filter( + lambda x: x[len("trait_") :] in args.trait, + trait_collections, + ) + ) + for trait_collection in trait_collections: + export_records(trait_collection) + + output_files = os.listdir(tmpdir) + + if os.path.isdir(args.output): + for file in output_files: + shutil.copyfile( + os.path.join(tmpdir, file), os.path.join(args.output, file) + ) + elif len(output_files) == 1 and args.output.endswith(".json.bz2"): + shutil.copyfile( + os.path.join(tmpdir, output_files[0]), + args.output, + ) + elif len(output_files) == 1 and args.output.endswith(".json"): + with open(args.output, "wb") as output_file, open( + os.path.join(tmpdir, output_files[0]), "rb" + ) as input_file: + decompressor = bz2.BZ2Decompressor() + for data in iter(lambda: input_file.read(1024 * 1024), b""): + output_file.write(decompressor.decompress(data)) + else: + with tarfile.open(args.output, "w") as tar: + for file in output_files: + tar.add(f"{tmpdir}/{file}", arcname=file) + + elif args.command == "import": + if args.input.endswith(".tar"): + with tempfile.TemporaryDirectory() as tmpdir: + with tarfile.open(args.input, "r") as tar: + tar.extractall(tmpdir) + + for file in os.listdir(tmpdir): + import_file(db, os.path.join(tmpdir, file)) + elif args.input.endswith(".json") or args.input.endswith(".json.bz2"): + import_file(db, args.input) + else: + print("Unknown input format") if __name__ == "__main__":