From d59ef13386f2d85cac9b318f902eb11d7564e42f Mon Sep 17 00:00:00 2001 From: rizwana hardoyal Date: Fri, 9 Dec 2022 12:58:17 +0400 Subject: [PATCH] postgres integration with sql changes for lcp, lsd --- config/config.go | 4 ++ dbmodel/postgres _db_setup_frontend.sql | 50 +++++++++++++++++++++++++ dbmodel/postgres_db_setup_lcpserver.sql | 25 +++++++++++++ dbmodel/postgres_db_setup_lsdserver.sql | 26 +++++++++++++ dbutils/dbutils.go | 31 +++++++++++++++ dbutils/dbutils_test.go | 17 +++++++++ index/index.go | 34 +++++++++++++---- license/store.go | 33 ++++++++-------- license_statuses/license_statuses.go | 17 +++++---- transactions/transactions.go | 15 ++++---- 10 files changed, 214 insertions(+), 38 deletions(-) create mode 100644 dbmodel/postgres _db_setup_frontend.sql create mode 100644 dbmodel/postgres_db_setup_lcpserver.sql create mode 100644 dbmodel/postgres_db_setup_lsdserver.sql create mode 100644 dbutils/dbutils.go create mode 100644 dbutils/dbutils_test.go diff --git a/config/config.go b/config/config.go index 9adfa079..3dfa5c6e 100644 --- a/config/config.go +++ b/config/config.go @@ -135,6 +135,10 @@ func GetDatabase(uri string) (string, string) { } parts := strings.Split(uri, "://") + if parts[0] == "postgres" { + return parts[0], uri + } + return parts[0], parts[1] } diff --git a/dbmodel/postgres _db_setup_frontend.sql b/dbmodel/postgres _db_setup_frontend.sql new file mode 100644 index 00000000..c03149b3 --- /dev/null +++ b/dbmodel/postgres _db_setup_frontend.sql @@ -0,0 +1,50 @@ +CREATE SEQUENCE publication_seq; + +CREATE TABLE publication ( + id int PRIMARY KEY DEFAULT NEXTVAL ('publication_seq'), + uuid varchar(255) NOT NULL, /* == content id */ + title varchar(255) NOT NULL, + status varchar(255) NOT NULL +); + +CREATE INDEX uuid_index ON publication (uuid); + +CREATE SEQUENCE user_seq; + +CREATE TABLE "user" ( + id int PRIMARY KEY DEFAULT NEXTVAL ('user_seq'), + uuid varchar(255) NOT NULL, + name varchar(64) NOT NULL, + email varchar(64) NOT NULL, + password varchar(64) NOT NULL, + hint varchar(64) NOT NULL +); + +CREATE SEQUENCE purchase_seq; + +CREATE TABLE purchase ( + id int PRIMARY KEY DEFAULT NEXTVAL ('purchase_seq'), + uuid varchar(255) NOT NULL, + publication_id int NOT NULL, + user_id int NOT NULL, + license_uuid varchar(255) NULL, + type varchar(32) NOT NULL, + transaction_date timestamp(0), + start_date timestamp(0), + end_date timestamp(0), + status varchar(255) NOT NULL, + FOREIGN KEY (publication_id) REFERENCES publication (id), + FOREIGN KEY (user_id) REFERENCES "user" (id) +); + +CREATE INDEX idx_purchase ON purchase (license_uuid); + +CREATE SEQUENCE license_view_seq; + +CREATE TABLE license_view ( + id int PRIMARY KEY DEFAULT NEXTVAL ('license_view_seq'), + uuid varchar(255) NOT NULL, + device_count int NOT NULL, + status varchar(255) NOT NULL, + message varchar(255) NOT NULL +); \ No newline at end of file diff --git a/dbmodel/postgres_db_setup_lcpserver.sql b/dbmodel/postgres_db_setup_lcpserver.sql new file mode 100644 index 00000000..0097a35b --- /dev/null +++ b/dbmodel/postgres_db_setup_lcpserver.sql @@ -0,0 +1,25 @@ + +CREATE TABLE content ( + id varchar(255) PRIMARY KEY NOT NULL, + encryption_key bytea NOT NULL, + location text NOT NULL, + length bigint, + sha256 varchar(64), + type varchar(255) NOT NULL DEFAULT 'application/epub+zip' +); + +-- SQLINES LICENSE FOR EVALUATION USE ONLY +CREATE TABLE license ( + id varchar(255) PRIMARY KEY NOT NULL, + user_id varchar(255) NOT NULL, + provider varchar(255) NOT NULL, + issued timestamp(0) NOT NULL, + updated timestamp(0) DEFAULT NULL, + rights_print int DEFAULT NULL, + rights_copy int DEFAULT NULL, + rights_start timestamp(0) DEFAULT NULL, + rights_end timestamp(0) DEFAULT NULL, + content_fk varchar(255) NOT NULL, + lsd_status int default 0, + FOREIGN KEY(content_fk) REFERENCES content(id) +); \ No newline at end of file diff --git a/dbmodel/postgres_db_setup_lsdserver.sql b/dbmodel/postgres_db_setup_lsdserver.sql new file mode 100644 index 00000000..39fff02c --- /dev/null +++ b/dbmodel/postgres_db_setup_lsdserver.sql @@ -0,0 +1,26 @@ +CREATE TABLE license_status ( + id serial4 NOT NULL, + status smallint NOT NULL, + license_updated timestamp(3) NOT NULL, + status_updated timestamp(3) NOT NULL, + device_count smallint DEFAULT NULL, + potential_rights_end timestamp(3) DEFAULT NULL, + license_ref varchar(255) NOT NULL, + rights_end timestamp(3) DEFAULT NULL, + CONSTRAINT license_status_pkey PRIMARY KEY (id) +); + +CREATE INDEX license_ref_index ON license_status (license_ref); + +CREATE TABLE event ( + id serial4 NOT NULL, + device_name varchar(255) DEFAULT NULL, + timestamp timestamp(3) NOT NULL, + type int NOT NULL, + device_id varchar(255) DEFAULT NULL, + license_status_fk int NOT NULL, + CONSTRAINT event_pkey PRIMARY KEY (id), + FOREIGN KEY(license_status_fk) REFERENCES license_status(id) +); + +CREATE INDEX license_status_fk_index on event (license_status_fk); \ No newline at end of file diff --git a/dbutils/dbutils.go b/dbutils/dbutils.go new file mode 100644 index 00000000..5e7aa8bc --- /dev/null +++ b/dbutils/dbutils.go @@ -0,0 +1,31 @@ +package dbutils + +import ( + "bytes" + "fmt" + "strings" +) + +func getPostgresQuery(query string) string { + var buffer bytes.Buffer + idx := 1 + for _, char := range query { + if char == '?' { + buffer.WriteString(fmt.Sprintf("$%d", idx)) + idx += 1 + } else { + buffer.WriteRune(char) + } + } + return buffer.String() +} + +// GetParamQuery replaces parameter placeholders '?' in the SQL query to +// placeholders supported by the selected database driver. +func GetParamQuery(database, query string) string { + if strings.HasPrefix(database, "postgres") { + return getPostgresQuery(query) + } else { + return query + } +} diff --git a/dbutils/dbutils_test.go b/dbutils/dbutils_test.go new file mode 100644 index 00000000..50a5cf07 --- /dev/null +++ b/dbutils/dbutils_test.go @@ -0,0 +1,17 @@ +package dbutils + +import "testing" + +const demo_query = "SELECT * FROM test WHERE id = ? AND test = ? LIMIT 1" + +func TestGetParamQuery(t *testing.T) { + q := GetParamQuery("postgres", demo_query) + if q != "SELECT * FROM test WHERE id = $1 AND test = $2 LIMIT 1" { + t.Fatalf("Incorrect postgres query") + } + + q = GetParamQuery("sqlite3", demo_query) + if q != "SELECT * FROM test WHERE id = ? AND test = ? LIMIT 1" { + t.Fatalf("Incorrect sqlite3 query") + } +} diff --git a/index/index.go b/index/index.go index 47f11dab..108efdc0 100644 --- a/index/index.go +++ b/index/index.go @@ -10,6 +10,7 @@ import ( "log" "github.com/readium/readium-lcp-server/config" + "github.com/readium/readium-lcp-server/dbutils" ) // ErrNotFound signals content not found @@ -52,16 +53,35 @@ func (i dbIndex) Get(id string) (Content, error) { // Add inserts a record func (i dbIndex) Add(c Content) error { - _, err := i.db.Exec("INSERT INTO content (id,encryption_key,location,length,sha256,type) VALUES (?, ?, ?, ?, ?, ?)", - c.ID, c.EncryptionKey, c.Location, c.Length, c.Sha256, c.Type) - return err + driver, _ := config.GetDatabase(config.Config.LcpServer.Database) + + if driver == "postgres" { + _, err := i.db.Exec(dbutils.GetParamQuery(config.Config.LcpServer.Database, "INSERT INTO content (id,encryption_key,location,length,sha256,type) VALUES (?, ?::bytea, ?, ?, ?, ?)"), + c.ID, c.EncryptionKey, c.Location, c.Length, c.Sha256, c.Type) + return err + + } else { + _, err := i.db.Exec("INSERT INTO content (id,encryption_key,location,length,sha256,type) VALUES (?, ?, ?, ?, ?, ?)", + c.ID, c.EncryptionKey, c.Location, c.Length, c.Sha256, c.Type) + return err + } + } // Update updates a record func (i dbIndex) Update(c Content) error { - _, err := i.db.Exec("UPDATE content SET encryption_key=? , location=?, length=?, sha256=?, type=? WHERE id=?", - c.EncryptionKey, c.Location, c.Length, c.Sha256, c.Type, c.ID) - return err + driver, _ := config.GetDatabase(config.Config.LcpServer.Database) + + if driver == "postgres" { + _, err := i.db.Exec(dbutils.GetParamQuery(config.Config.LcpServer.Database, "UPDATE content SET encryption_key=?::bytea , location=?, length=?, sha256=?, type=? WHERE id=?"), + c.EncryptionKey, c.Location, c.Length, c.Sha256, c.Type, c.ID) + return err + } else { + _, err := i.db.Exec("UPDATE content SET encryption_key=? , location=?, length=?, sha256=?, type=? WHERE id=?", + c.EncryptionKey, c.Location, c.Length, c.Sha256, c.Type, c.ID) + return err + } + } // List lists rows @@ -97,7 +117,7 @@ func Open(db *sql.DB) (i Index, err error) { } var dbGetByID *sql.Stmt - dbGetByID, err = db.Prepare("SELECT id,encryption_key,location,length,sha256,type FROM content WHERE id = ?") + dbGetByID, err = db.Prepare(dbutils.GetParamQuery(config.Config.LcpServer.Database, "SELECT id,encryption_key,location,length,sha256,type FROM content WHERE id = ?")) if err != nil { return } diff --git a/license/store.go b/license/store.go index 00113c74..cade50cd 100644 --- a/license/store.go +++ b/license/store.go @@ -12,6 +12,7 @@ import ( "time" "github.com/readium/readium-lcp-server/config" + "github.com/readium/readium-lcp-server/dbutils" ) var ErrNotFound = errors.New("License not found") @@ -98,7 +99,7 @@ func (s *sqlStore) ListByContentID(contentID string, pageSize int, pageNum int) // UpdateRights func (s *sqlStore) UpdateRights(l License) error { - result, err := s.db.Exec("UPDATE license SET rights_print=?, rights_copy=?, rights_start=?, rights_end=?, updated=? WHERE id=?", + result, err := s.db.Exec(dbutils.GetParamQuery(config.Config.LcpServer.Database, "UPDATE license SET rights_print=?, rights_copy=?, rights_start=?, rights_end=?, updated=? WHERE id=?"), l.Rights.Print, l.Rights.Copy, l.Rights.Start, l.Rights.End, time.Now().UTC().Truncate(time.Second), l.ID) if err == nil { @@ -112,9 +113,9 @@ func (s *sqlStore) UpdateRights(l License) error { // Add creates a new record in the license table func (s *sqlStore) Add(l License) error { - _, err := s.db.Exec(`INSERT INTO license (id, user_id, provider, issued, updated, + _, err := s.db.Exec(dbutils.GetParamQuery(config.Config.LcpServer.Database, `INSERT INTO license (id, user_id, provider, issued, updated, rights_print, rights_copy, rights_start, rights_end, content_fk) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`), l.ID, l.User.ID, l.Provider, l.Issued, nil, l.Rights.Print, l.Rights.Copy, l.Rights.Start, l.Rights.End, l.ContentID) @@ -124,9 +125,9 @@ func (s *sqlStore) Add(l License) error { // Update updates a record in the license table func (s *sqlStore) Update(l License) error { - _, err := s.db.Exec(`UPDATE license SET user_id=?,provider=?,updated=?, + _, err := s.db.Exec(dbutils.GetParamQuery(config.Config.LcpServer.Database, `UPDATE license SET user_id=?,provider=?,updated=?, rights_print=?, rights_copy=?, rights_start=?, rights_end=?, content_fk =? - WHERE id=?`, + WHERE id=?`), l.User.ID, l.Provider, time.Now().UTC().Truncate(time.Second), l.Rights.Print, l.Rights.Copy, l.Rights.Start, l.Rights.End, @@ -139,7 +140,7 @@ func (s *sqlStore) Update(l License) error { // UpdateLsdStatus func (s *sqlStore) UpdateLsdStatus(id string, status int32) error { - _, err := s.db.Exec(`UPDATE license SET lsd_status =? WHERE id=?`, + _, err := s.db.Exec(dbutils.GetParamQuery(config.Config.LcpServer.Database, `UPDATE license SET lsd_status =? WHERE id=?`), status, id) return err @@ -174,11 +175,11 @@ func Open(db *sql.DB) (store Store, err error) { var dbList *sql.Stmt if driver == "mssql" { - dbList, err = db.Prepare(`SELECT id, user_id, provider, issued, updated, rights_print, rights_copy, rights_start, rights_end, content_fk - FROM license ORDER BY issued desc OFFSET ? ROWS FETCH NEXT ? ROWS ONLY`) + dbList, err = db.Prepare(dbutils.GetParamQuery(config.Config.LcpServer.Database, `SELECT id, user_id, provider, issued, updated, rights_print, rights_copy, rights_start, rights_end, content_fk + FROM license ORDER BY issued desc OFFSET ? ROWS FETCH NEXT ? ROWS ONLY`)) } else { - dbList, err = db.Prepare(`SELECT id, user_id, provider, issued, updated, rights_print, rights_copy, rights_start, rights_end, content_fk - FROM license ORDER BY issued desc LIMIT ? OFFSET ?`) + dbList, err = db.Prepare(dbutils.GetParamQuery(config.Config.LcpServer.Database, `SELECT id, user_id, provider, issued, updated, rights_print, rights_copy, rights_start, rights_end, content_fk + FROM license ORDER BY issued desc LIMIT ? OFFSET ?`)) } if err != nil { log.Println("Error preparing dbList") @@ -187,13 +188,13 @@ func Open(db *sql.DB) (store Store, err error) { var dbListByContentID *sql.Stmt if driver == "mssql" { - dbListByContentID, err = db.Prepare(`SELECT id, user_id, provider, issued, updated, + dbListByContentID, err = db.Prepare(dbutils.GetParamQuery(config.Config.LcpServer.Database, `SELECT id, user_id, provider, issued, updated, rights_print, rights_copy, rights_start, rights_end, content_fk - FROM license WHERE content_fk = ? ORDER BY issued desc OFFSET ? ROWS FETCH NEXT ? ROWS ONLY`) + FROM license WHERE content_fk = ? ORDER BY issued desc OFFSET ? ROWS FETCH NEXT ? ROWS ONLY`)) } else { - dbListByContentID, err = db.Prepare(`SELECT id, user_id, provider, issued, updated, + dbListByContentID, err = db.Prepare(dbutils.GetParamQuery(config.Config.LcpServer.Database, `SELECT id, user_id, provider, issued, updated, rights_print, rights_copy, rights_start, rights_end, content_fk - FROM license WHERE content_fk = ? ORDER BY issued desc LIMIT ? OFFSET ?`) + FROM license WHERE content_fk = ? ORDER BY issued desc LIMIT ? OFFSET ?`)) } if err != nil { @@ -202,9 +203,9 @@ func Open(db *sql.DB) (store Store, err error) { } var dbGetByID *sql.Stmt - dbGetByID, err = db.Prepare(`SELECT id, user_id, provider, issued, updated, rights_print, rights_copy, + dbGetByID, err = db.Prepare(dbutils.GetParamQuery(config.Config.LcpServer.Database, `SELECT id, user_id, provider, issued, updated, rights_print, rights_copy, rights_start, rights_end, content_fk - FROM license WHERE id = ?`) + FROM license WHERE id = ?`)) if err != nil { log.Println("Error preparing dbGetByID") return diff --git a/license_statuses/license_statuses.go b/license_statuses/license_statuses.go index 9caa5cf3..67e03116 100644 --- a/license_statuses/license_statuses.go +++ b/license_statuses/license_statuses.go @@ -11,6 +11,7 @@ import ( "time" "github.com/readium/readium-lcp-server/config" + "github.com/readium/readium-lcp-server/dbutils" "github.com/readium/readium-lcp-server/status" ) @@ -78,9 +79,9 @@ func (i dbLicenseStatuses) Add(ls LicenseStatus) error { if ls.PotentialRights != nil && ls.PotentialRights.End != nil && !(*ls.PotentialRights.End).IsZero() { end = ls.PotentialRights.End } - _, err = i.db.Exec(`INSERT INTO license_status + _, err = i.db.Exec(dbutils.GetParamQuery(config.Config.LsdServer.Database, `INSERT INTO license_status (status, license_updated, status_updated, device_count, potential_rights_end, license_ref, rights_end) - VALUES (?, ?, ?, ?, ?, ?, ?)`, + VALUES (?, ?, ?, ?, ?, ?, ?)`), statusDB, ls.Updated.License, ls.Updated.Status, ls.DeviceCount, end, ls.LicenseRef, ls.CurrentEndLicense) } @@ -173,8 +174,8 @@ func (i dbLicenseStatuses) Update(ls LicenseStatus) error { } var result sql.Result - result, err = i.db.Exec(`UPDATE license_status SET status=?, license_updated=?, status_updated=?, - device_count=?,potential_rights_end=?, rights_end=? WHERE id=?`, + result, err = i.db.Exec(dbutils.GetParamQuery(config.Config.LsdServer.Database, `UPDATE license_status SET status=?, license_updated=?, status_updated=?, + device_count=?,potential_rights_end=?, rights_end=? WHERE id=?`), statusInt, ls.Updated.License, ls.Updated.Status, ls.DeviceCount, potentialRightsEnd, ls.CurrentEndLicense, ls.ID) if err == nil { @@ -199,7 +200,7 @@ func Open(db *sql.DB) (l LicenseStatuses, err error) { } } - dbGet, err := db.Prepare("SELECT * FROM license_status WHERE id = ?") + dbGet, err := db.Prepare(dbutils.GetParamQuery(config.Config.LsdServer.Database, "SELECT * FROM license_status WHERE id = ?")) if err != nil { return } @@ -209,15 +210,15 @@ func Open(db *sql.DB) (l LicenseStatuses, err error) { dbList, err = db.Prepare(`SELECT id, status, license_updated, status_updated, device_count, license_ref FROM license_status WHERE device_count >= ? ORDER BY id DESC OFFSET ? ROWS FETCH NEXT ? ROWS ONLY`) } else { - dbList, err = db.Prepare(`SELECT id, status, license_updated, status_updated, device_count, license_ref FROM license_status WHERE device_count >= ? - ORDER BY id DESC LIMIT ? OFFSET ?`) + dbList, err = db.Prepare(dbutils.GetParamQuery(config.Config.LsdServer.Database, `SELECT id, status, license_updated, status_updated, device_count, license_ref FROM license_status WHERE device_count >= ? + ORDER BY id DESC LIMIT ? OFFSET ?`)) } if err != nil { return } - dbGetByLicenseID, err := db.Prepare("SELECT * FROM license_status where license_ref = ?") + dbGetByLicenseID, err := db.Prepare(dbutils.GetParamQuery(config.Config.LsdServer.Database, "SELECT * FROM license_status where license_ref = ?")) if err != nil { return } diff --git a/transactions/transactions.go b/transactions/transactions.go index e2c796f3..3a73c8db 100644 --- a/transactions/transactions.go +++ b/transactions/transactions.go @@ -11,6 +11,7 @@ import ( "time" "github.com/readium/readium-lcp-server/config" + "github.com/readium/readium-lcp-server/dbutils" "github.com/readium/readium-lcp-server/status" ) @@ -70,7 +71,7 @@ func (i dbTransactions) Get(id int) (Event, error) { // The parameter eventType corresponds to the field 'type' in table 'event' func (i dbTransactions) Add(e Event, eventType int) error { - _, err := i.db.Exec("INSERT INTO event (device_name, timestamp, type, device_id, license_status_fk) VALUES (?, ?, ?, ?, ?)", + _, err := i.db.Exec(dbutils.GetParamQuery(config.Config.LsdServer.Database, "INSERT INTO event (device_name, timestamp, type, device_id, license_status_fk) VALUES (?, ?, ?, ?, ?)"), e.DeviceName, e.Timestamp, eventType, e.DeviceId, e.LicenseStatusFk) return err } @@ -151,12 +152,12 @@ func Open(db *sql.DB) (t Transactions, err error) { } // select an event by its id - dbGet, err := db.Prepare("SELECT * FROM event WHERE id = ?") + dbGet, err := db.Prepare(dbutils.GetParamQuery(config.Config.LsdServer.Database, "SELECT * FROM event WHERE id = ?")) if err != nil { return } - dbGetByStatusID, err := db.Prepare("SELECT * FROM event WHERE license_status_fk = ?") + dbGetByStatusID, err := db.Prepare(dbutils.GetParamQuery(config.Config.LsdServer.Database, "SELECT * FROM event WHERE license_status_fk = ?")) if err != nil { return } @@ -167,15 +168,15 @@ func Open(db *sql.DB) (t Transactions, err error) { dbCheckDeviceStatus, err = db.Prepare(`SELECT TOP 1 type FROM event WHERE license_status_fk = ? AND device_id = ? ORDER BY timestamp DESC`) } else { - dbCheckDeviceStatus, err = db.Prepare(`SELECT type FROM event WHERE license_status_fk = ? - AND device_id = ? ORDER BY timestamp DESC LIMIT 1`) + dbCheckDeviceStatus, err = db.Prepare(dbutils.GetParamQuery(config.Config.LsdServer.Database, `SELECT type FROM event WHERE license_status_fk = ? + AND device_id = ? ORDER BY timestamp DESC LIMIT 1`)) } if err != nil { return } - dbListRegisteredDevices, err := db.Prepare(`SELECT device_id, - device_name, timestamp FROM event WHERE license_status_fk = ? AND type = 1`) + dbListRegisteredDevices, err := db.Prepare(dbutils.GetParamQuery(config.Config.LsdServer.Database, `SELECT device_id, + device_name, timestamp FROM event WHERE license_status_fk = ? AND type = 1`)) if err != nil { return }