Skip to content

Commit

Permalink
[components] Remove translator from SlingReplicationCollectionComponent
Browse files Browse the repository at this point in the history
  • Loading branch information
OwenKephart committed Jan 8, 2025
1 parent 22c7cdb commit 65347e4
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 113 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC
from typing import AbstractSet, Annotated, Any, Dict, Literal, Mapping, Optional, Sequence, Union
from typing import Annotated, Any, Dict, Literal, Mapping, Optional, Sequence, Union

import dagster._check as check
from dagster._core.definitions.asset_key import AssetKey
from dagster._core.definitions.asset_selection import AssetSelection
from dagster._core.definitions.asset_spec import AssetSpec, map_asset_specs
Expand All @@ -10,7 +11,7 @@
)
from dagster._core.definitions.definitions_class import Definitions
from dagster._record import replace
from pydantic import BaseModel, Field
from pydantic import BaseModel

from dagster_components.core.component_rendering import (
RenderedModel,
Expand Down Expand Up @@ -50,36 +51,45 @@ class AssetAttributesModel(RenderedModel):
] = None


class AssetSpecProcessor(ABC, BaseModel):
class AssetSpecTransform(ABC, BaseModel):
target: str = "*"
operation: Literal["merge", "replace"] = "merge"
attributes: AssetAttributesModel

class Config:
arbitrary_types_allowed = True

def _apply_to_spec(self, spec: AssetSpec, attributes: Mapping[str, Any]) -> AssetSpec: ...

def apply_to_spec(
self,
spec: AssetSpec,
value_renderer: TemplatedValueRenderer,
target_keys: AbstractSet[AssetKey],
) -> AssetSpec:
if spec.key not in target_keys:
return spec

# add the original spec to the context and resolve values
return self._apply_to_spec(
spec, self.attributes.render_properties(value_renderer.with_context(asset=spec))
)
attributes = self.attributes.render_properties(value_renderer.with_context(asset=spec))

if self.operation == "merge":
mergeable_attributes = {"metadata", "tags"}
merge_attributes = {k: v for k, v in attributes.items() if k in mergeable_attributes}
replace_attributes = {
k: v for k, v in attributes.items() if k not in mergeable_attributes
}
return spec.merge_attributes(**merge_attributes).replace_attributes(
**replace_attributes
)
elif self.operation == "replace":
return spec.replace_attributes(**attributes)
else:
check.failed(f"Unsupported operation: {self.operation}")

def apply(self, defs: Definitions, value_renderer: TemplatedValueRenderer) -> Definitions:
target_selection = AssetSelection.from_string(self.target, include_sources=True)
target_keys = target_selection.resolve(defs.get_asset_graph())

mappable = [d for d in defs.assets or [] if isinstance(d, (AssetsDefinition, AssetSpec))]
mapped_assets = map_asset_specs(
lambda spec: self.apply_to_spec(spec, value_renderer, target_keys),
lambda spec: self.apply_to_spec(spec, value_renderer)
if spec.key in target_keys
else spec,
mappable,
)

Expand All @@ -88,27 +98,3 @@ def apply(self, defs: Definitions, value_renderer: TemplatedValueRenderer) -> De
*[d for d in defs.assets or [] if not isinstance(d, (AssetsDefinition, AssetSpec))],
]
return replace(defs, assets=assets)


class MergeAttributes(AssetSpecProcessor):
# default operation is "merge"
operation: Literal["merge"] = "merge"

def _apply_to_spec(self, spec: AssetSpec, attributes: Mapping[str, Any]) -> AssetSpec:
mergeable_attributes = {"metadata", "tags"}
merge_attributes = {k: v for k, v in attributes.items() if k in mergeable_attributes}
replace_attributes = {k: v for k, v in attributes.items() if k not in mergeable_attributes}
return spec.merge_attributes(**merge_attributes).replace_attributes(**replace_attributes)


class ReplaceAttributes(AssetSpecProcessor):
# operation must be set explicitly
operation: Literal["replace"]

def _apply_to_spec(self, spec: AssetSpec, attributes: Mapping[str, Any]) -> AssetSpec:
return spec.replace_attributes(**attributes)


AssetAttributes = Sequence[
Annotated[Union[MergeAttributes, ReplaceAttributes], Field(union_mode="left_to_right")]
]
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
from dagster_components import Component, ComponentLoadContext
from dagster_components.core.component import TemplatedValueRenderer, component_type
from dagster_components.core.dsl_schema import (
AssetAttributes,
AssetAttributesModel,
AssetSpecProcessor,
AssetSpecTransform,
OpSpecBaseModel,
)
from dagster_components.lib.dbt_project.generator import DbtProjectComponentGenerator
Expand All @@ -22,7 +21,7 @@ class DbtProjectParams(BaseModel):
dbt: DbtCliResource
op: Optional[OpSpecBaseModel] = None
translator: Optional[AssetAttributesModel] = None
asset_attributes: Optional[AssetAttributes] = None
asset_transforms: Optional[Sequence[AssetSpecTransform]] = None


class DbtProjectComponentTranslator(DagsterDbtTranslator):
Expand Down Expand Up @@ -70,12 +69,12 @@ def __init__(
dbt_resource: DbtCliResource,
op_spec: Optional[OpSpecBaseModel],
dbt_translator: Optional[DagsterDbtTranslator],
asset_processors: Sequence[AssetSpecProcessor],
asset_transforms: Sequence[AssetSpecTransform],
):
self.dbt_resource = dbt_resource
self.op_spec = op_spec
self.dbt_translator = dbt_translator
self.asset_processors = asset_processors
self.asset_transforms = asset_transforms

@classmethod
def get_generator(cls) -> "DbtProjectComponentGenerator":
Expand All @@ -96,7 +95,7 @@ def load(cls, context: ComponentLoadContext) -> Self:
params=loaded_params.translator,
value_renderer=context.templated_value_renderer,
),
asset_processors=loaded_params.asset_attributes or [],
asset_transforms=loaded_params.asset_transforms or [],
)

def build_defs(self, context: ComponentLoadContext) -> Definitions:
Expand All @@ -114,7 +113,7 @@ def _fn(context: AssetExecutionContext):
yield from self.execute(context=context, dbt=self.dbt_resource)

defs = Definitions(assets=[_fn])
for transform in self.asset_processors:
for transform in self.asset_transforms:
defs = transform.apply(defs, context.templated_value_renderer)
return defs

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from pathlib import Path
from typing import Any, Iterator, Mapping, Optional, Sequence, Union
from typing import Iterator, Optional, Sequence, Union

from dagster._core.definitions.asset_key import AssetKey
from dagster._core.definitions.assets import AssetsDefinition
from dagster._core.definitions.auto_materialize_policy import AutoMaterializePolicy
from dagster._core.definitions.definitions_class import Definitions
from dagster._core.definitions.events import AssetMaterialization
from dagster._core.definitions.result import MaterializeResult
Expand All @@ -13,66 +11,25 @@
from typing_extensions import Self

from dagster_components import Component, ComponentLoadContext
from dagster_components.core.component import TemplatedValueRenderer, component_type
from dagster_components.core.component import component_type
from dagster_components.core.dsl_schema import (
AssetAttributes,
AssetAttributesModel,
AssetSpecProcessor,
AssetSpecTransform,
OpSpecBaseModel,
)
from dagster_components.utils import get_wrapped_translator_class


class SlingReplicationParams(BaseModel):
path: str
op: Optional[OpSpecBaseModel] = None
translator: Optional[AssetAttributesModel] = None
asset_attributes: Optional[AssetAttributesModel] = None


class SlingReplicationCollectionParams(BaseModel):
sling: Optional[SlingResource] = None
replications: Sequence[SlingReplicationParams]
asset_attributes: Optional[AssetAttributes] = None


class SlingReplicationTranslator(DagsterSlingTranslator):
def __init__(
self,
*,
params: Optional[AssetAttributesModel],
value_renderer: TemplatedValueRenderer,
):
self.params = params or AssetAttributesModel()
self.value_renderer = value_renderer

def _get_rendered_attribute(
self, attribute: str, stream_definition: Mapping[str, Any], default_method
) -> Any:
renderer = self.value_renderer.with_context(stream_definition=stream_definition)
rendered_attribute = self.params.render_properties(renderer).get(attribute)
return (
rendered_attribute
if rendered_attribute is not None
else default_method(stream_definition)
)

def get_asset_key(self, stream_definition: Mapping[str, Any]) -> AssetKey:
return self._get_rendered_attribute("key", stream_definition, super().get_asset_key)

def get_group_name(self, stream_definition: Mapping[str, Any]) -> Optional[str]:
return self._get_rendered_attribute("group_name", stream_definition, super().get_group_name)

def get_tags(self, stream_definition: Mapping[str, Any]) -> Mapping[str, str]:
return self._get_rendered_attribute("tags", stream_definition, super().get_tags)

def get_metadata(self, stream_definition: Mapping[str, Any]) -> Mapping[str, Any]:
return self._get_rendered_attribute("metadata", stream_definition, super().get_metadata)

def get_auto_materialize_policy(
self, stream_definition: Mapping[str, Any]
) -> Optional[AutoMaterializePolicy]:
return self._get_rendered_attribute(
"auto_materialize_policy", stream_definition, super().get_auto_materialize_policy
)
asset_transforms: Optional[Sequence[AssetSpecTransform]] = None


@component_type(name="sling_replication_collection")
Expand All @@ -82,12 +39,12 @@ def __init__(
dirpath: Path,
resource: SlingResource,
sling_replications: Sequence[SlingReplicationParams],
asset_attributes: Sequence[AssetSpecProcessor],
asset_transforms: Sequence[AssetSpecTransform],
):
self.dirpath = dirpath
self.resource = resource
self.sling_replications = sling_replications
self.asset_attributes = asset_attributes
self.asset_transforms = asset_transforms

@classmethod
def get_params_schema_type(cls):
Expand All @@ -100,18 +57,22 @@ def load(cls, context: ComponentLoadContext) -> Self:
dirpath=context.path,
resource=loaded_params.sling or SlingResource(),
sling_replications=loaded_params.replications,
asset_attributes=loaded_params.asset_attributes or [],
asset_transforms=loaded_params.asset_transforms or [],
)

def build_replication_asset(
self, context: ComponentLoadContext, replication: SlingReplicationParams
) -> AssetsDefinition:
translator_cls = get_wrapped_translator_class(DagsterSlingTranslator)

@sling_assets(
name=replication.op.name if replication.op else Path(replication.path).stem,
op_tags=replication.op.tags if replication.op else {},
replication_config=self.dirpath / replication.path,
dagster_sling_translator=SlingReplicationTranslator(
params=replication.translator,
dagster_sling_translator=translator_cls(
obj_name="stream_definition",
base_translator=DagsterSlingTranslator(),
asset_attributes=replication.asset_attributes or AssetAttributesModel(),
value_renderer=context.templated_value_renderer,
),
)
Expand All @@ -132,6 +93,6 @@ def build_defs(self, context: ComponentLoadContext) -> Definitions:
for replication in self.sling_replications
],
)
for transform in self.asset_attributes:
for transform in self.asset_transforms:
defs = transform.apply(defs, context.templated_value_renderer)
return defs
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
import sys
from pathlib import Path
from typing import Any, Mapping, Optional, Type

from dagster._core.definitions.asset_key import AssetKey
from dagster_dbt.dagster_dbt_translator import AutomationCondition

from dagster_components.core.component_rendering import TemplatedValueRenderer
from dagster_components.core.dsl_schema import AssetAttributesModel

CLI_BUILTIN_COMPONENT_LIB_KEY = "builtin_component_lib"

Expand All @@ -12,3 +19,46 @@ def ensure_dagster_components_tests_import() -> None:
dagster_components_package_root / "dagster_components_tests"
).exists(), "Could not find dagster_components_tests where expected"
sys.path.append(dagster_components_package_root.as_posix())


def get_wrapped_translator_class(translator_type: Type):
"""Temporary hack to allow wrapping of many methods of a given translator class."""

class WrappedTranslator(translator_type):
def __init__(
self,
*,
obj_name: str,
base_translator,
asset_attributes: AssetAttributesModel,
value_renderer: TemplatedValueRenderer,
):
self.obj_name = obj_name
self.base_translator = base_translator
self.asset_attributes = asset_attributes
self.value_renderer = value_renderer

def _get_rendered_attribute(self, attribute: str, obj: Any, default_method) -> Any:
renderer = self.value_renderer.with_context(**{self.obj_name: obj})
rendered_attributes = self.asset_attributes.render_properties(renderer)
return (
rendered_attributes[attribute]
if attribute in rendered_attributes
else default_method(obj)
)

def get_asset_key(self, obj: Any) -> AssetKey:
return self._get_rendered_attribute("key", obj, super().get_asset_key)

def get_group_name(self, obj: Any) -> Optional[str]:
return self._get_rendered_attribute("group_name", obj, super().get_group_name)

def get_tags(self, obj: Any) -> Mapping[str, str]:
return self._get_rendered_attribute("tags", obj, super().get_tags)

def get_automation_condition(self, obj: Any) -> Optional[AutomationCondition]:
return self._get_rendered_attribute(
"automation_condition", obj, super().get_automation_condition
)

return WrappedTranslator
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ type: dagster_components.sling_replication_collection
params:
replications:
- path: ./replication.yaml
asset_attributes:
key: "foo/{{ stream_definition.config.meta.dagster.asset_key }}"
sling:
connections:
- name: DUCKDB
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def test_load_from_path(sling_path: Path) -> None:
assert len(components) == 1
assert get_asset_keys(components[0]) == {
AssetKey("input_csv"),
AssetKey("input_duckdb"),
AssetKey(["foo", "input_duckdb"]),
}

assert_assets(components[0], 2)
Expand Down
Loading

0 comments on commit 65347e4

Please sign in to comment.