Skip to content

Commit

Permalink
Merge pull request #614 from mfang90739/mariadb-backend
Browse files Browse the repository at this point in the history
Initial changes to support Mariadb backend.
  • Loading branch information
rpiazza authored Dec 11, 2024
2 parents 82aad1d + 41104f2 commit 19c2c53
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 1 deletion.
83 changes: 83 additions & 0 deletions stix2/datastore/relational_db/database_backends/mariadb_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import os
from typing import Any

from sqlalchemy import TIMESTAMP, LargeBinary, Text, VARCHAR
from sqlalchemy.schema import CreateSchema

from stix2.base import (
_DomainObject, _MetaObject, _Observable, _RelationshipObject,
)
from stix2.datastore.relational_db.utils import schema_for

from .database_backend_base import DatabaseBackend


class MariaDBBackend(DatabaseBackend):
default_database_connection_url = \
f"mariadbsql://{os.getenv('MARIADB_USER')}:" + \
f"{os.getenv('MARIADB_PASSWORD')}@" + \
f"{os.getenv('MARIADB_IP_ADDRESS')}:" + \
f"{os.getenv('MARIADB_PORT', '3306')}/rdb"

def __init__(self, database_connection_url=default_database_connection_url, force_recreate=False, **kwargs: Any):
super().__init__(database_connection_url, force_recreate=force_recreate, **kwargs)

# =========================================================================
# schema methods

def _create_schemas(self):
with self.database_connection.begin() as trans:
trans.execute(CreateSchema("common", if_not_exists=True))
trans.execute(CreateSchema("sdo", if_not_exists=True))
trans.execute(CreateSchema("sco", if_not_exists=True))
trans.execute(CreateSchema("sro", if_not_exists=True))

@staticmethod
def determine_schema_name(stix_object):
if isinstance(stix_object, _DomainObject):
return "sdo"
elif isinstance(stix_object, _Observable):
return "sco"
elif isinstance(stix_object, _RelationshipObject):
return "sro"
elif isinstance(stix_object, _MetaObject):
return "common"

@staticmethod
def schema_for(stix_class):
return schema_for(stix_class)

@staticmethod
def schema_for_core():
return "common"

# =========================================================================
# sql type methods (overrides)

@staticmethod
def determine_sql_type_for_key_as_id(): # noqa: F811
return VARCHAR(255)

@staticmethod
def determine_sql_type_for_binary_property(): # noqa: F811
return MariaDBBackend.determine_sql_type_for_string_property()

@staticmethod
def determine_sql_type_for_hex_property(): # noqa: F811
# return LargeBinary
return MariaDBBackend.determine_sql_type_for_string_property()

@staticmethod
def determine_sql_type_for_timestamp_property(): # noqa: F811
return TIMESTAMP(timezone=True)

# =========================================================================
# Other methods

@staticmethod
def array_allowed():
return False

@staticmethod
def create_regex_constraint_expression(column_name, pattern):
return f"{column_name} REGEXP {pattern}"
16 changes: 16 additions & 0 deletions stix2/datastore/relational_db/database_backends/sqlite_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,18 @@ def determine_sql_type_for_hex_property(): # noqa: F811
# return LargeBinary
return SQLiteBackend.determine_sql_type_for_string_property()

@staticmethod
def determine_sql_type_for_reference_property(): # noqa: F811
return Text

@staticmethod
def determine_sql_type_for_string_property(): # noqa: F811
return Text

@staticmethod
def determine_sql_type_for_key_as_id(): # noqa: F811
return Text

@staticmethod
def determine_sql_type_for_timestamp_property(): # noqa: F811
return TIMESTAMP(timezone=True)
Expand All @@ -49,3 +61,7 @@ def determine_sql_type_for_timestamp_property(): # noqa: F811
@staticmethod
def array_allowed():
return False

@staticmethod
def create_regex_constraint_expression(column_name, pattern):
return f"{column_name} ~ {pattern}"
6 changes: 5 additions & 1 deletion stix2/datastore/relational_db/relational_db_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from database_backends.postgres_backend import PostgresBackend
from database_backends.sqlite_backend import SQLiteBackend
from database_backends.mariadb_backend import MariaDBBackend
import pytz
import os

import stix2
from stix2.datastore.relational_db.relational_db import RelationalDBStore
Expand Down Expand Up @@ -288,8 +290,10 @@ def test_dictionary():

def main():
store = RelationalDBStore(
PostgresBackend("postgresql://localhost/stix-data-sink", force_recreate=True),
MariaDBBackend("mariadb+pymysql://{os.getenv('MARIADB_USER')}:{os.getenv('MARIADB_PASSWORD')}@127.0.0.1:3306/rdb", force_recreate=True),
#PostgresBackend("postgresql://localhost/stix-data-sink", force_recreate=True),
#SQLiteBackend("sqlite:///stix-data-sink.db", force_recreate=True),

True,
None,
True,
Expand Down

0 comments on commit 19c2c53

Please sign in to comment.