From 8cbb96ca336fedd5ef2bca44c0373652b3bd95c8 Mon Sep 17 00:00:00 2001 From: Adam <56907039+adamjsturge@users.noreply.github.com> Date: Tue, 18 Jun 2024 19:58:12 -0700 Subject: [PATCH] Single DB Instance --- api.go | 28 ++++++++++++++-------------- database.go | 16 ++++++---------- dbhelper.go | 17 +++++++++-------- main.go | 4 ++-- 4 files changed, 31 insertions(+), 34 deletions(-) diff --git a/api.go b/api.go index 9551f18..0271cf7 100644 --- a/api.go +++ b/api.go @@ -34,8 +34,8 @@ func settingsHandler(w http.ResponseWriter, r *http.Request) { http.Error(w, "Not authenticated", http.StatusUnauthorized) } if r.Method == "GET" { - db := establish_database_connection() - defer db.Close() + // db := establish_database_connection() + // defer db.Close() rows, err := db.Query("SELECT key, value FROM settings WHERE key IN ($1, $2, $3, $4)", CORRELATION_API_SECRET_SETTINGS_KEY, CHAINLOAD_URI_SETTINGS_KEY, PAGES_TO_COLLECT_SETTINGS_KEY, SEND_ALERTS_SETTINGS_KEY) if err != nil { @@ -136,8 +136,8 @@ func payloadFiresHandler(w http.ResponseWriter, r *http.Request) { limit := parameter_to_int(limit_string, 10) offset := page * limit - db := establish_database_connection() - defer db.Close() + // db := establish_database_connection() + // defer db.Close() rows, err := db.Query("SELECT id, url, ip_address, referer, user_agent, cookies, title, dom, text, origin, screenshot_id, was_iframe, browser_timestamp, injection_requests_id FROM payload_fire_results ORDER BY created_at DESC LIMIT $1 OFFSET $2", limit, offset) if err != nil { @@ -168,8 +168,8 @@ func payloadFiresHandler(w http.ResponseWriter, r *http.Request) { http.Error(w, "No ids to delete", http.StatusBadRequest) return } - db := establish_database_connection() - defer db.Close() + // db := establish_database_connection() + // defer db.Close() rows, err := db.Query("SELECT screenshot_id FROM payload_fire_results WHERE id IN ($1)", ids_to_delete) if err != nil { @@ -214,8 +214,8 @@ func collectedPagesHandler(w http.ResponseWriter, r *http.Request) { limit := parameter_to_int(limit_string, 10) offset := page * limit - db := establish_database_connection() - defer db.Close() + // db := establish_database_connection() + // defer db.Close() rows, err := db.Query("SELECT id, uri FROM collected_pages ORDER BY created_at DESC LIMIT $1 OFFSET $2", limit, offset) if err != nil { @@ -311,8 +311,8 @@ func userPayloadsHandler(w http.ResponseWriter, r *http.Request) { // limit := parameter_to_int(limit_string, 10) // offset := page * limit - db := establish_database_connection() - defer db.Close() + // db := establish_database_connection() + // defer db.Close() rows, err := db.Query("SELECT id, payload, title, description, author, author_link FROM user_xss_payloads ORDER BY created_at ASC") if err != nil { @@ -339,8 +339,8 @@ func userPayloadsHandler(w http.ResponseWriter, r *http.Request) { } } else if r.Method == "POST" { - db := establish_database_connection() - defer db.Close() + // db := establish_database_connection() + // defer db.Close() stmt, _ := db.Prepare(`INSERT INTO user_xss_payloads (payload, title, description, author, author_link) VALUES ($1, $2, $3, $4, $5)`) _, err := stmt.Exec(r.FormValue("payload"), r.FormValue("title"), r.FormValue("description"), r.FormValue("author"), r.FormValue("author_link")) @@ -374,8 +374,8 @@ func userPayloadImporterHandler(w http.ResponseWriter, r *http.Request) { http.Error(w, "Invalid method", http.StatusMethodNotAllowed) return } - db := establish_database_connection() - defer db.Close() + // db := establish_database_connection() + // defer db.Close() var user_payloads []UserXSSPayloads err := json.NewDecoder(r.Body).Decode(&user_payloads) diff --git a/database.go b/database.go index 3b65bbd..841171a 100644 --- a/database.go +++ b/database.go @@ -41,9 +41,11 @@ type InjectionRequests struct { Injection_key string } +var db *sql.DB + func create_sqlite_tables() { - db := establish_sqlite_database_connection() - defer db.Close() + // db := establish_sqlite_database_connection() + // defer db.Close() sqlStmt := ` CREATE TABLE IF NOT EXISTS settings ( @@ -102,8 +104,8 @@ func create_sqlite_tables() { } func create_postgres_tables() { - db := establish_postgres_database_connection() - defer db.Close() + // db := establish_postgres_database_connection() + // defer db.Close() sqlStmt := ` CREATE TABLE IF NOT EXISTS settings ( @@ -190,9 +192,6 @@ func initialize_users() { } func setup_admin_user(password string) bool { - db := establish_database_connection() - defer db.Close() - hashed_password, err := hash_string(password) if err != nil { log.Fatal(err) @@ -218,9 +217,6 @@ func initialize_correlation_api() { } func initialize_setting_helper(key string, value string) bool { - db := establish_database_connection() - defer db.Close() - has_setting, setting_err := db_single_item_query("SELECT 1 FROM settings WHERE key = $1", key).toBool() if setting_err != nil { log.Fatal(setting_err) diff --git a/dbhelper.go b/dbhelper.go index fa248a2..3db4fda 100644 --- a/dbhelper.go +++ b/dbhelper.go @@ -26,8 +26,8 @@ type ResultsObject map[string]Result //lint:ignore U1000 Ignore unused function temporarily for debugging func db_select(query string, args ...any) (ResultsObjectArray, error) { - db := establish_database_connection() - defer db.Close() + // db := establish_database_connection() + // defer db.Close() rows, err := db.Query(query, args...) if err != nil { @@ -153,8 +153,8 @@ func toBool(value interface{}) (bool, error) { } func db_single_item_query(query string, args ...any) SingleResult { - db := establish_database_connection() - defer db.Close() + // db := establish_database_connection() + // defer db.Close() var result interface{} err := db.QueryRow(query, args...).Scan(&result) @@ -169,8 +169,8 @@ func db_single_item_query(query string, args ...any) SingleResult { } func db_prepare_execute(query string, args ...any) (sql.Result, error) { - db := establish_database_connection() - defer db.Close() + // db := establish_database_connection() + // defer db.Close() stmt, _ := db.Prepare(query) result, err := stmt.Exec(args...) @@ -182,8 +182,8 @@ func db_prepare_execute(query string, args ...any) (sql.Result, error) { } func db_execute(query string, args ...any) (sql.Result, error) { - db := establish_database_connection() - defer db.Close() + // db := establish_database_connection() + // defer db.Close() result, err := db.Exec(query, args...) if err != nil { @@ -194,6 +194,7 @@ func db_execute(query string, args ...any) (sql.Result, error) { } func initialize_database() { + db = establish_database_connection() if is_postgres { initialize_postgres_database() } else { diff --git a/main.go b/main.go index 291e5dc..bc82124 100644 --- a/main.go +++ b/main.go @@ -255,8 +255,8 @@ func jscallbackHandler(w http.ResponseWriter, r *http.Request) { if injection_key != "" { query := "SELECT id, request FROM injection_requests WHERE injection_key = $1" - db := establish_database_connection() - defer db.Close() + // db := establish_database_connection() + // defer db.Close() rows, err := db.Query(query, injection_key) if err != nil {