Skip to content

Commit

Permalink
[components] Remove translator from SlingReplicationCollectionComponent
Browse files Browse the repository at this point in the history
  • Loading branch information
OwenKephart committed Jan 10, 2025
1 parent 8d1f82c commit 7be98c0
Show file tree
Hide file tree
Showing 8 changed files with 141 additions and 116 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from abc import ABC
from collections.abc import Mapping, Sequence
from typing import AbstractSet, Annotated, Any, Literal, Optional, Union # noqa: UP035
from typing import Annotated, Any, Literal, Optional, Union

import dagster._check as check
from dagster._core.definitions.asset_key import AssetKey
from dagster._core.definitions.asset_selection import AssetSelection
from dagster._core.definitions.asset_spec import AssetSpec, map_asset_specs
Expand All @@ -11,7 +12,7 @@
)
from dagster._core.definitions.definitions_class import Definitions
from dagster._record import replace
from pydantic import BaseModel, Field
from pydantic import BaseModel

from dagster_components.core.component_rendering import (
RenderedModel,
Expand Down Expand Up @@ -51,36 +52,45 @@ class AssetAttributesModel(RenderedModel):
] = None


class AssetSpecProcessor(ABC, BaseModel):
class AssetSpecTransform(ABC, BaseModel):
target: str = "*"
operation: Literal["merge", "replace"] = "merge"
attributes: AssetAttributesModel

class Config:
arbitrary_types_allowed = True

def _apply_to_spec(self, spec: AssetSpec, attributes: Mapping[str, Any]) -> AssetSpec: ...

def apply_to_spec(
self,
spec: AssetSpec,
value_renderer: TemplatedValueRenderer,
target_keys: AbstractSet[AssetKey],
) -> AssetSpec:
if spec.key not in target_keys:
return spec

# add the original spec to the context and resolve values
return self._apply_to_spec(
spec, self.attributes.render_properties(value_renderer.with_context(asset=spec))
)
attributes = self.attributes.render_properties(value_renderer.with_context(asset=spec))

if self.operation == "merge":
mergeable_attributes = {"metadata", "tags"}
merge_attributes = {k: v for k, v in attributes.items() if k in mergeable_attributes}
replace_attributes = {
k: v for k, v in attributes.items() if k not in mergeable_attributes
}
return spec.merge_attributes(**merge_attributes).replace_attributes(
**replace_attributes
)
elif self.operation == "replace":
return spec.replace_attributes(**attributes)
else:
check.failed(f"Unsupported operation: {self.operation}")

def apply(self, defs: Definitions, value_renderer: TemplatedValueRenderer) -> Definitions:
target_selection = AssetSelection.from_string(self.target, include_sources=True)
target_keys = target_selection.resolve(defs.get_asset_graph())

mappable = [d for d in defs.assets or [] if isinstance(d, (AssetsDefinition, AssetSpec))]
mapped_assets = map_asset_specs(
lambda spec: self.apply_to_spec(spec, value_renderer, target_keys),
lambda spec: self.apply_to_spec(spec, value_renderer)
if spec.key in target_keys
else spec,
mappable,
)

Expand All @@ -89,27 +99,3 @@ def apply(self, defs: Definitions, value_renderer: TemplatedValueRenderer) -> De
*[d for d in defs.assets or [] if not isinstance(d, (AssetsDefinition, AssetSpec))],
]
return replace(defs, assets=assets)


class MergeAttributes(AssetSpecProcessor):
# default operation is "merge"
operation: Literal["merge"] = "merge"

def _apply_to_spec(self, spec: AssetSpec, attributes: Mapping[str, Any]) -> AssetSpec:
mergeable_attributes = {"metadata", "tags"}
merge_attributes = {k: v for k, v in attributes.items() if k in mergeable_attributes}
replace_attributes = {k: v for k, v in attributes.items() if k not in mergeable_attributes}
return spec.merge_attributes(**merge_attributes).replace_attributes(**replace_attributes)


class ReplaceAttributes(AssetSpecProcessor):
# operation must be set explicitly
operation: Literal["replace"]

def _apply_to_spec(self, spec: AssetSpec, attributes: Mapping[str, Any]) -> AssetSpec:
return spec.replace_attributes(**attributes)


AssetAttributes = Sequence[
Annotated[Union[MergeAttributes, ReplaceAttributes], Field(union_mode="left_to_right")]
]
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@
from dagster_components import Component, ComponentLoadContext
from dagster_components.core.component import TemplatedValueRenderer, component_type
from dagster_components.core.dsl_schema import (
AssetAttributes,
AssetAttributesModel,
AssetSpecProcessor,
AssetSpecTransform,
OpSpecBaseModel,
)
from dagster_components.lib.dbt_project.generator import DbtProjectComponentGenerator
Expand All @@ -23,7 +22,7 @@ class DbtProjectParams(BaseModel):
dbt: DbtCliResource
op: Optional[OpSpecBaseModel] = None
translator: Optional[AssetAttributesModel] = None
asset_attributes: Optional[AssetAttributes] = None
transforms: Optional[Sequence[AssetSpecTransform]] = None


class DbtProjectComponentTranslator(DagsterDbtTranslator):
Expand Down Expand Up @@ -71,12 +70,12 @@ def __init__(
dbt_resource: DbtCliResource,
op_spec: Optional[OpSpecBaseModel],
dbt_translator: Optional[DagsterDbtTranslator],
asset_processors: Sequence[AssetSpecProcessor],
transforms: Sequence[AssetSpecTransform],
):
self.dbt_resource = dbt_resource
self.op_spec = op_spec
self.dbt_translator = dbt_translator
self.asset_processors = asset_processors
self.transforms = transforms

@classmethod
def get_generator(cls) -> "DbtProjectComponentGenerator":
Expand All @@ -97,7 +96,7 @@ def load(cls, context: ComponentLoadContext) -> Self:
params=loaded_params.translator,
value_renderer=context.templated_value_renderer,
),
asset_processors=loaded_params.asset_attributes or [],
transforms=loaded_params.transforms or [],
)

def build_defs(self, context: ComponentLoadContext) -> Definitions:
Expand All @@ -115,7 +114,7 @@ def _fn(context: AssetExecutionContext):
yield from self.execute(context=context, dbt=self.dbt_resource)

defs = Definitions(assets=[_fn])
for transform in self.asset_processors:
for transform in self.transforms:
defs = transform.apply(defs, context.templated_value_renderer)
return defs

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from collections.abc import Iterator, Mapping, Sequence
from collections.abc import Iterator, Sequence
from pathlib import Path
from typing import Any, Optional, Union
from typing import Optional, Union

from dagster._core.definitions.asset_key import AssetKey
from dagster._core.definitions.assets import AssetsDefinition
from dagster._core.definitions.auto_materialize_policy import AutoMaterializePolicy
from dagster._core.definitions.definitions_class import Definitions
from dagster._core.definitions.events import AssetMaterialization
from dagster._core.definitions.result import MaterializeResult
Expand All @@ -14,67 +12,26 @@
from typing_extensions import Self

from dagster_components import Component, ComponentLoadContext
from dagster_components.core.component import TemplatedValueRenderer, component_type
from dagster_components.core.component import component_type
from dagster_components.core.component_generator import ComponentGenerator
from dagster_components.core.dsl_schema import (
AssetAttributes,
AssetAttributesModel,
AssetSpecProcessor,
AssetSpecTransform,
OpSpecBaseModel,
)
from dagster_components.utils import get_wrapped_translator_class


class SlingReplicationParams(BaseModel):
path: str
op: Optional[OpSpecBaseModel] = None
translator: Optional[AssetAttributesModel] = None
asset_attributes: Optional[AssetAttributesModel] = None


class SlingReplicationCollectionParams(BaseModel):
sling: Optional[SlingResource] = None
replications: Sequence[SlingReplicationParams]
asset_attributes: Optional[AssetAttributes] = None


class SlingReplicationTranslator(DagsterSlingTranslator):
def __init__(
self,
*,
params: Optional[AssetAttributesModel],
value_renderer: TemplatedValueRenderer,
):
self.params = params or AssetAttributesModel()
self.value_renderer = value_renderer

def _get_rendered_attribute(
self, attribute: str, stream_definition: Mapping[str, Any], default_method
) -> Any:
renderer = self.value_renderer.with_context(stream_definition=stream_definition)
rendered_attribute = self.params.render_properties(renderer).get(attribute)
return (
rendered_attribute
if rendered_attribute is not None
else default_method(stream_definition)
)

def get_asset_key(self, stream_definition: Mapping[str, Any]) -> AssetKey:
return self._get_rendered_attribute("key", stream_definition, super().get_asset_key)

def get_group_name(self, stream_definition: Mapping[str, Any]) -> Optional[str]:
return self._get_rendered_attribute("group_name", stream_definition, super().get_group_name)

def get_tags(self, stream_definition: Mapping[str, Any]) -> Mapping[str, str]:
return self._get_rendered_attribute("tags", stream_definition, super().get_tags)

def get_metadata(self, stream_definition: Mapping[str, Any]) -> Mapping[str, Any]:
return self._get_rendered_attribute("metadata", stream_definition, super().get_metadata)

def get_auto_materialize_policy(
self, stream_definition: Mapping[str, Any]
) -> Optional[AutoMaterializePolicy]:
return self._get_rendered_attribute(
"auto_materialize_policy", stream_definition, super().get_auto_materialize_policy
)
transforms: Optional[Sequence[AssetSpecTransform]] = None


@component_type(name="sling_replication_collection")
Expand All @@ -84,12 +41,12 @@ def __init__(
dirpath: Path,
resource: SlingResource,
sling_replications: Sequence[SlingReplicationParams],
asset_attributes: Sequence[AssetSpecProcessor],
transforms: Sequence[AssetSpecTransform],
):
self.dirpath = dirpath
self.resource = resource
self.sling_replications = sling_replications
self.asset_attributes = asset_attributes
self.transforms = transforms

@classmethod
def get_generator(cls) -> ComponentGenerator:
Expand All @@ -110,18 +67,22 @@ def load(cls, context: ComponentLoadContext) -> Self:
dirpath=context.path,
resource=loaded_params.sling or SlingResource(),
sling_replications=loaded_params.replications,
asset_attributes=loaded_params.asset_attributes or [],
transforms=loaded_params.transforms or [],
)

def build_replication_asset(
self, context: ComponentLoadContext, replication: SlingReplicationParams
) -> AssetsDefinition:
translator_cls = get_wrapped_translator_class(DagsterSlingTranslator)

@sling_assets(
name=replication.op.name if replication.op else Path(replication.path).stem,
op_tags=replication.op.tags if replication.op else {},
replication_config=self.dirpath / replication.path,
dagster_sling_translator=SlingReplicationTranslator(
params=replication.translator,
dagster_sling_translator=translator_cls(
obj_name="stream_definition",
base_translator=DagsterSlingTranslator(),
asset_attributes=replication.asset_attributes or AssetAttributesModel(),
value_renderer=context.templated_value_renderer,
),
)
Expand All @@ -142,6 +103,6 @@ def build_defs(self, context: ComponentLoadContext) -> Definitions:
for replication in self.sling_replications
],
)
for transform in self.asset_attributes:
for transform in self.transforms:
defs = transform.apply(defs, context.templated_value_renderer)
return defs
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
import importlib.util
import sys
from collections.abc import Iterator
from collections.abc import Iterator, Mapping
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional

from dagster._core.definitions.asset_key import AssetKey
from dagster._core.definitions.asset_spec import AssetSpec
from dagster._core.definitions.declarative_automation.automation_condition import (
AutomationCondition,
)
from dagster._core.errors import DagsterError

from dagster_components.core.component_rendering import TemplatedValueRenderer
from dagster_components.core.dsl_schema import AssetAttributesModel

CLI_BUILTIN_COMPONENT_LIB_KEY = "builtin_component_lib"


Expand Down Expand Up @@ -40,3 +50,70 @@ def get_path_for_package(package_name: str) -> str:
if not submodule_search_locations:
raise DagsterError(f"Package does not have any locations for submodules: {package_name}")
return submodule_search_locations[0]


@dataclass
class ResolvingInfo:
obj_name: str
asset_attributes: AssetAttributesModel
value_renderer: TemplatedValueRenderer

def get_rendered_attribute(self, attribute: str, obj: Any, default_method) -> Any:
renderer = self.value_renderer.with_context(**{self.obj_name: obj})
rendered_attributes = self.asset_attributes.render_properties(renderer)
return (
rendered_attributes[attribute]
if attribute in rendered_attributes
else default_method(obj)
)

def get_asset_spec(self, base_spec: AssetSpec, context: Mapping[str, Any]) -> AssetSpec:
"""Returns an AssetSpec that combines the base spec with attributes resolved using the provided context.
Usage:
```python
class WrappedDagsterXTranslator(DagsterXTranslator):
def __init__(self, *, base_translator, resolving_info: ResolvingInfo):
self.base_translator = base_translator
self.resolving_info = resolving_info
def get_asset_spec(self, base_spec: AssetSpec, x_params: Any) -> AssetSpec:
return self.resolving_info.get_asset_spec(
base_spec, {"x_params": x_params}
)
```
"""
resolver = self.value_renderer.with_context(**context)
resolved_attributes = self.asset_attributes.render_properties(resolver)
return base_spec.replace_attributes(**resolved_attributes)


def get_wrapped_translator_class(translator_type: type):
"""Temporary hack to allow wrapping of many methods of a given translator class. Will be removed
once all translators implement `get_asset_spec`.
"""

class WrappedTranslator(translator_type):
def __init__(self, *, base_translator, resolving_info: ResolvingInfo):
self.base_translator = base_translator
self.resolving_info = resolving_info

def get_asset_key(self, obj: Any) -> AssetKey:
return self.resolving_info.get_rendered_attribute("key", obj, super().get_asset_key)

def get_group_name(self, obj: Any) -> Optional[str]:
return self.resolving_info.get_rendered_attribute(
"group_name", obj, super().get_group_name
)

def get_tags(self, obj: Any) -> Mapping[str, str]:
return self._get_rendered_attribute("tags", obj, super().get_tags)

def get_automation_condition(self, obj: Any) -> Optional[AutomationCondition]:
return self._get_rendered_attribute(
"automation_condition", obj, super().get_automation_condition
)

return WrappedTranslator
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ params:
dbt:
project_dir: jaffle_shop

asset_attributes:
asset_transforms:
- attributes:
tags:
foo: bar
Expand Down
Loading

0 comments on commit 7be98c0

Please sign in to comment.