Skip to content

Commit

Permalink
Merge pull request #21 from chrishavlin/schema_runner
Browse files Browse the repository at this point in the history
Schema execution (re)implementation
  • Loading branch information
samwalkow authored Aug 30, 2022
2 parents 1a1c695 + d810507 commit f264deb
Show file tree
Hide file tree
Showing 20 changed files with 712 additions and 449 deletions.
4 changes: 2 additions & 2 deletions analysis_schema/Previous_Analysis_Schema/data_objects.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import enum
import typing

from pydantic import BaseModel, Schema, create_model
from pydantic import BaseModel, create_model

from .fields import FieldName, FieldParameter
from .quantities import Path, UnitfulArray, UnitfulCoordinate, UnitfulValue, Vector
from .quantities import UnitfulCoordinate, UnitfulValue


class FlatDefinitionsEnum(str, enum.Enum):
Expand Down
2 changes: 1 addition & 1 deletion analysis_schema/Previous_Analysis_Schema/image_gallery.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import enum
from typing import Dict, List, Tuple, Union

from pydantic import BaseModel, Field, Schema, create_model
from pydantic import BaseModel, Field

from .data_objects import DataSource
from .fields import FieldName
Expand Down
8 changes: 3 additions & 5 deletions analysis_schema/Previous_Analysis_Schema/products.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import typing
from pydantic import BaseModel

from pydantic import BaseModel, Schema

from .data_objects import AllData, DataObject
from .data_objects import DataObject
from .dataset import Dataset
from .operations import Operation, Sum
from .operations import Operation


class Projection(BaseModel):
Expand Down
88 changes: 67 additions & 21 deletions analysis_schema/_data_store.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import contextlib
from pathlib import PosixPath
from typing import Optional

import yt


Expand All @@ -9,42 +13,84 @@ def add_output(self, ytmodel_plotresult):
self._output_list.append(ytmodel_plotresult)


class DatasetFixture:
class DatasetContext:
def __init__(self, fn, *args, in_memory_ds=None, **kwargs):
self.filename = fn
self.load_args = args
self.load_kwargs = kwargs

# _ds and _on_disk here (and below) are mainly for testing purposes so
# that in-memory datasets can be added to the data store.
self._ds = in_memory_ds
self._on_disk = in_memory_ds is None

@contextlib.contextmanager
def load(self):
if self._on_disk:
ds = yt.load(self.filename, *self.load_args, **self.load_kwargs)
else:
ds = self._ds

try:
yield ds
finally:
if self._on_disk:
# ds.close doesnt do anything for majority of frontends... might
# as well call it though.
ds.close()
# do nothing if the ds is in-memory.

@contextlib.contextmanager
def load_sample(self):
ds = yt.load_sample(self.filename, *self.load_args, **self.load_kwargs)
try:
yield ds
finally:
# ds.close doesnt do anything for majority of frontends... might
# as well call it though.
ds.close()


class DataStore:
"""
A class to hold all references and instantiated datasets.
Also has a method to instantiate the data if it isn't already.
There is a dictionary for dataset references and
instantiated datasets.
A class to hold all dataset references.
"""

def __init__(self):
self.all_data = {}
self._instantiated_datasets = {}
self.available_datasets = {}

def add_to_alldata(self, fn: str, dataset_name: str):
def store(self, fn: str, dataset_name: Optional[str] = None, in_memory_ds=None):
"""
A function to track all dataset.
Stores dataset name, or if no name is provided,
adds a number as the name.
"""
self.fn = fn
if dataset_name is not None:
self.dataset_name = dataset_name
else:
self.dataset_name = len(self.all_data.values())
self.all_data[dataset_name] = fn
dataset_name = self.validate_name(fn, dataset_name)

if dataset_name not in self.available_datasets:
self.available_datasets[dataset_name] = DatasetContext(
fn, in_memory_ds=in_memory_ds
)

def _instantiate_data(
def validate_name(self, fn: str, dataset_name: str = None):
if dataset_name is None:
if isinstance(fn, PosixPath):
fn = str(fn)
dataset_name = fn
return dataset_name

def retrieve(
self,
dataset_name: str,
):
"""
Instantiates a dataset and stores it in a separate dictionary.
Returns an instantiated (loaded into memory) dataset.
Returns a dataset context
"""
ds = yt.load(self.all_data[dataset_name])
self._instantiated_datasets[dataset_name] = ds
return ds

if dataset_name in self.available_datasets:
return self.available_datasets[dataset_name]
else:
raise KeyError(f"{dataset_name} is not in the DataStore")

dataset_fixture = DatasetFixture()
def list_available(self):
return list(self.available_datasets.keys())
191 changes: 191 additions & 0 deletions analysis_schema/_model_instantiation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import abc
import inspect
import os

import yt

from . import base_model, data_classes


class YTRunner(abc.ABC):
@abc.abstractmethod
def process_pydantic(self, pydantic_instance, ds=None):
# take the pydantic model and return another object
pass

def run(self, pydantic_instance, ds=None):
return self.process_pydantic(pydantic_instance, ds=ds)


class FieldNames(YTRunner):
def process_pydantic(self, pydantic_instance: data_classes.FieldNames, ds=None):
return (pydantic_instance.field_type, pydantic_instance.field)


class Dataset(YTRunner):
def process_pydantic(self, pydantic_instance: data_classes.Dataset, ds=None):
# always return the instantiated dataset
return ds


class DataSource3D(YTRunner):
def process_pydantic(self, pydantic_instance: data_classes.DataSource3D, ds=None):
for pyfield in pydantic_instance.__fields__.keys():
pyval = getattr(pydantic_instance, pyfield, None)
if pyval is not None:
runner = YTGeneric()
return runner.run(pyval, ds=ds)
return None


class YTGeneric(YTRunner):
@staticmethod
def _determine_callable(pydantic_instance, ds=None):
if hasattr(pydantic_instance, "_yt_operation"):
yt_op = pydantic_instance._yt_operation # e.g., SlicePlot, sphere
else:
yt_op = type(pydantic_instance).__name__

if hasattr(yt, yt_op): # check top api
return getattr(yt, yt_op)
elif hasattr(ds, yt_op): # check ds-level api
return getattr(ds, yt_op)

raise RuntimeError("could not determine yt callable")

def _check_and_run(self, value, ds=None):
# potentially recursive as well
if _is_yt_schema_instance(value):
runner = yt_registry.get(value)
return runner.run(value, ds=ds)
elif isinstance(value, list):
if len(value) and _is_yt_schema_instance(value[0]):
if isinstance(value[0], data_classes.Dataset):
return self._check_and_run(value[0], ds=ds)
return [self._check_and_run(val, ds=ds) for val in value]
return value
else:
return value

def process_pydantic(self, pydantic_instance, ds=None):
yt_func = self._determine_callable(pydantic_instance, ds=ds)
# the list that we'll use to eventually call our function
the_args = []

# now we get the arguments for the function:
# func_spec.args, which lists the named arguments and keyword arguments.
# ignoring vargs and kw-only args for now...
# see https://docs.python.org/3/library/inspect.html#inspect.getfullargspec
func_spec = inspect.getfullargspec(yt_func)

# the argument position number at which we have default values (a little
# hacky, should be a better way to do this, and not sure how to scale it to
# include *args and **kwargs)
n_args = len(func_spec.args) # number of arguments
if func_spec.defaults is None:
# no default args, make sure we never get there...
named_kw_start_at = n_args + 1
else:
# the position at which named keyword args start
named_kw_start_at = n_args - len(func_spec.defaults)

# loop over the call signature arguments and pull out values from our pydantic
# class. this is recursive! will call _run() if a given argument value is also
# a ytBaseModel.
for arg_i, arg in enumerate(func_spec.args):
if arg in ["self", "cls"]:
continue

# get the value for this argument. If it's not there, attempt to set default
# values for arguments needed for yt but not exposed in our pydantic class
try:
arg_value = getattr(pydantic_instance, arg)
if arg_value is None:
default_index = arg_i - named_kw_start_at
arg_value = func_spec.defaults[default_index]
except AttributeError:
if arg_i >= named_kw_start_at:
# we are in the named keyword arguments, grab the default
# the func_spec.defaults tuple 0 index is the first named
# argument, so need to offset the arg_i counter
default_index = arg_i - named_kw_start_at
arg_value = func_spec.defaults[default_index]
else:
raise AttributeError(f"could not file {arg}")

arg_value = self._check_and_run(arg_value, ds=ds)
the_args.append(arg_value)

# if this class has a list of known kwargs that we know will not be
# picked up by argspec, add them here. Not using inspect here because
# some of the yt visualization classes pass along kwargs, so we need
# to do this semi-manually for some classes and functions.
kwarg_dict = {}
if getattr(pydantic_instance, "_known_kwargs", None):
for kw in pydantic_instance._known_kwargs:
arg_value = getattr(pydantic_instance, kw, None)
arg_value = self._check_and_run(arg_value, ds=ds)
kwarg_dict[kw] = arg_value

return yt_func(*the_args, **kwarg_dict)


class Visualizations(YTRunner):
def _sanitize_viz(self, viz_model, yt_viz):
if viz_model.output_type == "file":
# because we may be processing multiple datasets, need to store objects
# without dataset references -- save
if viz_model.output_dir and viz_model.output_file is None:
outdir = viz_model.output_dir
if outdir[-1] != os.sep:
# needs to end in sep so save recognizes it as a directory
outdir = outdir + os.sep
fi = yt_viz.save(outdir)
elif viz_model.output_file and viz_model.output_dir is None:
fi = yt_viz.save(viz_model.output_file)
elif viz_model.output_file and viz_model.output_dir:
fname = os.path.join(viz_model.output_dir, viz_model.output_file)
fi = yt_viz.save(fname)
else:
fi = yt_viz.save()
return fi[0]
elif viz_model.output_type == "html":
return yt_viz._repr_html_()

def process_pydantic(self, pydantic_instance: data_classes.Visualizations, ds=None):
generic_runner = YTGeneric()
viz_results = {}
for attr in pydantic_instance.__fields__.keys():
viz_model = getattr(pydantic_instance, attr) # SlicePlot, etc.
if viz_model is not None:
result = generic_runner.run(viz_model, ds=ds)
nme = f"{ds.basename}_{attr}"
viz_results[nme] = self._sanitize_viz(viz_model, result)
return viz_results


class RunnerRegistry:
def __init__(self):
self._registry = {}

def register(self, pydantic_class, runner):
if isinstance(runner, YTRunner) is False:
raise ValueError("the runner must be a YTRunner instance")
self._registry[pydantic_class] = runner

def get(self, pydantic_class_instance):
pyd_type = type(pydantic_class_instance)
if pyd_type in self._registry:
return self._registry[pyd_type]
return YTGeneric()


def _is_yt_schema_instance(obj):
return isinstance(obj, base_model.ytBaseModel)


yt_registry = RunnerRegistry()
yt_registry.register(data_classes.FieldNames, FieldNames())
yt_registry.register(data_classes.Visualizations, Visualizations())
yt_registry.register(data_classes.Dataset, Dataset())
yt_registry.register(data_classes.DataSource3D, DataSource3D())
16 changes: 16 additions & 0 deletions analysis_schema/_testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import os

from yt.config import ytcfg


def yt_file_exists(req_file):
# returns True if yt can find the file, False otherwise (a simplification of
# yt.testing.requires_file without the nose dependency)
path = ytcfg.get("yt", "test_data_dir")

if os.path.exists(req_file):
return True
else:
if os.path.exists(os.path.join(path, req_file)):
return True
return False
Loading

0 comments on commit f264deb

Please sign in to comment.