diff --git a/python_modules/libraries/dagster-components/dagster_components/core/dsl_schema.py b/python_modules/libraries/dagster-components/dagster_components/core/dsl_schema.py index 1507f6bc3f5cf..2f726296b2bb4 100644 --- a/python_modules/libraries/dagster-components/dagster_components/core/dsl_schema.py +++ b/python_modules/libraries/dagster-components/dagster_components/core/dsl_schema.py @@ -1,7 +1,8 @@ from abc import ABC from collections.abc import Mapping, Sequence -from typing import AbstractSet, Annotated, Any, Literal, Optional, Union # noqa: UP035 +from typing import Annotated, Any, Literal, Optional, Union +import dagster._check as check from dagster._core.definitions.asset_key import AssetKey from dagster._core.definitions.asset_selection import AssetSelection from dagster._core.definitions.asset_spec import AssetSpec, map_asset_specs @@ -11,7 +12,7 @@ ) from dagster._core.definitions.definitions_class import Definitions from dagster._record import replace -from pydantic import BaseModel, Field +from pydantic import BaseModel from dagster_components.core.component_rendering import ( RenderedModel, @@ -51,28 +52,35 @@ class AssetAttributesModel(RenderedModel): ] = None -class AssetSpecProcessor(ABC, BaseModel): +class AssetSpecTransform(ABC, BaseModel): target: str = "*" + operation: Literal["merge", "replace"] = "merge" attributes: AssetAttributesModel class Config: arbitrary_types_allowed = True - def _apply_to_spec(self, spec: AssetSpec, attributes: Mapping[str, Any]) -> AssetSpec: ... - def apply_to_spec( self, spec: AssetSpec, value_renderer: TemplatedValueRenderer, - target_keys: AbstractSet[AssetKey], ) -> AssetSpec: - if spec.key not in target_keys: - return spec - # add the original spec to the context and resolve values - return self._apply_to_spec( - spec, self.attributes.render_properties(value_renderer.with_context(asset=spec)) - ) + attributes = self.attributes.render_properties(value_renderer.with_context(asset=spec)) + + if self.operation == "merge": + mergeable_attributes = {"metadata", "tags"} + merge_attributes = {k: v for k, v in attributes.items() if k in mergeable_attributes} + replace_attributes = { + k: v for k, v in attributes.items() if k not in mergeable_attributes + } + return spec.merge_attributes(**merge_attributes).replace_attributes( + **replace_attributes + ) + elif self.operation == "replace": + return spec.replace_attributes(**attributes) + else: + check.failed(f"Unsupported operation: {self.operation}") def apply(self, defs: Definitions, value_renderer: TemplatedValueRenderer) -> Definitions: target_selection = AssetSelection.from_string(self.target, include_sources=True) @@ -80,7 +88,9 @@ def apply(self, defs: Definitions, value_renderer: TemplatedValueRenderer) -> De mappable = [d for d in defs.assets or [] if isinstance(d, (AssetsDefinition, AssetSpec))] mapped_assets = map_asset_specs( - lambda spec: self.apply_to_spec(spec, value_renderer, target_keys), + lambda spec: self.apply_to_spec(spec, value_renderer) + if spec.key in target_keys + else spec, mappable, ) @@ -89,27 +99,3 @@ def apply(self, defs: Definitions, value_renderer: TemplatedValueRenderer) -> De *[d for d in defs.assets or [] if not isinstance(d, (AssetsDefinition, AssetSpec))], ] return replace(defs, assets=assets) - - -class MergeAttributes(AssetSpecProcessor): - # default operation is "merge" - operation: Literal["merge"] = "merge" - - def _apply_to_spec(self, spec: AssetSpec, attributes: Mapping[str, Any]) -> AssetSpec: - mergeable_attributes = {"metadata", "tags"} - merge_attributes = {k: v for k, v in attributes.items() if k in mergeable_attributes} - replace_attributes = {k: v for k, v in attributes.items() if k not in mergeable_attributes} - return spec.merge_attributes(**merge_attributes).replace_attributes(**replace_attributes) - - -class ReplaceAttributes(AssetSpecProcessor): - # operation must be set explicitly - operation: Literal["replace"] - - def _apply_to_spec(self, spec: AssetSpec, attributes: Mapping[str, Any]) -> AssetSpec: - return spec.replace_attributes(**attributes) - - -AssetAttributes = Sequence[ - Annotated[Union[MergeAttributes, ReplaceAttributes], Field(union_mode="left_to_right")] -] diff --git a/python_modules/libraries/dagster-components/dagster_components/lib/dbt_project/component.py b/python_modules/libraries/dagster-components/dagster_components/lib/dbt_project/component.py index 29819c20b91d5..917d6d1b74291 100644 --- a/python_modules/libraries/dagster-components/dagster_components/lib/dbt_project/component.py +++ b/python_modules/libraries/dagster-components/dagster_components/lib/dbt_project/component.py @@ -11,9 +11,8 @@ from dagster_components import Component, ComponentLoadContext from dagster_components.core.component import TemplatedValueRenderer, component_type from dagster_components.core.dsl_schema import ( - AssetAttributes, AssetAttributesModel, - AssetSpecProcessor, + AssetSpecTransform, OpSpecBaseModel, ) from dagster_components.lib.dbt_project.generator import DbtProjectComponentGenerator @@ -23,7 +22,7 @@ class DbtProjectParams(BaseModel): dbt: DbtCliResource op: Optional[OpSpecBaseModel] = None translator: Optional[AssetAttributesModel] = None - asset_attributes: Optional[AssetAttributes] = None + transforms: Optional[Sequence[AssetSpecTransform]] = None class DbtProjectComponentTranslator(DagsterDbtTranslator): @@ -71,12 +70,12 @@ def __init__( dbt_resource: DbtCliResource, op_spec: Optional[OpSpecBaseModel], dbt_translator: Optional[DagsterDbtTranslator], - asset_processors: Sequence[AssetSpecProcessor], + transforms: Sequence[AssetSpecTransform], ): self.dbt_resource = dbt_resource self.op_spec = op_spec self.dbt_translator = dbt_translator - self.asset_processors = asset_processors + self.transforms = transforms @classmethod def get_generator(cls) -> "DbtProjectComponentGenerator": @@ -97,7 +96,7 @@ def load(cls, context: ComponentLoadContext) -> Self: params=loaded_params.translator, value_renderer=context.templated_value_renderer, ), - asset_processors=loaded_params.asset_attributes or [], + transforms=loaded_params.transforms or [], ) def build_defs(self, context: ComponentLoadContext) -> Definitions: @@ -115,7 +114,7 @@ def _fn(context: AssetExecutionContext): yield from self.execute(context=context, dbt=self.dbt_resource) defs = Definitions(assets=[_fn]) - for transform in self.asset_processors: + for transform in self.transforms: defs = transform.apply(defs, context.templated_value_renderer) return defs diff --git a/python_modules/libraries/dagster-components/dagster_components/lib/sling_replication_collection/component.py b/python_modules/libraries/dagster-components/dagster_components/lib/sling_replication_collection/component.py index d7ad0b1a4099e..0c672d265e0ad 100644 --- a/python_modules/libraries/dagster-components/dagster_components/lib/sling_replication_collection/component.py +++ b/python_modules/libraries/dagster-components/dagster_components/lib/sling_replication_collection/component.py @@ -1,10 +1,8 @@ -from collections.abc import Iterator, Mapping, Sequence +from collections.abc import Iterator, Sequence from pathlib import Path -from typing import Any, Optional, Union +from typing import Optional, Union -from dagster._core.definitions.asset_key import AssetKey from dagster._core.definitions.assets import AssetsDefinition -from dagster._core.definitions.auto_materialize_policy import AutoMaterializePolicy from dagster._core.definitions.definitions_class import Definitions from dagster._core.definitions.events import AssetMaterialization from dagster._core.definitions.result import MaterializeResult @@ -14,67 +12,26 @@ from typing_extensions import Self from dagster_components import Component, ComponentLoadContext -from dagster_components.core.component import TemplatedValueRenderer, component_type +from dagster_components.core.component import component_type from dagster_components.core.component_generator import ComponentGenerator from dagster_components.core.dsl_schema import ( - AssetAttributes, AssetAttributesModel, - AssetSpecProcessor, + AssetSpecTransform, OpSpecBaseModel, ) +from dagster_components.utils import get_wrapped_translator_class class SlingReplicationParams(BaseModel): path: str op: Optional[OpSpecBaseModel] = None - translator: Optional[AssetAttributesModel] = None + asset_attributes: Optional[AssetAttributesModel] = None class SlingReplicationCollectionParams(BaseModel): sling: Optional[SlingResource] = None replications: Sequence[SlingReplicationParams] - asset_attributes: Optional[AssetAttributes] = None - - -class SlingReplicationTranslator(DagsterSlingTranslator): - def __init__( - self, - *, - params: Optional[AssetAttributesModel], - value_renderer: TemplatedValueRenderer, - ): - self.params = params or AssetAttributesModel() - self.value_renderer = value_renderer - - def _get_rendered_attribute( - self, attribute: str, stream_definition: Mapping[str, Any], default_method - ) -> Any: - renderer = self.value_renderer.with_context(stream_definition=stream_definition) - rendered_attribute = self.params.render_properties(renderer).get(attribute) - return ( - rendered_attribute - if rendered_attribute is not None - else default_method(stream_definition) - ) - - def get_asset_key(self, stream_definition: Mapping[str, Any]) -> AssetKey: - return self._get_rendered_attribute("key", stream_definition, super().get_asset_key) - - def get_group_name(self, stream_definition: Mapping[str, Any]) -> Optional[str]: - return self._get_rendered_attribute("group_name", stream_definition, super().get_group_name) - - def get_tags(self, stream_definition: Mapping[str, Any]) -> Mapping[str, str]: - return self._get_rendered_attribute("tags", stream_definition, super().get_tags) - - def get_metadata(self, stream_definition: Mapping[str, Any]) -> Mapping[str, Any]: - return self._get_rendered_attribute("metadata", stream_definition, super().get_metadata) - - def get_auto_materialize_policy( - self, stream_definition: Mapping[str, Any] - ) -> Optional[AutoMaterializePolicy]: - return self._get_rendered_attribute( - "auto_materialize_policy", stream_definition, super().get_auto_materialize_policy - ) + transforms: Optional[Sequence[AssetSpecTransform]] = None @component_type(name="sling_replication_collection") @@ -84,12 +41,12 @@ def __init__( dirpath: Path, resource: SlingResource, sling_replications: Sequence[SlingReplicationParams], - asset_attributes: Sequence[AssetSpecProcessor], + transforms: Sequence[AssetSpecTransform], ): self.dirpath = dirpath self.resource = resource self.sling_replications = sling_replications - self.asset_attributes = asset_attributes + self.transforms = transforms @classmethod def get_generator(cls) -> ComponentGenerator: @@ -110,18 +67,22 @@ def load(cls, context: ComponentLoadContext) -> Self: dirpath=context.path, resource=loaded_params.sling or SlingResource(), sling_replications=loaded_params.replications, - asset_attributes=loaded_params.asset_attributes or [], + transforms=loaded_params.transforms or [], ) def build_replication_asset( self, context: ComponentLoadContext, replication: SlingReplicationParams ) -> AssetsDefinition: + translator_cls = get_wrapped_translator_class(DagsterSlingTranslator) + @sling_assets( name=replication.op.name if replication.op else Path(replication.path).stem, op_tags=replication.op.tags if replication.op else {}, replication_config=self.dirpath / replication.path, - dagster_sling_translator=SlingReplicationTranslator( - params=replication.translator, + dagster_sling_translator=translator_cls( + obj_name="stream_definition", + base_translator=DagsterSlingTranslator(), + asset_attributes=replication.asset_attributes or AssetAttributesModel(), value_renderer=context.templated_value_renderer, ), ) @@ -142,6 +103,6 @@ def build_defs(self, context: ComponentLoadContext) -> Definitions: for replication in self.sling_replications ], ) - for transform in self.asset_attributes: + for transform in self.transforms: defs = transform.apply(defs, context.templated_value_renderer) return defs diff --git a/python_modules/libraries/dagster-components/dagster_components/utils.py b/python_modules/libraries/dagster-components/dagster_components/utils.py index 474f3f5470d01..a76bb24c38554 100644 --- a/python_modules/libraries/dagster-components/dagster_components/utils.py +++ b/python_modules/libraries/dagster-components/dagster_components/utils.py @@ -1,11 +1,21 @@ import importlib.util import sys -from collections.abc import Iterator +from collections.abc import Iterator, Mapping from contextlib import contextmanager +from dataclasses import dataclass from pathlib import Path +from typing import Any, Optional +from dagster._core.definitions.asset_key import AssetKey +from dagster._core.definitions.asset_spec import AssetSpec +from dagster._core.definitions.declarative_automation.automation_condition import ( + AutomationCondition, +) from dagster._core.errors import DagsterError +from dagster_components.core.component_rendering import TemplatedValueRenderer +from dagster_components.core.dsl_schema import AssetAttributesModel + CLI_BUILTIN_COMPONENT_LIB_KEY = "builtin_component_lib" @@ -40,3 +50,70 @@ def get_path_for_package(package_name: str) -> str: if not submodule_search_locations: raise DagsterError(f"Package does not have any locations for submodules: {package_name}") return submodule_search_locations[0] + + +@dataclass +class ResolvingInfo: + obj_name: str + asset_attributes: AssetAttributesModel + value_renderer: TemplatedValueRenderer + + def get_rendered_attribute(self, attribute: str, obj: Any, default_method) -> Any: + renderer = self.value_renderer.with_context(**{self.obj_name: obj}) + rendered_attributes = self.asset_attributes.render_properties(renderer) + return ( + rendered_attributes[attribute] + if attribute in rendered_attributes + else default_method(obj) + ) + + def get_asset_spec(self, base_spec: AssetSpec, context: Mapping[str, Any]) -> AssetSpec: + """Returns an AssetSpec that combines the base spec with attributes resolved using the provided context. + + Usage: + + ```python + class WrappedDagsterXTranslator(DagsterXTranslator): + def __init__(self, *, base_translator, resolving_info: ResolvingInfo): + self.base_translator = base_translator + self.resolving_info = resolving_info + + def get_asset_spec(self, base_spec: AssetSpec, x_params: Any) -> AssetSpec: + return self.resolving_info.get_asset_spec( + base_spec, {"x_params": x_params} + ) + + ``` + """ + resolver = self.value_renderer.with_context(**context) + resolved_attributes = self.asset_attributes.render_properties(resolver) + return base_spec.replace_attributes(**resolved_attributes) + + +def get_wrapped_translator_class(translator_type: type): + """Temporary hack to allow wrapping of many methods of a given translator class. Will be removed + once all translators implement `get_asset_spec`. + """ + + class WrappedTranslator(translator_type): + def __init__(self, *, base_translator, resolving_info: ResolvingInfo): + self.base_translator = base_translator + self.resolving_info = resolving_info + + def get_asset_key(self, obj: Any) -> AssetKey: + return self.resolving_info.get_rendered_attribute("key", obj, super().get_asset_key) + + def get_group_name(self, obj: Any) -> Optional[str]: + return self.resolving_info.get_rendered_attribute( + "group_name", obj, super().get_group_name + ) + + def get_tags(self, obj: Any) -> Mapping[str, str]: + return self._get_rendered_attribute("tags", obj, super().get_tags) + + def get_automation_condition(self, obj: Any) -> Optional[AutomationCondition]: + return self._get_rendered_attribute( + "automation_condition", obj, super().get_automation_condition + ) + + return WrappedTranslator diff --git a/python_modules/libraries/dagster-components/dagster_components_tests/code_locations/dbt_project_location/components/jaffle_shop_dbt/component.yaml b/python_modules/libraries/dagster-components/dagster_components_tests/code_locations/dbt_project_location/components/jaffle_shop_dbt/component.yaml index 8de868b9ec71b..e6be110095d33 100644 --- a/python_modules/libraries/dagster-components/dagster_components_tests/code_locations/dbt_project_location/components/jaffle_shop_dbt/component.yaml +++ b/python_modules/libraries/dagster-components/dagster_components_tests/code_locations/dbt_project_location/components/jaffle_shop_dbt/component.yaml @@ -4,7 +4,7 @@ params: dbt: project_dir: jaffle_shop - asset_attributes: + asset_transforms: - attributes: tags: foo: bar diff --git a/python_modules/libraries/dagster-components/dagster_components_tests/code_locations/sling_location/components/ingest/component.yaml b/python_modules/libraries/dagster-components/dagster_components_tests/code_locations/sling_location/components/ingest/component.yaml index 644c275eadb3f..b725f0927bf72 100644 --- a/python_modules/libraries/dagster-components/dagster_components_tests/code_locations/sling_location/components/ingest/component.yaml +++ b/python_modules/libraries/dagster-components/dagster_components_tests/code_locations/sling_location/components/ingest/component.yaml @@ -3,6 +3,8 @@ type: dagster_components.sling_replication_collection params: replications: - path: ./replication.yaml + asset_attributes: + key: "foo/{{ stream_definition.config.meta.dagster.asset_key }}" sling: connections: - name: DUCKDB diff --git a/python_modules/libraries/dagster-components/dagster_components_tests/integration_tests/test_sling_integration_test.py b/python_modules/libraries/dagster-components/dagster_components_tests/integration_tests/test_sling_integration_test.py index 758b850e1f3a4..b530454f5f042 100644 --- a/python_modules/libraries/dagster-components/dagster_components_tests/integration_tests/test_sling_integration_test.py +++ b/python_modules/libraries/dagster-components/dagster_components_tests/integration_tests/test_sling_integration_test.py @@ -154,7 +154,7 @@ def test_load_from_path(sling_path: Path) -> None: assert len(components) == 1 assert get_asset_keys(components[0]) == { AssetKey("input_csv"), - AssetKey("input_duckdb"), + AssetKey(["foo", "input_duckdb"]), } assert_assets(components[0], 2) diff --git a/python_modules/libraries/dagster-components/dagster_components_tests/unit_tests/test_spec_processing.py b/python_modules/libraries/dagster-components/dagster_components_tests/unit_tests/test_spec_processing.py index 59e2d446f8536..a7494133a899b 100644 --- a/python_modules/libraries/dagster-components/dagster_components_tests/unit_tests/test_spec_processing.py +++ b/python_modules/libraries/dagster-components/dagster_components_tests/unit_tests/test_spec_processing.py @@ -1,17 +1,17 @@ +from collections.abc import Sequence + import pytest from dagster import AssetKey, AssetSpec, AutomationCondition, Definitions from dagster_components.core.dsl_schema import ( - AssetAttributes, AssetAttributesModel, - MergeAttributes, - ReplaceAttributes, + AssetSpecTransform, TemplatedValueRenderer, ) from pydantic import BaseModel, TypeAdapter class M(BaseModel): - asset_attributes: AssetAttributes = [] + asset_attributes: Sequence[AssetSpecTransform] = [] defs = Definitions( @@ -24,7 +24,7 @@ class M(BaseModel): def test_replace_attributes() -> None: - op = ReplaceAttributes( + op = AssetSpecTransform( operation="replace", target="group:g2", attributes=AssetAttributesModel(tags={"newtag": "newval"}), @@ -38,7 +38,7 @@ def test_replace_attributes() -> None: def test_merge_attributes() -> None: - op = MergeAttributes( + op = AssetSpecTransform( operation="merge", target="group:g2", attributes=AssetAttributesModel(tags={"newtag": "newval"}), @@ -52,7 +52,7 @@ def test_merge_attributes() -> None: def test_render_attributes_asset_context() -> None: - op = MergeAttributes( + op = AssetSpecTransform( attributes=AssetAttributesModel(tags={"group_name_tag": "group__{{ asset.group_name }}"}) ) @@ -64,7 +64,7 @@ def test_render_attributes_asset_context() -> None: def test_render_attributes_custom_context() -> None: - op = ReplaceAttributes( + op = AssetSpecTransform( operation="replace", target="group:g2", attributes=AssetAttributesModel( @@ -102,11 +102,11 @@ def _custom_cron(s): # default to merge and a * target ( {"attributes": {"tags": {"a": "b"}}}, - MergeAttributes(target="*", attributes=AssetAttributesModel(tags={"a": "b"})), + AssetSpecTransform(target="*", attributes=AssetAttributesModel(tags={"a": "b"})), ), ( {"operation": "replace", "attributes": {"tags": {"a": "b"}}}, - ReplaceAttributes( + AssetSpecTransform( operation="replace", target="*", attributes=AssetAttributesModel(tags={"a": "b"}), @@ -115,14 +115,14 @@ def _custom_cron(s): # explicit target ( {"attributes": {"tags": {"a": "b"}}, "target": "group:g2"}, - MergeAttributes( + AssetSpecTransform( target="group:g2", attributes=AssetAttributesModel(tags={"a": "b"}), ), ), ( {"operation": "replace", "attributes": {"tags": {"a": "b"}}, "target": "group:g2"}, - ReplaceAttributes( + AssetSpecTransform( operation="replace", target="group:g2", attributes=AssetAttributesModel(tags={"a": "b"}), @@ -131,6 +131,6 @@ def _custom_cron(s): ], ) def test_load_attributes(python, expected) -> None: - loaded = TypeAdapter(AssetAttributes).validate_python([python]) + loaded = TypeAdapter(Sequence[AssetSpecTransform]).validate_python([python]) assert len(loaded) == 1 assert loaded[0] == expected