diff --git a/analysis_schema/Previous_Analysis_Schema/data_objects.py b/analysis_schema/Previous_Analysis_Schema/data_objects.py index b61d29b..77ecba0 100644 --- a/analysis_schema/Previous_Analysis_Schema/data_objects.py +++ b/analysis_schema/Previous_Analysis_Schema/data_objects.py @@ -1,10 +1,10 @@ import enum import typing -from pydantic import BaseModel, Schema, create_model +from pydantic import BaseModel, create_model from .fields import FieldName, FieldParameter -from .quantities import Path, UnitfulArray, UnitfulCoordinate, UnitfulValue, Vector +from .quantities import UnitfulCoordinate, UnitfulValue class FlatDefinitionsEnum(str, enum.Enum): diff --git a/analysis_schema/Previous_Analysis_Schema/image_gallery.py b/analysis_schema/Previous_Analysis_Schema/image_gallery.py index c9b6531..d36570e 100644 --- a/analysis_schema/Previous_Analysis_Schema/image_gallery.py +++ b/analysis_schema/Previous_Analysis_Schema/image_gallery.py @@ -1,7 +1,7 @@ import enum from typing import Dict, List, Tuple, Union -from pydantic import BaseModel, Field, Schema, create_model +from pydantic import BaseModel, Field from .data_objects import DataSource from .fields import FieldName diff --git a/analysis_schema/Previous_Analysis_Schema/products.py b/analysis_schema/Previous_Analysis_Schema/products.py index 08def35..ff35561 100644 --- a/analysis_schema/Previous_Analysis_Schema/products.py +++ b/analysis_schema/Previous_Analysis_Schema/products.py @@ -1,10 +1,8 @@ -import typing +from pydantic import BaseModel -from pydantic import BaseModel, Schema - -from .data_objects import AllData, DataObject +from .data_objects import DataObject from .dataset import Dataset -from .operations import Operation, Sum +from .operations import Operation class Projection(BaseModel): diff --git a/analysis_schema/_data_store.py b/analysis_schema/_data_store.py index 9a8706a..cc36228 100644 --- a/analysis_schema/_data_store.py +++ b/analysis_schema/_data_store.py @@ -1,3 +1,7 @@ +import contextlib +from pathlib import PosixPath +from typing import Optional + import yt @@ -9,42 +13,84 @@ def add_output(self, ytmodel_plotresult): self._output_list.append(ytmodel_plotresult) -class DatasetFixture: +class DatasetContext: + def __init__(self, fn, *args, in_memory_ds=None, **kwargs): + self.filename = fn + self.load_args = args + self.load_kwargs = kwargs + + # _ds and _on_disk here (and below) are mainly for testing purposes so + # that in-memory datasets can be added to the data store. + self._ds = in_memory_ds + self._on_disk = in_memory_ds is None + + @contextlib.contextmanager + def load(self): + if self._on_disk: + ds = yt.load(self.filename, *self.load_args, **self.load_kwargs) + else: + ds = self._ds + + try: + yield ds + finally: + if self._on_disk: + # ds.close doesnt do anything for majority of frontends... might + # as well call it though. + ds.close() + # do nothing if the ds is in-memory. + + @contextlib.contextmanager + def load_sample(self): + ds = yt.load_sample(self.filename, *self.load_args, **self.load_kwargs) + try: + yield ds + finally: + # ds.close doesnt do anything for majority of frontends... might + # as well call it though. + ds.close() + + +class DataStore: """ - A class to hold all references and instantiated datasets. - Also has a method to instantiate the data if it isn't already. - There is a dictionary for dataset references and - instantiated datasets. + A class to hold all dataset references. """ def __init__(self): - self.all_data = {} - self._instantiated_datasets = {} + self.available_datasets = {} - def add_to_alldata(self, fn: str, dataset_name: str): + def store(self, fn: str, dataset_name: Optional[str] = None, in_memory_ds=None): """ A function to track all dataset. Stores dataset name, or if no name is provided, adds a number as the name. """ - self.fn = fn - if dataset_name is not None: - self.dataset_name = dataset_name - else: - self.dataset_name = len(self.all_data.values()) - self.all_data[dataset_name] = fn + dataset_name = self.validate_name(fn, dataset_name) + + if dataset_name not in self.available_datasets: + self.available_datasets[dataset_name] = DatasetContext( + fn, in_memory_ds=in_memory_ds + ) - def _instantiate_data( + def validate_name(self, fn: str, dataset_name: str = None): + if dataset_name is None: + if isinstance(fn, PosixPath): + fn = str(fn) + dataset_name = fn + return dataset_name + + def retrieve( self, dataset_name: str, ): """ Instantiates a dataset and stores it in a separate dictionary. - Returns an instantiated (loaded into memory) dataset. + Returns a dataset context """ - ds = yt.load(self.all_data[dataset_name]) - self._instantiated_datasets[dataset_name] = ds - return ds - + if dataset_name in self.available_datasets: + return self.available_datasets[dataset_name] + else: + raise KeyError(f"{dataset_name} is not in the DataStore") -dataset_fixture = DatasetFixture() + def list_available(self): + return list(self.available_datasets.keys()) diff --git a/analysis_schema/_model_instantiation.py b/analysis_schema/_model_instantiation.py new file mode 100644 index 0000000..9538d05 --- /dev/null +++ b/analysis_schema/_model_instantiation.py @@ -0,0 +1,191 @@ +import abc +import inspect +import os + +import yt + +from . import base_model, data_classes + + +class YTRunner(abc.ABC): + @abc.abstractmethod + def process_pydantic(self, pydantic_instance, ds=None): + # take the pydantic model and return another object + pass + + def run(self, pydantic_instance, ds=None): + return self.process_pydantic(pydantic_instance, ds=ds) + + +class FieldNames(YTRunner): + def process_pydantic(self, pydantic_instance: data_classes.FieldNames, ds=None): + return (pydantic_instance.field_type, pydantic_instance.field) + + +class Dataset(YTRunner): + def process_pydantic(self, pydantic_instance: data_classes.Dataset, ds=None): + # always return the instantiated dataset + return ds + + +class DataSource3D(YTRunner): + def process_pydantic(self, pydantic_instance: data_classes.DataSource3D, ds=None): + for pyfield in pydantic_instance.__fields__.keys(): + pyval = getattr(pydantic_instance, pyfield, None) + if pyval is not None: + runner = YTGeneric() + return runner.run(pyval, ds=ds) + return None + + +class YTGeneric(YTRunner): + @staticmethod + def _determine_callable(pydantic_instance, ds=None): + if hasattr(pydantic_instance, "_yt_operation"): + yt_op = pydantic_instance._yt_operation # e.g., SlicePlot, sphere + else: + yt_op = type(pydantic_instance).__name__ + + if hasattr(yt, yt_op): # check top api + return getattr(yt, yt_op) + elif hasattr(ds, yt_op): # check ds-level api + return getattr(ds, yt_op) + + raise RuntimeError("could not determine yt callable") + + def _check_and_run(self, value, ds=None): + # potentially recursive as well + if _is_yt_schema_instance(value): + runner = yt_registry.get(value) + return runner.run(value, ds=ds) + elif isinstance(value, list): + if len(value) and _is_yt_schema_instance(value[0]): + if isinstance(value[0], data_classes.Dataset): + return self._check_and_run(value[0], ds=ds) + return [self._check_and_run(val, ds=ds) for val in value] + return value + else: + return value + + def process_pydantic(self, pydantic_instance, ds=None): + yt_func = self._determine_callable(pydantic_instance, ds=ds) + # the list that we'll use to eventually call our function + the_args = [] + + # now we get the arguments for the function: + # func_spec.args, which lists the named arguments and keyword arguments. + # ignoring vargs and kw-only args for now... + # see https://docs.python.org/3/library/inspect.html#inspect.getfullargspec + func_spec = inspect.getfullargspec(yt_func) + + # the argument position number at which we have default values (a little + # hacky, should be a better way to do this, and not sure how to scale it to + # include *args and **kwargs) + n_args = len(func_spec.args) # number of arguments + if func_spec.defaults is None: + # no default args, make sure we never get there... + named_kw_start_at = n_args + 1 + else: + # the position at which named keyword args start + named_kw_start_at = n_args - len(func_spec.defaults) + + # loop over the call signature arguments and pull out values from our pydantic + # class. this is recursive! will call _run() if a given argument value is also + # a ytBaseModel. + for arg_i, arg in enumerate(func_spec.args): + if arg in ["self", "cls"]: + continue + + # get the value for this argument. If it's not there, attempt to set default + # values for arguments needed for yt but not exposed in our pydantic class + try: + arg_value = getattr(pydantic_instance, arg) + if arg_value is None: + default_index = arg_i - named_kw_start_at + arg_value = func_spec.defaults[default_index] + except AttributeError: + if arg_i >= named_kw_start_at: + # we are in the named keyword arguments, grab the default + # the func_spec.defaults tuple 0 index is the first named + # argument, so need to offset the arg_i counter + default_index = arg_i - named_kw_start_at + arg_value = func_spec.defaults[default_index] + else: + raise AttributeError(f"could not file {arg}") + + arg_value = self._check_and_run(arg_value, ds=ds) + the_args.append(arg_value) + + # if this class has a list of known kwargs that we know will not be + # picked up by argspec, add them here. Not using inspect here because + # some of the yt visualization classes pass along kwargs, so we need + # to do this semi-manually for some classes and functions. + kwarg_dict = {} + if getattr(pydantic_instance, "_known_kwargs", None): + for kw in pydantic_instance._known_kwargs: + arg_value = getattr(pydantic_instance, kw, None) + arg_value = self._check_and_run(arg_value, ds=ds) + kwarg_dict[kw] = arg_value + + return yt_func(*the_args, **kwarg_dict) + + +class Visualizations(YTRunner): + def _sanitize_viz(self, viz_model, yt_viz): + if viz_model.output_type == "file": + # because we may be processing multiple datasets, need to store objects + # without dataset references -- save + if viz_model.output_dir and viz_model.output_file is None: + outdir = viz_model.output_dir + if outdir[-1] != os.sep: + # needs to end in sep so save recognizes it as a directory + outdir = outdir + os.sep + fi = yt_viz.save(outdir) + elif viz_model.output_file and viz_model.output_dir is None: + fi = yt_viz.save(viz_model.output_file) + elif viz_model.output_file and viz_model.output_dir: + fname = os.path.join(viz_model.output_dir, viz_model.output_file) + fi = yt_viz.save(fname) + else: + fi = yt_viz.save() + return fi[0] + elif viz_model.output_type == "html": + return yt_viz._repr_html_() + + def process_pydantic(self, pydantic_instance: data_classes.Visualizations, ds=None): + generic_runner = YTGeneric() + viz_results = {} + for attr in pydantic_instance.__fields__.keys(): + viz_model = getattr(pydantic_instance, attr) # SlicePlot, etc. + if viz_model is not None: + result = generic_runner.run(viz_model, ds=ds) + nme = f"{ds.basename}_{attr}" + viz_results[nme] = self._sanitize_viz(viz_model, result) + return viz_results + + +class RunnerRegistry: + def __init__(self): + self._registry = {} + + def register(self, pydantic_class, runner): + if isinstance(runner, YTRunner) is False: + raise ValueError("the runner must be a YTRunner instance") + self._registry[pydantic_class] = runner + + def get(self, pydantic_class_instance): + pyd_type = type(pydantic_class_instance) + if pyd_type in self._registry: + return self._registry[pyd_type] + return YTGeneric() + + +def _is_yt_schema_instance(obj): + return isinstance(obj, base_model.ytBaseModel) + + +yt_registry = RunnerRegistry() +yt_registry.register(data_classes.FieldNames, FieldNames()) +yt_registry.register(data_classes.Visualizations, Visualizations()) +yt_registry.register(data_classes.Dataset, Dataset()) +yt_registry.register(data_classes.DataSource3D, DataSource3D()) diff --git a/analysis_schema/_testing.py b/analysis_schema/_testing.py new file mode 100644 index 0000000..15cbf2a --- /dev/null +++ b/analysis_schema/_testing.py @@ -0,0 +1,16 @@ +import os + +from yt.config import ytcfg + + +def yt_file_exists(req_file): + # returns True if yt can find the file, False otherwise (a simplification of + # yt.testing.requires_file without the nose dependency) + path = ytcfg.get("yt", "test_data_dir") + + if os.path.exists(req_file): + return True + else: + if os.path.exists(os.path.join(path, req_file)): + return True + return False diff --git a/analysis_schema/_workflows.py b/analysis_schema/_workflows.py new file mode 100644 index 0000000..8ecc2c8 --- /dev/null +++ b/analysis_schema/_workflows.py @@ -0,0 +1,124 @@ +import abc +import os +from pathlib import PosixPath +from typing import Union + +from ._data_store import DataStore +from ._model_instantiation import _is_yt_schema_instance, yt_registry +from .data_classes import Dataset +from .schema_model import ytModel + + +class BaseWorkflow: + def __init__(self, model, ds_name=None): + self.model = model + self.ds_name = ds_name + + @abc.abstractmethod + def run(self): + pass + + +class Workflow(BaseWorkflow): + def run(self, ds=None): + runner = yt_registry.get(self.model) + return runner.run(self.model, ds=ds) + + +class MainWorkflow: + def __init__(self, json_like: Union[str, PosixPath, dict]): + + self.model: ytModel = self._validate_json(json_like) + self.data_store = DataStore() + self.workflows_by_dataset = {} + self.workflows_with_no_dataset = [] + self.build_workflows() + + def _validate_json(self, json_like: Union[str, PosixPath, dict]): + + if isinstance(json_like, str): + # could be a file or a string + if os.path.isfile(json_like): + model = ytModel.parse_file(json_like) + else: + # might be a json string, try to parse it + model = ytModel.parse_raw(json_like) + elif isinstance(json_like, PosixPath): + model = ytModel.parse_file(json_like) + else: + model = ytModel.parse_obj(json_like) + return model + + def build_workflows(self): + + # first add the dataset definitions to the store + if self.model.Data is not None: + for datamodel in self.model.Data: + _check_for_ds(datamodel, set(), self.data_store) + + # assemble workflows, potentially adding more datasets to store + workflows_by_ds = {} + workflows_without_dataset = [] + if ( + self.model.Plot is not None + ): # could generalize this for other high-level schema objects + for plot in self.model.Plot: + dataset_names = _check_for_ds(plot, set(), self.data_store) + ds_list = list(dataset_names) # the datasets used by this workflow + if len(ds_list) > 0: + for dsname in ds_list: + # can specify multiple datasets to signal that a workflow should + # be run for each dataset, so duplicate the workflow for each + # dataset + if dsname not in workflows_by_ds: + workflows_by_ds[dsname] = [] + workflows_by_ds[dsname].append(Workflow(plot, dsname)) + else: + # nothing would end up here at present (ever?) + workflows_without_dataset.append(Workflow(plot)) + + self.workflows_by_dataset = workflows_by_ds + self.workflows_with_no_dataset = workflows_without_dataset + + def run_all(self): + output = [] + + # potential for parallel execution here... + for dsname, workflows in self.workflows_by_dataset.items(): + ds_context = self.data_store.retrieve(dsname) + with ds_context.load() as ds: + for workflow in workflows: + output.append(workflow.run(ds)) + + for workflow in self.workflows_with_no_dataset: + # should be empty... + output.append(workflow.run()) + return output + + +def _add_ds_to_store(pydantic_ds: Dataset, data_store): + ds_nm = pydantic_ds.DatasetName + fn = pydantic_ds.fn + data_store.store(fn, dataset_name=ds_nm) + name_to_add = data_store.validate_name(fn, ds_nm) + return name_to_add + + +def _check_for_ds(model, dataset_set: set, data_store): + # walk a pydantic model and add datasets to the data store + # returns a set of the datasets by short name + + if isinstance(model, Dataset): + name_to_add = _add_ds_to_store(model, data_store) + dataset_set.update((name_to_add,)) + elif _is_yt_schema_instance(model): + for attr in model.__fields__.keys(): + attval = getattr(model, attr) + if _is_yt_schema_instance(attval): + dataset_set = _check_for_ds(attval, dataset_set, data_store) + elif isinstance(attval, list) and len(attval): + if _is_yt_schema_instance(attval[0]): + for val in attval: + dataset_set = _check_for_ds(val, dataset_set, data_store) + + return dataset_set diff --git a/analysis_schema/base_model.py b/analysis_schema/base_model.py index db0efb3..5100520 100644 --- a/analysis_schema/base_model.py +++ b/analysis_schema/base_model.py @@ -1,10 +1,7 @@ -from inspect import getfullargspec from typing import List, Optional from pydantic import BaseModel -from ._data_store import dataset_fixture - def show_plots(schema, files): """ @@ -26,151 +23,6 @@ def show_plots(schema, files): class ytBaseModel(BaseModel): - """ - A class to connect attributes and their values to yt operations and their - keyword arguments. - - Args: - BaseModel ([type]): A pydantic basemodel in the form of a json schema - - Raises: - AttributeError: [description] - - Returns: - [list]: A list of yt classes to be run and then displayed - """ - _arg_mapping: dict = {} # mapping from internal yt name to schema name - _yt_operation: Optional[str] + _yt_operation: Optional[str] = None # the name of a yt operation _known_kwargs: Optional[List[str]] = None # a list of known keyword args - - def _run(self): - - # the list that we'll use to eventually call our function - the_args = [] - # this method actually executes the yt code - - # first make sure yt is imported and then get our function handle. This assumes - # that our class name exists in yt's top level api. - import yt - - funcname = getattr(self, "_yt_operation", type(self).__name__) - # if the function is not readily available in yt, move to the except block - # try: - func = getattr(yt, funcname) - - # now we get the arguments for the function: - # func_spec.args, which lists the named arguments and keyword arguments. - # ignoring vargs and kw-only args for now... - # see https://docs.python.org/3/library/inspect.html#inspect.getfullargspec - func_spec = getfullargspec(func) - - # the argument position number at which we have default values (a little - # hacky, should be a better way to do this, and not sure how to scale it to - # include *args and **kwargs) - n_args = len(func_spec.args) # number of arguments - if func_spec.defaults is None: - # no default args, make sure we never get there... - named_kw_start_at = n_args + 1 - else: - # the position at which named keyword args start - named_kw_start_at = n_args - len(func_spec.defaults) - - # loop over the call signature arguments and pull out values from our pydantic - # class. this is recursive! will call _run() if a given argument value is also - # a ytBaseModel. - for arg_i, arg in enumerate(func_spec.args): - # check if we've remapped the yt internal argument name for the schema - if arg in ["self", "cls"]: - continue - - # get the value for this argument. If it's not there, attempt to set default - # values for arguments needed for yt but not exposed in our pydantic class - try: - arg_value = getattr(self, arg) - if arg_value is None: - default_index = arg_i - named_kw_start_at - arg_value = func_spec.defaults[default_index] - except AttributeError: - if arg_i >= named_kw_start_at: - # we are in the named keyword arguments, grab the default - # the func_spec.defaults tuple 0 index is the first named - # argument, so need to offset the arg_i counter - default_index = arg_i - named_kw_start_at - arg_value = func_spec.defaults[default_index] - else: - raise AttributeError(f"could not file {arg}") - - if _check_run(arg_value): - arg_value = arg_value._run() - the_args.append(arg_value) - - # if this class has a list of known kwargs that we know will not be - # picked up by argspec, add them here. Not using inspect here because - # some of the yt visualization classes pass along kwargs, so we need - # to do this semi-manually for some classes and functions. - kwarg_dict = {} - if self._known_kwargs: - for kw in self._known_kwargs: - arg_value = getattr(self, kw, None) - if _check_run(arg_value): - arg_value = arg_value._run() - kwarg_dict[kw] = arg_value - - return func(*the_args, **kwarg_dict) - - -def _check_run(obj) -> bool: - # the following classes will have a ._run() attribute that needs to be called - if ( - isinstance(obj, ytBaseModel) - or isinstance(obj, ytParameter) - or isinstance(obj, ytDataObjectAbstract) - ): - return True - return False - - -class ytParameter(BaseModel): - _skip_these = ["comments"] - - def _run(self): - p = [ - getattr(self, key) - for key in self.schema()["properties"].keys() - if key not in self._skip_these - ] - if len(p) > 1: - raise ValueError("ytParameter instances can only have single values") - return p[0] - - -class ytDataObjectAbstract(ytBaseModel): - # abstract class for all the data selectors to inherit from - - def _run(self): - from yt.data_objects.data_containers import data_object_registry - - the_args = [] - funcname = getattr(self, "_yt_operation", type(self).__name__) - - # get the function from the data object registry - val = data_object_registry[funcname] - - # iterate through the arguments for the found data object - for arguments in val._con_args: - con_value = getattr(self, arguments) - # check that the argument is the correct instance - if isinstance(con_value, ytDataObjectAbstract): - # call the _run() function on the agrument - con_value = con_value._run() - the_args.append(con_value) - - if len(dataset_fixture._instantiated_datasets) > 0: - ds_keys = list(dataset_fixture._instantiated_datasets.keys()) - ds = dataset_fixture._instantiated_datasets[ds_keys[0]] - return val(*the_args, ds=ds) - else: - raise AttributeError( - "could not find a dataset: cannot build the data container" - ) diff --git a/analysis_schema/cli.py b/analysis_schema/cli.py index 1d7cf34..78551f8 100644 --- a/analysis_schema/cli.py +++ b/analysis_schema/cli.py @@ -33,7 +33,7 @@ def main(): def generate(model_type, schema_object, output): """generate a schema file""" - if hasattr(analysis_schema, model_type) is False: + if hasattr(analysis_schema.schema_model, model_type) is False: raise ValueError(f"{model_type} is not a valid analysis_schema model") # instantiate an empty model diff --git a/analysis_schema/data_classes.py b/analysis_schema/data_classes.py index 76817b2..79be9fd 100644 --- a/analysis_schema/data_classes.py +++ b/analysis_schema/data_classes.py @@ -1,10 +1,10 @@ +from enum import Enum from pathlib import Path from typing import List, Optional, Tuple, Union -from pydantic import BaseModel, Field +from pydantic import Field -from ._data_store import dataset_fixture -from .base_model import ytBaseModel, ytDataObjectAbstract, ytParameter +from analysis_schema.base_model import ytBaseModel class Dataset(ytBaseModel): @@ -20,22 +20,9 @@ class Dataset(ytBaseModel): description="A string containing the (path to the file and the) file name", ) comments: Optional[str] - # instantiate: bool = True - _yt_operation: str = "load" - - def _run(self): - if self.DatasetName is not None: - if self.DatasetName in [dataset_fixture._instantiated_datasets.keys()]: - return dataset_fixture._instantiated_datasets[self.DatasetName] - else: - dataset_fixture.add_to_alldata(self.fn, self.DatasetName) - ds = dataset_fixture._instantiate_data(self.DatasetName) - return ds - else: - raise AttributeError("Missing a dataset!") - - -class FieldNames(ytParameter): + + +class FieldNames(ytBaseModel): """ Specify a field name and field type from the dataset """ @@ -49,11 +36,8 @@ class FieldNames(ytParameter): _unit: Optional[str] comments: Optional[str] - def _run(self): - return (self.field_type, self.field) - -class Sphere(ytDataObjectAbstract): +class Sphere(ytBaseModel): """A sphere of points defined by a *center* and a *radius*.""" # found in the 'selection_data_containers.py' @@ -63,7 +47,7 @@ class Sphere(ytDataObjectAbstract): _yt_operation: str = "sphere" -class Region(ytDataObjectAbstract): +class Region(ytBaseModel): """A cartesian box data selection object""" center: List[float] @@ -72,7 +56,7 @@ class Region(ytDataObjectAbstract): _yt_operation: str = "region" -class Slice(ytDataObjectAbstract): +class Slice(ytBaseModel): """An axis-aligned 2-d slice data selection object""" axis: Union[int, str] @@ -88,13 +72,21 @@ class DataSource3D(ytBaseModel): sphere: Optional[Sphere] region: Optional[Region] - def _run(self): - for container in [self.sphere, self.region]: - if container: - return container._run() + +class ytVisType(str, Enum): + """Select visualization output type.""" + + file = "file" + html = "html" -class SlicePlot(ytBaseModel): +class ytVisualization(ytBaseModel): + output_type: ytVisType + output_file: Optional[str] = None + output_dir: Optional[str] = None + + +class SlicePlot(ytVisualization): """Axis-aligned slice plot.""" ds: Optional[List[Dataset]] = Field(alias="Dataset") @@ -109,41 +101,8 @@ class SlicePlot(ytBaseModel): "data_source", ] - def _run(self): - """ - This _run function checks if this plot has a value - for the `ds` - arguement (or attribute). - If it does not, then it looks for data in the - `DatasetFixture` class. - If there is more than one instantiated dataset, a plot - will be created for each dataset. - - return: a dataset, or a list of datasets - """ - figures = [] - if self.ds is None: - for instantiated_keys in list( - dataset_fixture._instantiated_datasets.keys() - ): - self.ds = dataset_fixture._instantiated_datasets[instantiated_keys] - # append output to a list to return - figures.append(super()._run()) - # put each 'self' into the output - # when calling `._run()` there is no plotting - # attribute, so it is not added to the output list - return figures - if self.ds is not None: - if isinstance(self.ds, list): - for data in self.ds: - self.ds = data - figures.append(super()._run()) - return figures - figures.append(super()._run()) - return figures - - -class ProjectionPlot(ytBaseModel): + +class ProjectionPlot(ytVisualization): """Axis-aligned projection plot.""" ds: Optional[List[Dataset]] = Field(alias="Dataset") @@ -173,38 +132,6 @@ class ProjectionPlot(ytBaseModel): Comments: Optional[str] _yt_operation: str = "ProjectionPlot" - def _run(self): - """ - This _run function checks if this plot has a value - for the `ds` arguement (or attribute). - If it does not, then it looks for data in the - `DatasetFixture` class. - If there is more than one instantiated dataset, - a plot will be created for each dataset. - - return: a dataset, or a list of datasets - """ - super_list = [] - if self.ds is None: - for instantiated_keys in list( - dataset_fixture._instantiated_datasets.keys() - ): - self.ds = dataset_fixture._instantiated_datasets[instantiated_keys] - # append output to a list to return - super_list.append(super()._run()) - # put each 'self' into the output - # when calling `._run()` there is no plotting - # attribute, so it is not added to the output list - return super_list - if self.ds is not None: - if isinstance(self.ds, list): - for data in self.ds: - self.ds = data - super_list.append(super()._run()) - return super_list - super_list.append(super()._run()) - return super_list - @property def axis(self): # yt <= 4.1.0 uses axis instead of normal, this aliasing allows the @@ -212,7 +139,7 @@ def axis(self): return self.normal -class PhasePlot(ytBaseModel): +class PhasePlot(ytVisualization): """A yt phase plot""" data_source: Optional[Dataset] = Field(alias="Dataset") @@ -232,30 +159,14 @@ class PhasePlot(ytBaseModel): Comments: Optional[str] _yt_operation: str = "PhasePlot" - def _run(self): - super_list = [] - if self.ds is None: - # self.ds = list(DatasetFixture._instantiated_datasets.values())[0] - for instantiated_keys in list( - dataset_fixture._instantiated_datasets.keys() - ): - self.ds = dataset_fixture._instantiated_datasets[instantiated_keys] - super_list.append(super()._run()) - # put each 'self' into the output - # when calling `._run()` there is no plotting - # attribute, so it is not added to the output list - return super_list - # return super()._run() - - -class Visualizations(BaseModel): + +class Visualizations(ytBaseModel): """ This class organizes the attributes below so users can select the plot by name, and see the correct arguments as suggestions """ - # use pydantic basemodel SlicePlot: Optional[SlicePlot] ProjectionPlot: Optional[ProjectionPlot] PhasePlot: Optional[PhasePlot] diff --git a/analysis_schema/pydantic_schema_example.json b/analysis_schema/pydantic_schema_example.json index 004d38b..85e69e1 100644 --- a/analysis_schema/pydantic_schema_example.json +++ b/analysis_schema/pydantic_schema_example.json @@ -5,11 +5,11 @@ "ProjectionPlot": { "Dataset": [ { - "FileName": "../../Data/IsolatedGalaxy/galaxy0030/galaxy0030", + "FileName": "IsolatedGalaxy/galaxy0030/galaxy0030", "DatasetName": "IG" }, { - "FileName": "../../Data/enzo_tiny_cosmology/DD0000/DD0000", + "FileName": "enzo_tiny_cosmology/DD0000/DD0000", "DatasetName": "Enzo" } ], @@ -21,7 +21,8 @@ "WeightFieldName": { "field": "temperature", "field_type": "gas" - } + }, + "output_type": "file" } } ] diff --git a/analysis_schema/run_analysis.py b/analysis_schema/run_analysis.py index de34ec0..6908629 100644 --- a/analysis_schema/run_analysis.py +++ b/analysis_schema/run_analysis.py @@ -1,11 +1,9 @@ import argparse -import json -from .base_model import show_plots -from .schema_model import ytModel +from analysis_schema._workflows import MainWorkflow -def load_and_run(json_file, files): +def load_and_run(json_file): """ A function to load the user JSON and load it into the analysis schema model, and the run that model to produce an output. @@ -13,20 +11,9 @@ def load_and_run(json_file, files): Args: json_file (json file): the JSON users edit """ - # open the file where the user is entering values - live_json = open(json_file) - # assign to a variable - live_schema = json.load(live_json) - live_json.close() - # remove schema line - live_schema.pop("$schema") - # create analysis schema model - if "Data" in list(live_schema.keys()): - analysis_model = ytModel(Data=live_schema["Data"], Plot=live_schema["Plot"]) - print(show_plots(analysis_model, files)) - else: - analysis_model = ytModel(Plot=live_schema["Plot"]) - print(show_plots(analysis_model, files)) + + wk = MainWorkflow(json_file) + return wk.run_all() if __name__ == "__main__": @@ -37,13 +24,7 @@ def load_and_run(json_file, files): # add the JSON file name agrument parser.add_argument("JSONFile", help="Call the JSON with the Schema to run") - parser.add_argument( - "ImageFormat", - nargs="*", - help="Enter 'Jupyter' to run .show() or a filename to run .save()", - ) - args = parser.parse_args() # run the analysis - load_and_run(args.JSONFile, args.ImageFormat) + load_and_run(args.JSONFile) diff --git a/analysis_schema/save_schema_file.py b/analysis_schema/save_schema_file.py index ef89e18..9aac54d 100644 --- a/analysis_schema/save_schema_file.py +++ b/analysis_schema/save_schema_file.py @@ -4,7 +4,7 @@ # which will be referenced by the user -def save_schema(): +def save_schema(fi: str = None): """ A function to create a schema file """ @@ -13,7 +13,10 @@ def save_schema(): Data=[{"DatasetName": "", "FileName": ""}], Plot=[{}] ) - with open("../analysis_schema/yt_analysis_schema.json", "w") as file: + if fi is None: + fi = "../analysis_schema/yt_analysis_schema.json" + + with open(fi, "w") as file: file.write(analysis_model_schema.schema_json(indent=2)) - print("Schema is has been saved!") + print("Schema has been saved!") diff --git a/analysis_schema/schema_model.py b/analysis_schema/schema_model.py index 34e0ef9..dd867ce 100644 --- a/analysis_schema/schema_model.py +++ b/analysis_schema/schema_model.py @@ -1,6 +1,5 @@ from typing import List, Optional -from ._data_store import Output from .base_model import ytBaseModel from .data_classes import Dataset, Visualizations @@ -23,35 +22,6 @@ class Config: title = "yt Schema Model for Descriptive Visualization and Analysis" underscore_attrs_are_private = True - def _run(self): - # for the top level model, we override this. - # Nested objects will still be recursive! - # output_list = [] - # because this inside the _run() function, - # it is wiped clean everytime it is called - attribute_data = self.Data - # creating an output instance to store viz - output = Output() - - if attribute_data is not None: - # the data does not get added to the output list, because we can't call - # .save() or .show() on it - for data in attribute_data: - data._run() - - attribute_plot = self.Plot - if attribute_plot is not None: - for data_class in attribute_plot: - for attribute in dir(data_class): - if attribute.endswith("Plot"): - plotting_attribute = getattr(data_class, attribute) - if plotting_attribute is not None: - output.add_output(plotting_attribute._run()) - output_flat = [ - viz for out in output._output_list for viz in out - ] - return output_flat - schema = ytModel schema_dict = schema.schema() diff --git a/analysis_schema/server.py b/analysis_schema/server.py index d29916c..5dde545 100644 --- a/analysis_schema/server.py +++ b/analysis_schema/server.py @@ -5,6 +5,8 @@ import pkg_resources +from analysis_schema._workflows import MainWorkflow + from .schema_model import schema # For static serving: @@ -144,7 +146,7 @@ def run_a_schema(json_payload_str): # parse and validate try: - valid_json = schema.parse_raw(json_payload_str) + wkflow = MainWorkflow(json_payload_str) except Exception as ex: return "".join( traceback.format_exception(etype=type(ex), value=ex, tb=ex.__traceback__) @@ -152,7 +154,7 @@ def run_a_schema(json_payload_str): # run it try: - results = valid_json._run() + results = wkflow.run_all() except Exception as ex: return "".join( traceback.format_exception(etype=type(ex), value=ex, tb=ex.__traceback__) diff --git a/analysis_schema/yt_analysis_schema.json b/analysis_schema/yt_analysis_schema.json index e559036..9782344 100644 --- a/analysis_schema/yt_analysis_schema.json +++ b/analysis_schema/yt_analysis_schema.json @@ -21,7 +21,7 @@ "definitions": { "Dataset": { "title": "Dataset", - "description": "The dataset to load. Filename (fn) must be a string.\n\nRequired fields: Filename, DatasetName", + "description": "The dataset to load. Filename (fn) must be a string.\n\nRequired fields: Filename", "type": "object", "properties": { "DatasetName": { @@ -44,6 +44,15 @@ "FileName" ] }, + "ytVisType": { + "title": "ytVisType", + "description": "Select visualization output type.", + "enum": [ + "file", + "html" + ], + "type": "string" + }, "FieldNames": { "title": "FieldNames", "description": "Specify a field name and field type from the dataset", @@ -69,7 +78,7 @@ }, "Sphere": { "title": "Sphere", - "description": "A sphere of points defined by a *center* and a *radius*.\n ", + "description": "A sphere of points defined by a *center* and a *radius*.", "type": "object", "properties": { "Center": { @@ -111,7 +120,7 @@ }, "Region": { "title": "Region", - "description": "A cartesian box data selection object\n ", + "description": "A cartesian box data selection object", "type": "object", "properties": { "center": { @@ -160,6 +169,17 @@ "description": "Axis-aligned slice plot.", "type": "object", "properties": { + "output_type": { + "$ref": "#/definitions/ytVisType" + }, + "output_file": { + "title": "Output File", + "type": "string" + }, + "output_dir": { + "title": "Output Dir", + "type": "string" + }, "Dataset": { "title": "Dataset", "type": "array", @@ -221,6 +241,7 @@ } }, "required": [ + "output_type", "FieldNames", "Axis" ] @@ -230,6 +251,17 @@ "description": "Axis-aligned projection plot.", "type": "object", "properties": { + "output_type": { + "$ref": "#/definitions/ytVisType" + }, + "output_file": { + "title": "Output File", + "type": "string" + }, + "output_dir": { + "title": "Output Dir", + "type": "string" + }, "Dataset": { "title": "Dataset", "type": "array", @@ -317,6 +349,7 @@ } }, "required": [ + "output_type", "FieldNames", "Axis" ] @@ -326,6 +359,17 @@ "description": "A yt phase plot", "type": "object", "properties": { + "output_type": { + "$ref": "#/definitions/ytVisType" + }, + "output_file": { + "title": "Output File", + "type": "string" + }, + "output_dir": { + "title": "Output Dir", + "type": "string" + }, "Dataset": { "$ref": "#/definitions/Dataset" }, @@ -396,6 +440,7 @@ } }, "required": [ + "output_type", "xField", "yField", "zField(s)" diff --git a/tests/test_analysis_schema_cli.py b/tests/test_analysis_schema_cli.py index 6409bba..8ba1e08 100644 --- a/tests/test_analysis_schema_cli.py +++ b/tests/test_analysis_schema_cli.py @@ -8,8 +8,7 @@ from click.testing import CliRunner -from analysis_schema import cli, ytModel -from analysis_schema.schema_model import _empty_model_registry, _model_types +from analysis_schema import cli, schema_model def test_command_line_interface(): @@ -41,7 +40,7 @@ def test_schema_generation(tmpdir): # check schema-string to screen schema_result = runner.invoke(cli.main, ["generate"]) assert schema_result.exit_code == 0 - mod = ytModel() + mod = schema_model.ytModel() s = mod.schema_json(indent=2) assert s in schema_result.output @@ -58,8 +57,8 @@ def test_schema_generation(tmpdir): assert all([s in schema_from_file.keys() for s in schema.keys()]) # check that we can generate a schema for a subset of the full schema - for mtype in _model_types: - cls, kwargs = _empty_model_registry[mtype] + for mtype in schema_model._model_types: + cls, kwargs = schema_model._empty_model_registry[mtype] mod = cls(**kwargs) base_args = ["generate", "--model_type", mtype] for obj in list(kwargs.keys()): @@ -88,11 +87,11 @@ def test_schema_availability(): # test the model type list generation schema_result = runner.invoke(cli.main, ["list-model-types"]) assert schema_result.exit_code == 0 - assert all([s in schema_result.output for s in _model_types]) + assert all([s in schema_result.output for s in schema_model._model_types]) # or each model type, check the list of schema_objects - for mtype in _model_types: - _, kwargs = _empty_model_registry["ytModel"] + for mtype in schema_model._model_types: + _, kwargs = schema_model._empty_model_registry["ytModel"] run_args = ["list-objects", "--model_type", mtype] schema_result = runner.invoke(cli.main, run_args) assert schema_result.exit_code == 0 diff --git a/tests/test_data_store.py b/tests/test_data_store.py new file mode 100644 index 0000000..988be10 --- /dev/null +++ b/tests/test_data_store.py @@ -0,0 +1,83 @@ +import pytest +from yt.testing import fake_random_ds + +from analysis_schema._data_store import DatasetContext, DataStore +from analysis_schema._testing import yt_file_exists + + +def test_data_storage(): + dstore = DataStore() + + dstore.store("test_file.hdf", dataset_name="test") + assert len(dstore.list_available()) == 1 + + # using the same dataset_name should not add another file + fi2 = "test_file_2.hdf" + dstore.store(fi2, dataset_name="test") + assert len(dstore.list_available()) == 1 + assert dstore.available_datasets["test"].filename == "test_file.hdf" + + # adding without a name will use the filename + dstore.store(fi2) + assert len(dstore.list_available()) == 2 + assert dstore.available_datasets[fi2].filename == fi2 + + # add an in-mem dataset + ds = fake_random_ds(3) + dstore.store("in_mem_ds", in_memory_ds=ds) + assert len(dstore.list_available()) == 3 + assert dstore.available_datasets["in_mem_ds"]._on_disk is False + assert dstore.available_datasets["in_mem_ds"]._ds == ds + + +def test_dataset_context_in_mem(): + ds = fake_random_ds(3) + dcont = DatasetContext("in_mem_ds", in_memory_ds=ds) + assert dcont._on_disk is False + assert dcont._ds == ds + with dcont.load() as ds_from_context: + assert ds == ds_from_context + + +def test_dataset_context_storage(): + fi = "IsolatedGalaxy/galaxy0030/galaxy0030" + dcont = DatasetContext(fi) + assert dcont._on_disk + + +def test_dataset_context_on_disk(): + # will only run if the dataset is available. + fi = "IsolatedGalaxy/galaxy0030/galaxy0030" + if yt_file_exists(fi): + dcont = DatasetContext(fi) + assert dcont._on_disk + with dcont.load() as ds_from_context: + assert fi in ds_from_context.parameter_filename + else: + pytest.skip("Dataset file is unavailable.") + + +def test_loading_from_datastore(): + + files = [ + "IsolatedGalaxy/galaxy0030/galaxy0030", + "enzo_tiny_cosmology/DD0000/DD0000", + ] + + for fi in files: + if yt_file_exists(fi) is False: + pytest.skip(f"{fi} not found.") + + dstore = DataStore() + for fi in files: + dstore.store(fi) + + ds = fake_random_ds(3) + dstore.store("in_mem_ds", in_memory_ds=ds) + + assert len(dstore.list_available()) == 3 + + for ds_name in dstore.list_available(): + ds_con = dstore.retrieve(ds_name) + with ds_con.load() as ds_: + _ = ds_.domain_center diff --git a/tests/test_workflows.py b/tests/test_workflows.py new file mode 100644 index 0000000..89dc48a --- /dev/null +++ b/tests/test_workflows.py @@ -0,0 +1,84 @@ +import json +import os + +import pytest +from yt.testing import fake_amr_ds + +from analysis_schema._data_store import DataStore +from analysis_schema._testing import yt_file_exists +from analysis_schema._workflows import MainWorkflow + + +def test_workflow_instantiation(): + jfi = "analysis_schema/pydantic_schema_example.json" + _ = MainWorkflow(jfi) + + with open(jfi) as jstream: + jdict = json.loads(jstream.read()) + _ = MainWorkflow(jdict) + + +def test_full_execution(tmpdir): + jfi = "analysis_schema/pydantic_schema_example.json" + + # first check if the files for the workflow are available + init_wkflow = MainWorkflow(jfi) + files_used = [] + for dscontext in init_wkflow.data_store.available_datasets.values(): + files_used.append(dscontext.filename) + + for dsfi in files_used: + if yt_file_exists(dsfi) is False: + pytest.skip(f"{dsfi} not found.") + + # adjust the output files so they will write to a temporary directory + with open(jfi) as jstream: + jdict = json.loads(jstream.read()) + + newdict = jdict.copy() + for iplot, p in enumerate(jdict["Plot"]): + ptype = list(p.keys())[0] + p[ptype]["output_dir"] = str(tmpdir) + newdict["Plot"][iplot] = p + + # get a new workflow with the updated dictionary + wkflow = MainWorkflow(newdict) + + # actually run it and check that the figures exist + for output in wkflow.run_all(): + output_name = list(output.keys())[0] + output_fi = output[output_name] + assert os.path.isfile(str(output_fi)) + + +def test_execution_with_fake_ds(tmpdir): + jfi = "analysis_schema/pydantic_schema_example.json" + + # adjust the output files so they will write to a temporary directory + with open(jfi) as jstream: + jdict = json.loads(jstream.read()) + + newdict = jdict.copy() + + for iplot, p in enumerate(jdict["Plot"]): + ptype = list(p.keys())[0] + p[ptype]["output_dir"] = str(tmpdir) + newdict["Plot"][iplot] = p + + # get a new workflow with the updated dictionary + wkflow = MainWorkflow(newdict) + + # replace the data store datasets with in-memory datasets + new_store = DataStore() + flist = [("gas", "density"), ("gas", "temperature")] + ulist = ["g/cm**3", "K"] + for dsname, dscon in wkflow.data_store.available_datasets.items(): + ds_ = fake_amr_ds(fields=flist, units=ulist) + new_store.store(dscon.filename, dataset_name=dsname, in_memory_ds=ds_) + wkflow.data_store = new_store + + # actually run it and check that the figures exist + for output in wkflow.run_all(): + output_name = list(output.keys())[0] + output_fi = output[output_name] + assert os.path.isfile(str(output_fi)) diff --git a/tests/test_ytschema.py b/tests/test_ytschema.py index f622374..40af9b5 100644 --- a/tests/test_ytschema.py +++ b/tests/test_ytschema.py @@ -1,17 +1,7 @@ # isort: skip_file import json -import yt -from yt.testing import fake_amr_ds - -import analysis_schema -from analysis_schema._data_store import dataset_fixture -from analysis_schema.base_model import ( - _check_run, - ytBaseModel, - ytDataObjectAbstract, - ytParameter, -) +from analysis_schema.schema_model import ytModel ds_only = r""" { @@ -26,51 +16,41 @@ """ -viz_only_prj = r""" -{ - "$schema": "./yt_analysis_schema.json", - "Plot": [ - { - "ProjectionPlot": { - "Axis":"y", - "FieldNames": { - "field": "temperature", - "field_type": "gas" - }, - "DataSource": { - "region": { - "center": [0.25, 0.25, 0.25], - "left_edge": [0.0, 0.0, 0.0], - "right_edge": [0.5, 0.5, 0.5] - } - } - } - } - ] -} -""" - - -viz_only_slc = r""" +more_complete_example = r""" { - "$schema": "./yt_analysis_schema.json", - "Plot": [ - { - "SlicePlot": { - "Axis":"y", - "FieldNames": { - "field": "temperature", - "field_type": "gas" + "$schema": "../analysis_schema/yt_analysis_schema.json", + "Data": [ + { + "FileName": "IsolatedGalaxy/galaxy0030/galaxy0030", + "DatasetName": "IG_Testing" + } + ], + "Plot": [ + { + "ProjectionPlot": { + "Dataset": [ + { + "FileName": "great_filename", + "DatasetName": "nice" }, - "DataSource": { - "sphere": { - "Center": [0.25, 0.25, 0.25], - "Radius": 0.25 - } + { + "FileName": "and_another", + "DatasetName": "another" } - } + ], + "Axis":"y", + "FieldNames": { + "field": "density", + "field_type": "gas" + }, + "WeightFieldName": { + "field": "temperature", + "field_type": "gas" + }, + "output_type": "file" } - ] + } + ] } """ @@ -78,34 +58,11 @@ def test_validation(): # only testing the validation here, not instantiating yt objects - model = analysis_schema.ytModel.parse_raw(ds_only) + model = ytModel.parse_raw(ds_only) jdict = json.loads(ds_only) assert str(model.Data[0].fn) == jdict["Data"][0]["FileName"] - -def test_execution(): - - # we can inject an instantiated dataset here! the methods that require a - # ds will check the dataset store if ds is None and use this ds: - test_ds = fake_amr_ds(fields=[("gas", "temperature")], units=["K"]) - dataset_fixture._instantiated_datasets["_test_ds"] = test_ds - - # run the slice plot - model = analysis_schema.ytModel.parse_raw(viz_only_slc) - m = model._run() - print(m) - assert isinstance(m[0], yt.AxisAlignedSlicePlot) - - # run the projection plot - model = analysis_schema.ytModel.parse_raw(viz_only_prj) - m = model._run() - print(m) - assert isinstance(m[0], yt.ProjectionPlot) - - -def test_base_model(): - # some basic tests of base_model - for cls in [ytBaseModel, ytDataObjectAbstract, ytParameter]: - c = cls() - assert _check_run(c) - assert _check_run("someothertype") is False + model = ytModel.parse_raw(more_complete_example) + jdict = json.loads(ds_only) + assert str(model.Data[0].fn) == jdict["Data"][0]["FileName"] + assert str(model.Plot[0].ProjectionPlot.normal) == "y"