Skip to content

Commit

Permalink
[components] Convert params schema from classvar to classmethod
Browse files Browse the repository at this point in the history
  • Loading branch information
OwenKephart committed Jan 8, 2025
1 parent 0ff03a6 commit 9dbf553
Show file tree
Hide file tree
Showing 13 changed files with 63 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def generate_component_command(
raise Exception(
f"Component type {component_type} does not have a generator. Reason: {generator.message}."
)
generate_params = TypeAdapter(generator.generator_params).validate_json(json_params)
generate_params = TypeAdapter(generator.get_params_schema_type()).validate_json(json_params)
else:
generate_params = {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from dagster._core.definitions.definitions_class import Definitions
from dagster._core.errors import DagsterError
from dagster._utils import pushd, snakecase
from pydantic import TypeAdapter
from pydantic import BaseModel, TypeAdapter
from typing_extensions import Self

from dagster_components.core.component_generator import (
Expand All @@ -44,7 +44,10 @@ class ComponentDeclNode: ...

class Component(ABC):
name: ClassVar[Optional[str]] = None
params_schema: ClassVar = None

@classmethod
def get_params_schema_type(cls) -> Optional[Type[BaseModel]]:
return None

@classmethod
def get_generator(cls) -> Union[ComponentGenerator, ComponentGeneratorUnavailableReason]:
Expand Down Expand Up @@ -78,13 +81,17 @@ def get_metadata(cls) -> "ComponentTypeInternalMetadata":
if isinstance(generator, ComponentGeneratorUnavailableReason):
raise DagsterError(f"Component {cls.__name__} is not scaffoldable: {generator.message}")

component_params = cls.get_params_schema_type()
generator_params = generator.get_params_schema_type()
return {
"summary": clean_docstring.split("\n\n")[0] if clean_docstring else None,
"description": clean_docstring if clean_docstring else None,
"generate_params_schema": generator.generator_params.schema()
if generator.generator_params
else None,
"component_params_schema": cls.params_schema.schema() if cls.params_schema else None,
"generate_params_schema": None
if generator_params is None
else generator_params.model_json_schema(),
"component_params_schema": None
if component_params is None
else component_params.model_json_schema(),
}

@classmethod
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from abc import abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Any, ClassVar
from typing import Any, Optional, Type

from dagster._record import record
from pydantic import BaseModel


@record
Expand All @@ -13,7 +14,9 @@ class ComponentGenerateRequest:


class ComponentGenerator:
generator_params: ClassVar = None
@classmethod
def get_params_schema_type(cls) -> Optional[Type[BaseModel]]:
return None

@abstractmethod
def generate_files(self, request: ComponentGenerateRequest, params: Any) -> None: ...
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,6 @@ def get_automation_condition(self, dbt_resource_props):

@component_type(name="dbt_project")
class DbtProjectComponent(Component):
params_schema = DbtProjectParams

def __init__(
self,
dbt_resource: DbtCliResource,
Expand All @@ -83,9 +81,13 @@ def __init__(
def get_generator(cls) -> "DbtProjectComponentGenerator":
return DbtProjectComponentGenerator()

@classmethod
def get_params_schema_type(cls):
return DbtProjectParams

@classmethod
def load(cls, context: ComponentLoadContext) -> Self:
loaded_params = context.load_params(cls.params_schema)
loaded_params = context.load_params(cls.get_params_schema_type())

return cls(
dbt_resource=loaded_params.dbt,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from pathlib import Path
from typing import Optional
from typing import Optional, Type

import dagster._check as check
from dbt.cli.main import dbtRunner
Expand All @@ -16,7 +16,9 @@ class DbtGenerateParams(BaseModel):


class DbtProjectComponentGenerator(ComponentGenerator):
generator_params = DbtGenerateParams
@classmethod
def get_params_schema_type(cls) -> Optional[Type[BaseModel]]:
return DbtGenerateParams

def generate_files(self, request: ComponentGenerateRequest, params: DbtGenerateParams) -> None:
cwd = os.getcwd()
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,18 @@ class DefinitionsComponent(Component):
def __init__(self, definitions_path: Path):
self.definitions_path = definitions_path

params_schema = DefinitionsParamSchema

@classmethod
def get_generator(cls) -> DefinitionsComponentGenerator:
return DefinitionsComponentGenerator()

@classmethod
def get_params_schema_type(cls):
return DefinitionsParamSchema

@classmethod
def load(cls, context: ComponentLoadContext) -> Self:
# all paths should be resolved relative to the directory we're in
loaded_params = context.load_params(cls.params_schema)
loaded_params = context.load_params(cls.get_params_schema_type())

return cls(definitions_path=Path(loaded_params.definitions_path or "definitions.py"))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ class DefinitionsGenerateParams(BaseModel):


class DefinitionsComponentGenerator(ComponentGenerator):
generator_params = DefinitionsGenerateParams
@classmethod
def get_params_schema_type(cls):
return DefinitionsGenerateParams

def generate_files(
self, request: ComponentGenerateRequest, params: DefinitionsGenerateParams
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ class PipesSubprocessScriptCollectionParams(BaseModel):
class PipesSubprocessScriptCollection(Component):
"""Assets that wrap Python scripts executed with Dagster's PipesSubprocessClient."""

params_schema = PipesSubprocessScriptCollectionParams

def __init__(self, dirpath: Path, path_specs: Mapping[Path, Sequence[AssetSpec]]):
self.dirpath = dirpath
# mapping from the script name (e.g. /path/to/script_abc.py -> script_abc)
Expand All @@ -42,9 +40,13 @@ def introspect_from_path(path: Path) -> "PipesSubprocessScriptCollection":
path_specs = {path: [AssetSpec(path.stem)] for path in list(path.rglob("*.py"))}
return PipesSubprocessScriptCollection(dirpath=path, path_specs=path_specs)

@classmethod
def get_params_schema_type(cls):
return PipesSubprocessScriptCollectionParams

@classmethod
def load(cls, context: ComponentLoadContext) -> "PipesSubprocessScriptCollection":
loaded_params = context.load_params(cls.params_schema)
loaded_params = context.load_params(cls.get_params_schema_type())

path_specs = {}
for script in loaded_params.scripts:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@ def get_auto_materialize_policy(

@component_type(name="sling_replication_collection")
class SlingReplicationCollectionComponent(Component):
params_schema = SlingReplicationCollectionParams

def __init__(
self,
dirpath: Path,
Expand All @@ -100,9 +98,13 @@ def get_generator(cls) -> ComponentGenerator:

return SlingReplicationComponentGenerator()

@classmethod
def get_params_schema_type(cls):
return SlingReplicationCollectionParams

@classmethod
def load(cls, context: ComponentLoadContext) -> Self:
loaded_params = context.load_params(cls.params_schema)
loaded_params = context.load_params(cls.get_params_schema_type())
return cls(
dirpath=context.path,
resource=loaded_params.sling or SlingResource(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ class SimpleAssetParams(BaseModel):
class SimpleAsset(Component):
"""A simple asset that returns a constant string value."""

params_schema = SimpleAssetParams
@classmethod
def get_params_schema_type(cls):
return SimpleAssetParams

@classmethod
def get_generator(cls) -> ComponentGenerator:
Expand All @@ -38,7 +40,7 @@ def from_decl_node(
cls, context: "ComponentLoadContext", decl_node: "ComponentDeclNode"
) -> Self:
assert isinstance(decl_node, YamlComponentDecl)
loaded_params = TypeAdapter(cls.params_schema).validate_python(
loaded_params = TypeAdapter(cls.get_params_schema_type()).validate_python(
decl_node.component_file_model.params
)
return cls(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ class SimplePipesScriptAssetParams(BaseModel):


class SimplePipesScriptAssetGenerator(ComponentGenerator):
generator_params = SimplePipesScriptAssetParams
@classmethod
def get_params_schema_type(cls):
return SimplePipesScriptAssetParams

def generate_files(
self, request: ComponentGenerateRequest, params: SimplePipesScriptAssetParams
Expand Down Expand Up @@ -54,18 +56,20 @@ class SimplePipesScriptAsset(Component):
Because it is a pipes asset, no value is returned.
"""

params_schema = SimplePipesScriptAssetParams

@classmethod
def get_generator(cls) -> ComponentGenerator:
return SimplePipesScriptAssetGenerator()

@classmethod
def get_params_schema_type(cls):
return SimplePipesScriptAssetParams

@classmethod
def from_decl_node(
cls, context: "ComponentLoadContext", decl_node: "ComponentDeclNode"
) -> Self:
assert isinstance(decl_node, YamlComponentDecl)
loaded_params = TypeAdapter(cls.params_schema).validate_python(
loaded_params = TypeAdapter(cls.get_params_schema_type()).validate_python(
decl_node.component_file_model.params
)
return cls(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ class CustomScopeParams(BaseModel):

@component_type(name="custom_scope_component")
class HasCustomScope(Component):
params_schema = CustomScopeParams

@classmethod
def get_rendering_scope(cls) -> Mapping[str, Any]:
return {
Expand All @@ -33,9 +31,13 @@ def get_rendering_scope(cls) -> Mapping[str, Any]:
def __init__(self, attributes: Mapping[str, Any]):
self.attributes = attributes

@classmethod
def get_params_schema_type(cls):
return CustomScopeParams

@classmethod
def load(cls, context: ComponentLoadContext):
loaded_params = context.load_params(cls.params_schema)
loaded_params = context.load_params(cls.get_params_schema_type())
return cls(attributes=loaded_params.attributes)

def build_defs(self, context: ComponentLoadContext):
Expand Down

0 comments on commit 9dbf553

Please sign in to comment.