Skip to content

Commit

Permalink
Use apsw to force sqlite version and switch to jsonb for improved per…
Browse files Browse the repository at this point in the history
…formance

Signed-off-by: Prabhu Subramanian <[email protected]>
  • Loading branch information
prabhu committed Mar 18, 2024
1 parent 79d27df commit 4f09d8c
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 36 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ dependencies = [
"packageurl-python",
"cvss",
"pydantic[email]",
"rich"
"rich",
"apsw>=3.45.2.0"
]
requires-python = ">=3.10"
readme = "README.md"
Expand Down
8 changes: 5 additions & 3 deletions vdb/lib/cve.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,8 @@ def store(self, data: list[Vulnerability]):
self.store5(cve5_list)

def store5(self, data: list[CVE]):
self.db_conn.execute("BEGIN")
self.index_conn.execute("BEGIN")
for d in data:
cve_id = d.cveMetadata.cveId
if d.containers.cna and d.containers.cna.affected:
Expand All @@ -373,7 +375,7 @@ def store5(self, data: list[CVE]):
exclude_unset=True,
exclude_none=True)
self.db_conn.execute(
"INSERT INTO cve_data values(?, ?, ?, ?, json(?), ?);", (
"INSERT INTO cve_data values(?, ?, ?, ?, jsonb(?), ?);", (
cve_id,
affected.vendor,
affected.product,
Expand All @@ -389,5 +391,5 @@ def store5(self, data: list[CVE]):
to_purl_vers(affected.vendor, affected.versions)
)
)
self.db_conn.commit()
self.index_conn.commit()
self.db_conn.execute("COMMIT")
self.index_conn.execute("COMMIT")
48 changes: 19 additions & 29 deletions vdb/lib/db6.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,47 @@
import os
import sqlite3
import sys

import apsw

from vdb.lib import config

db_conn: sqlite3.Connection = None
index_conn: sqlite3.Connection = None
db_conn: apsw.Connection = None
index_conn: apsw.Connection = None
tables_created = False
db_file_sep = "///" if sys.platform == "win32" else "//"
DB_FILE_SEP = "///" if sys.platform == "win32" else "//"


def ensure_schemas(db_conn: sqlite3.Connection, index_conn: sqlite3.Connection):
def ensure_schemas(db_conn_obj: apsw.Connection, index_conn_obj: apsw.Connection):
"""Create the sqlite tables and indexes in case they don't exist"""
db_conn.execute(
"CREATE TABLE if not exists cve_data(cve_id TEXT NOT NULL, type TEXT NOT NULL, namespace TEXT, name TEXT NOT NULL, source_data JSON NOT NULL, override_data JSON);")
db_conn.executescript("""PRAGMA synchronous = OFF;
PRAGMA journal_mode = MEMORY;
""")
db_conn.commit()
index_conn.execute(
db_conn_obj.execute(
f"CREATE TABLE if not exists cve_data(cve_id TEXT NOT NULL, type TEXT NOT NULL, namespace TEXT, name TEXT NOT NULL, source_data BLOB NOT NULL, override_data BLOB);")
db_conn_obj.pragma("synchronous", "OFF")
db_conn_obj.pragma("journal_mode", "MEMORY")
index_conn_obj.execute(
"CREATE TABLE if not exists cve_index(cve_id TEXT NOT NULL, type TEXT NOT NULL, namespace TEXT, name TEXT NOT NULL, vers TEXT NOT NULL);")
index_conn.executescript("""PRAGMA synchronous = OFF;
PRAGMA journal_mode = MEMORY;
""")
index_conn.commit()
index_conn_obj.pragma("synchronous", "OFF")
index_conn_obj.pragma("journal_mode", "MEMORY")


def get(db_file: str = config.VDB_BIN_FILE, index_file: str = config.VDB_BIN_INDEX, read_only=False) -> (
sqlite3.Connection, sqlite3.Connection):
apsw.Connection, apsw.Connection):
global db_conn, index_conn, tables_created
if not db_file.startswith("file:"):
db_file = f"file:{db_file_sep}{os.path.abspath(db_file)}"
db_file = f"file:{DB_FILE_SEP}{os.path.abspath(db_file)}"
if not index_file.startswith("file:"):
index_file = f"file:{db_file_sep}{os.path.abspath(index_file)}"
if read_only:
db_file = f"{db_file}?mode=ro"
index_file = f"{index_file}?mode=ro"
index_file = f"file:{DB_FILE_SEP}{os.path.abspath(index_file)}"
flags = apsw.SQLITE_OPEN_URI | apsw.SQLITE_OPEN_NOFOLLOW | (apsw.SQLITE_OPEN_READONLY if read_only else apsw.SQLITE_OPEN_CREATE | apsw.SQLITE_OPEN_READWRITE)
if not db_conn:
db_conn = sqlite3.connect(db_file, uri=True)
db_conn = apsw.Connection(db_file, flags=flags)
if not index_conn:
index_conn = sqlite3.connect(index_file, uri=True)
index_conn = apsw.Connection(index_file, flags=flags)
if not tables_created:
ensure_schemas(db_conn, index_conn)
tables_created = True
return db_conn, index_conn


def stats():
global db_conn, index_conn
cve_data_count = 0
res = db_conn.execute("SELECT count(*) FROM cve_data").fetchone()
if res:
Expand All @@ -62,10 +56,8 @@ def stats():
def clear_all():
if db_conn:
db_conn.execute("DELETE FROM cve_data;")
db_conn.commit()
if index_conn:
index_conn.execute("DELETE FROM cve_index;")
index_conn.commit()


def optimize_and_close_all():
Expand All @@ -76,7 +68,6 @@ def optimize_and_close_all():
db_conn.execute(
"CREATE INDEX if not exists idx1 on cve_data(cve_id, type);")
db_conn.execute("VACUUM;")
db_conn.commit()
db_conn.close()
if index_conn:
index_conn.execute(
Expand All @@ -88,5 +79,4 @@ def optimize_and_close_all():
index_conn.execute(
"CREATE INDEX if not exists cidx4 on cve_index(namespace, name);")
index_conn.execute("VACUUM;")
index_conn.commit()
index_conn.close()
6 changes: 3 additions & 3 deletions vdb/lib/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def get_cve_data(db_conn, index_hits: list[dict, Any], search_str: str) -> list[
data_list = []
for ahit in index_hits:
results = exec_query(db_conn,
"SELECT cve_id, type, namespace, name, source_data, override_data FROM cve_data WHERE cve_id = ? AND type = ? ORDER BY cve_id DESC;",
"SELECT cve_id, type, namespace, name, json_object('source', source_data), json_object('override', override_data) FROM cve_data WHERE cve_id = ? AND type = ? ORDER BY cve_id DESC;",
(ahit["cve_id"], ahit["type"]))
for res in results:
data_list.append({
Expand All @@ -48,8 +48,8 @@ def get_cve_data(db_conn, index_hits: list[dict, Any], search_str: str) -> list[
"name": res[3],
"matching_vers": ahit["vers"],
"matched_by": search_str,
"source_data": CVE.model_validate(orjson.loads(res[4]), strict=False) if res[4] else None,
"override_data": orjson.loads(res[5]) if res[5] else None
"source_data": CVE.model_validate(orjson.loads(res[4])["source"], strict=False) if res[4] else None,
"override_data": orjson.loads(res[5])["override"] if res[5] else None
})
return data_list

Expand Down

0 comments on commit 4f09d8c

Please sign in to comment.