Skip to content

Commit

Permalink
[DOM-65175] Add support for datasets as datasources (#153)
Browse files Browse the repository at this point in the history
[DOM-65175] Add support for datasets as datasources
  • Loading branch information
ddl-s-ramirezayuso authored Feb 20, 2025
1 parent 4e3de87 commit 98b0645
Show file tree
Hide file tree
Showing 31 changed files with 1,508 additions and 44 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ mypy:
check-safety:
poetry check
# TODO remove pip ignore flag when fixed
poetry run safety check --full-report -i 62044 -i 70612
poetry run safety check --full-report -i 62044 -i 70612 -i 73884
poetry run bandit -ll --recursive domino_data tests

.PHONY: lint
Expand Down
12 changes: 6 additions & 6 deletions datasource_api_client/api/datasource/get_datasource_by_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,27 +36,27 @@ def _get_kwargs(
def _parse_response(
*, client: Union[AuthenticatedClient, Client], response: httpx.Response
) -> Optional[Union[DatasourceDto, ErrorResponse]]:
if response.status_code == HTTPStatus.OK:
if response.status_code == 200:
response_200 = DatasourceDto.from_dict(response.json())

return response_200
if response.status_code == HTTPStatus.BAD_REQUEST:
if response.status_code == 400:
response_400 = ErrorResponse.from_dict(response.json())

return response_400
if response.status_code == HTTPStatus.UNAUTHORIZED:
if response.status_code == 401:
response_401 = ErrorResponse.from_dict(response.json())

return response_401
if response.status_code == HTTPStatus.FORBIDDEN:
if response.status_code == 403:
response_403 = ErrorResponse.from_dict(response.json())

return response_403
if response.status_code == HTTPStatus.NOT_FOUND:
if response.status_code == 404:
response_404 = ErrorResponse.from_dict(response.json())

return response_404
if response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR:
if response.status_code == 500:
response_500 = ErrorResponse.from_dict(response.json())

return response_500
Expand Down
6 changes: 3 additions & 3 deletions datasource_api_client/api/proxy/get_key_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ def _get_kwargs(
def _parse_response(
*, client: Union[AuthenticatedClient, Client], response: httpx.Response
) -> Optional[Union[ProxyErrorResponse, str]]:
if response.status_code == HTTPStatus.OK:
if response.status_code == 200:
response_200 = cast(str, response.json())
return response_200
if response.status_code == HTTPStatus.BAD_REQUEST:
if response.status_code == 400:
response_400 = ProxyErrorResponse.from_dict(response.json())

return response_400
if response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR:
if response.status_code == 500:
response_500 = ProxyErrorResponse.from_dict(response.json())

return response_500
Expand Down
6 changes: 3 additions & 3 deletions datasource_api_client/api/proxy/list_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ def _get_kwargs(
def _parse_response(
*, client: Union[AuthenticatedClient, Client], response: httpx.Response
) -> Optional[Union[List[str], ProxyErrorResponse]]:
if response.status_code == HTTPStatus.OK:
if response.status_code == 200:
response_200 = cast(List[str], response.json())

return response_200
if response.status_code == HTTPStatus.BAD_REQUEST:
if response.status_code == 400:
response_400 = ProxyErrorResponse.from_dict(response.json())

return response_400
if response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR:
if response.status_code == 500:
response_500 = ProxyErrorResponse.from_dict(response.json())

return response_500
Expand Down
2 changes: 1 addition & 1 deletion datasource_api_client/api/proxy/log_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _get_kwargs(
def _parse_response(
*, client: Union[AuthenticatedClient, Client], response: httpx.Response
) -> Optional[Any]:
if response.status_code == HTTPStatus.OK:
if response.status_code == 200:
return None
if client.raise_on_unexpected_status:
raise errors.UnexpectedStatus(response.status_code, response.content)
Expand Down
4 changes: 2 additions & 2 deletions datasource_api_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def with_timeout(self, timeout: httpx.Timeout) -> "Client":
return evolve(self, timeout=timeout)

def set_httpx_client(self, client: httpx.Client) -> "Client":
"""Manually the underlying httpx.Client
"""Manually set the underlying httpx.Client
**NOTE**: This will override any other settings on the client, including cookies, headers, and timeout.
"""
Expand Down Expand Up @@ -209,7 +209,7 @@ def with_timeout(self, timeout: httpx.Timeout) -> "AuthenticatedClient":
return evolve(self, timeout=timeout)

def set_httpx_client(self, client: httpx.Client) -> "AuthenticatedClient":
"""Manually the underlying httpx.Client
"""Manually set the underlying httpx.Client
**NOTE**: This will override any other settings on the client, including cookies, headers, and timeout.
"""
Expand Down
1 change: 1 addition & 0 deletions datasource_api_client/models/datasource_dto_auth_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class DatasourceDtoAuthType(str, Enum):
BASIC = "Basic"
CLIENTIDSECRET = "ClientIdSecret"
GCPBASIC = "GCPBasic"
NOAUTH = "NoAuth"
OAUTH = "OAuth"
OAUTHTOKEN = "OAuthToken"
PERSONALTOKEN = "PersonalToken"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ class DatasourceDtoDataSourceType(str, Enum):
BIGQUERYCONFIG = "BigQueryConfig"
CLICKHOUSECONFIG = "ClickHouseConfig"
DATABRICKSCONFIG = "DatabricksConfig"
DATASETCONFIG = "DatasetConfig"
DB2CONFIG = "DB2Config"
DRUIDCONFIG = "DruidConfig"
GCSCONFIG = "GCSConfig"
Expand Down
1 change: 1 addition & 0 deletions domino_data/_feature_store/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __attrs_post_init__(self):
api_key=self.api_key,
token_file=self.token_file,
token_url=token_url,
token=None,
headers={"Accept": "application/json"},
)

Expand Down
12 changes: 10 additions & 2 deletions domino_data/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,18 @@ class AuthenticatedClient(Client):
api_key: Optional[str] = attr.ib()
token_file: Optional[str] = attr.ib()
token_url: Optional[str] = attr.ib()
token: Optional[str] = attr.ib()

def __attrs_post_init__(self):
if not (self.api_key or self.token_file or self.token_url):
if not (self.api_key or self.token_file or self.token_url or self.token):
raise Exception(
"One of two authentication methods must be supplied (API Key or JWT Location)" # noqa
)

def _get_auth_headers(self) -> Dict[str, str]:
"""Get auth headers with either JWT or API Key."""
if self.token is not None:
return {"Authorization": f"Bearer {self.token}"}
if self.token_url is not None:
try:
jwt = get_jwt_token(self.token_url)
Expand Down Expand Up @@ -91,6 +94,8 @@ def _get_auth_headers(self) -> Dict[str, str]:
headers["X-Domino-Client-Source"] = self.client_source
if self.run_id:
headers["X-Domino-Run-Id"] = self.run_id
if self.token is not None:
headers["Authorization"] = f"Bearer {self.token}"

if self.token_url is not None:
try:
Expand All @@ -116,16 +121,19 @@ class AuthMiddlewareFactory(flight.ClientMiddlewareFactory):
api_key: Optional[str] = attr.ib()
token_file: Optional[str] = attr.ib()
token_url: Optional[str] = attr.ib()
token: Optional[str] = attr.ib()

def __attrs_post_init__(self):
if not (self.api_key or self.token_file or self.token_url):
if not (self.api_key or self.token_file or self.token_url or self.token):
raise Exception(
"One of two authentication methods must be supplied (API Key or JWT Location)" # noqa
)

def start_call(self, _):
"""Called at the start of an RPC."""
jwt = None
if self.token is not None:
return {"Authorization": f"Bearer {self.token}"}

if self.token_url is not None:
try:
Expand Down
15 changes: 14 additions & 1 deletion domino_data/configuration_gen.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Code generated by gen.py; DO NOT EDIT.
This file was generated by robots at
2024-01-18 15:51:53.230967"""
2025-02-20 13:30:06.271414"""

from typing import Any, Dict, Optional, Union

Expand Down Expand Up @@ -66,6 +66,7 @@ class ConfigElem(Enum):
CATALOG = "catalog"
CLUSTER = "cluster"
DATABASE = "database"
DATASETID = "datasetID"
EXTRAPROPERTIES = "extraProperties"
HOST = "host"
HTTPPATH = "httpPath"
Expand All @@ -79,6 +80,8 @@ class ConfigElem(Enum):
WAREHOUSE = "warehouse"
CATALOGCODE = "catalogCode"
ENVIRONMENT = "environment"
SNAPSHOTID = "snapshotID"
SSLENABLED = "sslEnabled"


class CredElem(Enum):
Expand Down Expand Up @@ -162,6 +165,14 @@ class DatabricksConfig(Config):
personal_access_token: Optional[str] = _cred(elem=CredElem.PERSONALACCESSTOKEN)


@attr.s(auto_attribs=True)
class DatasetConfig(Config):
"""DatasetConfig datasource configuration."""

snapshot_id: Optional[str] = _config(elem=ConfigElem.SNAPSHOTID)
subfolder: Optional[str] = _config(elem=ConfigElem.SUBFOLDER)


@attr.s(auto_attribs=True)
class DB2Config(Config):
"""DB2Config datasource configuration."""
Expand Down Expand Up @@ -411,6 +422,7 @@ class VerticaConfig(Config):
BigQueryConfig,
ClickHouseConfig,
DatabricksConfig,
DatasetConfig,
DB2Config,
DruidConfig,
GCSConfig,
Expand Down Expand Up @@ -445,6 +457,7 @@ class VerticaConfig(Config):
DatasourceDtoDataSourceType.BIGQUERYCONFIG: "TabularDatasource",
DatasourceDtoDataSourceType.CLICKHOUSECONFIG: "TabularDatasource",
DatasourceDtoDataSourceType.DATABRICKSCONFIG: "TabularDatasource",
DatasourceDtoDataSourceType.DATASETCONFIG: "ObjectStoreDatasource",
DatasourceDtoDataSourceType.DB2CONFIG: "TabularDatasource",
DatasourceDtoDataSourceType.DRUIDCONFIG: "TabularDatasource",
DatasourceDtoDataSourceType.GCSCONFIG: "ObjectStoreDatasource",
Expand Down
4 changes: 4 additions & 0 deletions domino_data/data_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,7 @@ class DataSourceClient:
api_key: Optional[str] = attr.ib(factory=lambda: os.getenv(DOMINO_USER_API_KEY))
token_file: Optional[str] = attr.ib(factory=lambda: os.getenv(DOMINO_TOKEN_FILE))
token_url: Optional[str] = attr.ib(factory=lambda: os.getenv(DOMINO_API_PROXY))
token: Optional[str] = attr.ib(default=None)

def __attrs_post_init__(self):
flight_host = os.getenv(DOMINO_DATASOURCE_PROXY_FLIGHT_HOST)
Expand All @@ -610,6 +611,7 @@ def __attrs_post_init__(self):
run_id=run_id,
token_file=self.token_file,
token_url=self.token_url,
token=self.token,
timeout=httpx.Timeout(5.0),
verify_ssl=True,
)
Expand All @@ -618,6 +620,7 @@ def __attrs_post_init__(self):
api_key=self.api_key,
token_file=self.token_file,
token_url=self.token_url,
token=self.token,
headers=ACCEPT_HEADERS,
timeout=httpx.Timeout(20.0),
verify_ssl=True,
Expand All @@ -634,6 +637,7 @@ def _set_proxy(self):
self.api_key,
self.token_file,
self.token_url,
self.token,
),
MetaMiddlewareFactory(client_source=client_source, run_id=run_id),
],
Expand Down
Loading

0 comments on commit 98b0645

Please sign in to comment.