diff --git a/test/test_db.py b/test/test_db.py deleted file mode 100644 index 0cb7f32..0000000 --- a/test/test_db.py +++ /dev/null @@ -1,198 +0,0 @@ -import json -import os -import tempfile - -import pytest - -from vdb.lib import db as db -from vdb.lib.gha import GitHubSource -from vdb.lib.nvd import NvdSource -from vdb.lib.utils import parse_cpe - - -@pytest.fixture -def test_db(): - with tempfile.NamedTemporaryFile(delete=False) as fp: - with tempfile.NamedTemporaryFile(delete=False) as indexfp: - return db.get(db_file=fp.name, index_file=indexfp.name) - - -@pytest.fixture -def test_vuln_data(): - test_cve_data = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "data", "cve_data.json" - ) - with open(test_cve_data, "r") as fp: - json_data = json.loads(fp.read()) - nvdlatest = NvdSource() - return nvdlatest.convert(json_data) - - -@pytest.fixture -def test_gha_data(): - test_cve_data = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "data", "gha_data.json" - ) - with open(test_cve_data, "r") as fp: - json_data = json.loads(fp.read()) - ghalatest = GitHubSource() - return ghalatest.convert(json_data)[0] - - -def test_create(test_db, test_vuln_data): - docs = db.store(test_db, test_vuln_data) - assert len(docs) > len(test_vuln_data) - - -@pytest.mark.skip(reason="Slow test") -def test_search_slow(test_db, test_vuln_data): - table = test_db - docs = db.list_all(table) - assert len(docs) == 0 - docs = db.store(test_db, test_vuln_data) - assert len(docs) > 0 - all_data = db.list_all(table) - assert all_data - for d in all_data: - res = db.pkg_search( - table, - d["details"]["package"], - d["details"]["mai"], - ) - assert len(res) - assert res[0].to_dict()["package_issue"] - - -def test_search_fast(test_db, test_vuln_data): - table = test_db - docs = db.list_all(table) - assert len(docs) == 0 - docs = db.store(test_db, test_vuln_data) - assert len(docs) > 0 - all_data = db.list_all(table) - assert all_data - search_list = [ - { - "name": d["details"]["package"], - "version": d["details"]["mai"], - } - for d in all_data - ] - res = db.bulk_index_search(search_list) - assert len(res) - - -def test_gha_create(test_db, test_gha_data): - docs = db.store(test_db, test_gha_data) - assert len(docs) > len(test_gha_data) - - -def test_gha_search_slow(test_db, test_gha_data): - table = test_db - docs = db.list_all(table) - assert len(docs) == 0 - docs = db.store(test_db, test_gha_data) - assert len(docs) > 0 - all_data = db.list_all(table) - assert all_data - for d in all_data: - version = d["details"]["mai"] - if version and version != "*": - res = db.pkg_search( - table, - d["details"]["package"], - version, - ) - assert len(res) - assert res[0].to_dict()["package_issue"] - - -def test_gha_vendor_search(test_db, test_gha_data): - table = test_db - docs = db.list_all(table) - assert len(docs) == 0 - docs = db.store(test_db, test_gha_data) - assert len(docs) > 0 - all_data = db.list_all(table) - assert all_data - for d in all_data: - vendor, _, _, cve_type = parse_cpe(d["details"]["cpe_uri"]) - version = d["details"]["mai"] - if version and version != "*": - res = db.vendor_pkg_search( - table, - vendor, - d["details"]["package"], - version, - ) - assert len(res) - assert res[0].to_dict()["package_issue"] - - -def test_gha_search_bulk(test_db, test_gha_data): - table = test_db - docs = db.list_all(table) - assert len(docs) == 0 - docs = db.store(test_db, test_gha_data) - assert len(docs) > 0 - all_data = db.list_all(table) - assert all_data - tmp_list = [ - { - "name": d["details"]["package"], - "version": d["details"]["mai"], - } - for d in all_data - if d["details"]["mai"] != "*" - ] - res = db.bulk_index_search(tmp_list) - assert len(res) - - -def test_index_search(test_db, test_vuln_data): - # This slow test ensures that every data in the main database is indexed - table = test_db - docs = db.list_all(table) - assert len(docs) == 0 - docs = db.store(test_db, test_vuln_data) - assert len(docs) > 0 - all_data = db.list_all(table) - assert all_data - tmp_list = [] - for d in all_data[:40]: - version = d["details"]["mai"] - if version and version != "*": - tmp_list.append({"name": d["details"]["package"], "version": version}) - res = db.bulk_index_search(tmp_list) - assert len(res) - for r in res: - name_ver = r.split("|") - fullres = db.index_search(name_ver[1], name_ver[2]) - assert fullres - - -def test_vendor_index_search(test_db, test_vuln_data): - # This slow test ensures that every data in the main database is indexed - table = test_db - docs = db.list_all(table) - assert len(docs) == 0 - docs = db.store(test_db, test_vuln_data) - assert len(docs) > 0 - all_data = db.list_all(table) - assert all_data - tmp_list = [] - for d in all_data[:40]: - vendor, _, _, cve_type = parse_cpe(d["details"]["cpe_uri"]) - tmp_list.append( - { - "vendor": vendor, - "name": d["details"]["package"], - "version": d["details"]["mai"], - } - ) - res = db.bulk_index_search(tmp_list) - assert len(res) - for r in res: - name_ver = r.split("|") - fullres = db.index_search(name_ver[2], name_ver[3]) - assert fullres diff --git a/vdb/cli.py b/vdb/cli.py index 05e0fb5..f4f4e5d 100644 --- a/vdb/cli.py +++ b/vdb/cli.py @@ -170,7 +170,7 @@ def main(): for s in sources: LOG.info("Refreshing %s", s.__class__.__name__) s.refresh() - db_lib.close_all() + db_lib.optimize_and_close_all() elif args.sync: for s in (GitHubSource(),): LOG.info("Syncing %s", s.__class__.__name__) diff --git a/vdb/lib/db6.py b/vdb/lib/db6.py index 8c70469..cc3dba6 100644 --- a/vdb/lib/db6.py +++ b/vdb/lib/db6.py @@ -16,16 +16,8 @@ def ensure_schemas(db_conn: sqlite3.Connection, index_conn: sqlite3.Connection): """Create the sqlite tables and indexes in case they don't exist""" db_conn.execute( "CREATE TABLE if not exists cve_data(cve_id TEXT NOT NULL, type TEXT NOT NULL, namespace TEXT, name TEXT NOT NULL, source_data JSON NOT NULL, override_data JSON);") - db_conn.execute( - "CREATE INDEX if not exists idx1 on cve_data(cve_id, type);") index_conn.execute( "CREATE TABLE if not exists cve_index(cve_id TEXT NOT NULL, type TEXT NOT NULL, namespace TEXT, name TEXT NOT NULL, versions JSON NOT NULL);") - index_conn.execute( - "CREATE INDEX if not exists cidx1 on cve_index(cve_id);") - index_conn.execute( - "CREATE INDEX if not exists cidx2 on cve_index(type, namespace, name);") - index_conn.execute( - "CREATE INDEX if not exists cidx3 on cve_index(namespace, name);") def get(db_file: str = config.VDB_BIN_FILE, index_file: str = config.VDB_BIN_INDEX, read_only=False) -> ( @@ -57,12 +49,23 @@ def clear_all(): index_conn.commit() -def close_all(): +def optimize_and_close_all(): + """ + Safely close the connections by creating indexes and vacuuming if needed. + """ if db_conn: + db_conn.execute( + "CREATE INDEX if not exists idx1 on cve_data(cve_id, type);") db_conn.execute("VACUUM;") db_conn.commit() db_conn.close() if index_conn: + index_conn.execute( + "CREATE INDEX if not exists cidx1 on cve_index(cve_id);") + index_conn.execute( + "CREATE INDEX if not exists cidx2 on cve_index(type, namespace, name);") + index_conn.execute( + "CREATE INDEX if not exists cidx3 on cve_index(namespace, name);") index_conn.execute("VACUUM;") index_conn.commit() index_conn.close()