Skip to content

Commit

Permalink
Allow update datasource
Browse files Browse the repository at this point in the history
  • Loading branch information
mawandm committed Apr 10, 2024
1 parent e8c664b commit 89c965d
Show file tree
Hide file tree
Showing 16 changed files with 679 additions and 107 deletions.
8 changes: 4 additions & 4 deletions nesis/api/core/controllers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
GET = "GET"
POST = "POST"
DELETE = "DELETE"
PUT = "PUT"
GET: str = "GET"
POST: str = "POST"
DELETE: str = "DELETE"
PUT: str = "PUT"

from .api import app

Expand Down
95 changes: 46 additions & 49 deletions nesis/api/core/controllers/datasources.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from flask import request, jsonify

from . import GET, POST, DELETE
import nesis.api.core.controllers as controllers
from .api import app, error_message

import logging
Expand All @@ -14,18 +14,21 @@
_LOG = logging.getLogger(__name__)


@app.route("/v1/datasources", methods=[POST, GET])
@app.route("/v1/datasources", methods=[controllers.POST, controllers.GET])
def operate_datasources():
token = get_bearer_token(request.headers.get("Authorization"))
try:
if request.method == POST:
result = services.datasource_service.create(
token=token, datasource=request.json
)
return jsonify(result.to_dict())
else:
results = services.datasource_service.get(token=token)
return jsonify({"items": [item.to_dict() for item in results]})
match request.method:
case controllers.POST:
result = services.datasource_service.create(
token=token, datasource=request.json
)
return jsonify(result.to_dict())
case controllers.GET:
results = services.datasource_service.get(token=token)
return jsonify({"items": [item.to_dict() for item in results]})
case _:
raise Exception("Should never be reached")
except util.ServiceException as se:
return jsonify(error_message(str(se))), 400
except util.UnauthorizedAccess:
Expand All @@ -39,49 +42,43 @@ def operate_datasources():
return jsonify(error_message("Server error")), 500


@app.route("/v1/datasources/<datasource_id>", methods=[GET, DELETE])
@app.route(
"/v1/datasources/<datasource_id>",
methods=[controllers.GET, controllers.DELETE, controllers.PUT],
)
def operate_datasource(datasource_id):
token = get_bearer_token(request.headers.get("Authorization"))
try:
if request.method == GET:
results = services.datasource_service.get(
token=token, datasource_id=datasource_id
)
if len(results) != 0:
return jsonify(results[0].to_dict())
else:
return (
jsonify(
error_message(
f"Datasource {datasource_id} not found", message_type="WARN"
)
),
404,
)
else:
services.datasource_service.delete(token=token, datasource_id=datasource_id)
return jsonify(success=True)
except util.ServiceException as se:
return jsonify(error_message(str(se))), 400
except util.UnauthorizedAccess:
return jsonify(error_message("Unauthorized access")), 401
except util.PermissionException:
return jsonify(error_message("Forbidden resource")), 403
except:
_LOG.exception("Error getting user")
return jsonify(error_message("Server error")), 500


@app.route("/v1/datasources/<datasource>/dataobjects/<dataobject>", methods=[GET])
def operate_dataobject(datasource, dataobject):
token = get_bearer_token(request.headers.get("Authorization"))
try:
if request.method == GET:
results = services.dataobject_service.get(
token=token, datasource=datasource, dataobject=dataobject
)
return jsonify(results.to_dict())

match request.method:
case controllers.GET:
results = services.datasource_service.get(
token=token, datasource_id=datasource_id
)
if len(results) != 0:
return jsonify(results[0].to_dict())
else:
return (
jsonify(
error_message(
f"Datasource {datasource_id} not found",
message_type="WARN",
)
),
404,
)
case controllers.DELETE:
services.datasource_service.delete(
token=token, datasource_id=datasource_id
)
return jsonify(success=True)
case controllers.PUT:
result = services.datasource_service.update(
token=token, datasource=request.json, datasource_id=datasource_id
)
return jsonify(result.to_dict())
case _:
raise Exception("Should never be reached really")
except util.ServiceException as se:
return jsonify(error_message(str(se))), 400
except util.UnauthorizedAccess:
Expand Down
101 changes: 83 additions & 18 deletions nesis/api/core/services/datasources.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
ServiceException,
is_valid_resource_name,
has_valid_keys,
PermissionException,
)

_LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -115,21 +116,11 @@ def get(self, **kwargs):
session = DBSession()
try:

self._authorized(
session=session, token=kwargs.get("token"), action=Action.READ
)

# Get datasources this user is authorized to access
authorized_datasources: list[RoleAction] = services.authorized_resources(
self._session_service,
session=session,
token=kwargs.get("token"),
action=Action.READ,
resource_type=objects.ResourceType.DATASOURCES,
datasources = self._authorized_resources(
session=session, action=Action.READ, token=kwargs.get("token")
)

datasources = {ds.resource for ds in authorized_datasources}

session.expire_on_commit = False
query = session.query(Datasource)
if datasource_id:
Expand All @@ -143,6 +134,20 @@ def get(self, **kwargs):
if session:
session.close()

def _authorized_resources(self, token, session, action, resource=None):
authorized_datasources: list[RoleAction] = services.authorized_resources(
self._session_service,
session=session,
token=token,
action=action,
resource_type=objects.ResourceType.DATASOURCES,
)
datasources = {ds.resource for ds in authorized_datasources}
if resource and resource not in datasources:
raise PermissionException("Access to resource denied")

return {ds.resource for ds in authorized_datasources}

@staticmethod
def get_datasources(source_type: str = None) -> list[Datasource]:
session = DBSession()
Expand All @@ -165,19 +170,21 @@ def delete(self, **kwargs):
session = DBSession()
try:

self._authorized(
session=session, token=kwargs.get("token"), action=Action.DELETE
)

session.expire_on_commit = False

datasource = (
session.query(Datasource)
.filter(Datasource.uuid == datasource_id)
.first()
)

if datasource:
self._authorized_resources(
session=session,
action=Action.DELETE,
token=kwargs.get("token"),
resource=datasource.name,
)

session.delete(datasource)
session.commit()
except Exception as e:
Expand All @@ -188,4 +195,62 @@ def delete(self, **kwargs):
session.close()

def update(self, **kwargs):
raise NotImplementedError("Invalid operation on datasource")
"""
Update the datasource. The payload only contains fields that we intend to update. Any missing fields will be
ignored.
:param kwargs: datasource the datasource object as a dict
:param kwargs: id The datasource id
:return:
"""
datasource = kwargs["datasource"]
datasource_id = kwargs["datasource_id"]

session = DBSession()

try:
datasource_record: Datasource = (
session.query(Datasource)
.filter(Datasource.uuid == datasource_id)
.first()
)

if datasource is None:
raise ServiceException("Datasource not found")

self._authorized_resources(
session=session,
action=Action.UPDATE,
token=kwargs.get("token"),
resource=datasource_record.name,
)

session.expire_on_commit = False

source_type: str = datasource.get("type")
if source_type is not None:
try:
datasource_type = DatasourceType[source_type.upper()]
datasource_record.type = datasource_type
except Exception:
raise ServiceException("Invalid datasource type")

if datasource.get("connection"):
try:
connection = validators.validate_datasource_connection(datasource)
datasource_record.connection = connection
except ValueError as ve:
raise ServiceException(ve)

if not has_valid_keys(connection):
raise ServiceException("Missing connection details")

session.merge(datasource_record)
session.commit()
return datasource_record
except Exception as e:
session.rollback()
self._LOG.exception(f"Error when creating setting")
raise
finally:
if session:
session.close()
94 changes: 92 additions & 2 deletions nesis/api/tests/core/controllers/test_datasources.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ def client():
return cloud_app.test_client()


@pytest.fixture
def tc():
return ut.TestCase()


def get_admin_session(client):
admin_data = {
"name": "s3 documents",
Expand All @@ -37,7 +42,7 @@ def get_admin_session(client):
).json


def test_datasources(client):
def test_create_datasource_invalid_input(client, tc):
# Get the prediction
payload = {
"type": "minio",
Expand Down Expand Up @@ -94,7 +99,23 @@ def test_datasources(client):
)
assert 400 == response.status_code, response.json

assert 400 == response.status_code, response.json

def test_create_datasource(client, tc):
# Get the prediction
payload = {
"type": "minio",
"name": "finance6",
"connection": {
"user": "caikuodda",
"password": "some.password",
"host": "localhost",
"port": "5432",
"database": "initdb",
},
}

admin_session = get_admin_session(client=client)

response = client.post(
f"/v1/datasources",
headers=tests.get_header(token=admin_session["token"]),
Expand Down Expand Up @@ -134,3 +155,72 @@ def test_datasources(client):
headers=tests.get_header(token=admin_session["token"]),
)
assert 404 == response.status_code, response.json


def test_update_datasources(client, tc):
# Create a datasource
payload = {
"type": "minio",
"name": "finance6",
"connection": {
"user": "caikuodda",
"password": "some.password",
"host": "localhost",
"port": "5432",
"database": "initdb",
},
}

admin_session = get_admin_session(client=client)

response = client.post(
f"/v1/datasources",
headers=tests.get_header(token=admin_session["token"]),
data=json.dumps(payload),
)
assert 200 == response.status_code, response.json
assert response.json.get("connection") is not None
print(json.dumps(response.json["connection"]))

response = client.get(
"/v1/datasources", headers=tests.get_header(token=admin_session["token"])
)
assert 200 == response.status_code, response.json
print(response.json)
assert 1 == len(response.json["items"])

datasource_id = response.json["items"][0]["id"]

response = client.get(
f"/v1/datasources/{datasource_id}",
headers=tests.get_header(token=admin_session["token"]),
)
assert 200 == response.status_code, response.json

datasource = response.json

datasource["connection"] = {
"user": "root",
"password": "some.password",
"host": "some.other.host.tld",
"port": "3360",
"database": "initdb",
}

response = client.put(
f"/v1/datasources/{datasource_id}",
headers=tests.get_header(token=admin_session["token"]),
data=json.dumps(datasource),
)
assert 200 == response.status_code, response.json

response = client.get(
f"/v1/datasources/{datasource_id}",
headers=tests.get_header(token=admin_session["token"]),
)

# Datasource password is never emitted so we skip it
tc.assertDictEqual(
response.json["connection"],
{k: v for k, v in datasource["connection"].items() if k != "password"},
)
Loading

0 comments on commit 89c965d

Please sign in to comment.