There are two ways to add neural network architectures to the project: contributing to mfai OR creating a py4cast plugin module.
- Neural network architectures MUST be Python classes and inherit from both ModelABC and nn.Module, in that order.
class NewModel(ModelABC, nn.Module):
settings_kls = NewModelSettings
onnx_supported: bool = True
supported_num_spatial_dims = (2,)
num_spatial_dims: int = 2
features_last: bool = False
model_type: int = ModelType.CONVOLUTIONAL
register: bool = True
An example setting file :
@dataclass_json
@dataclass(slots=True)
class NewModelSettings:
num_filters: int = 64
dilation: int = 1
bias: bool = False
use_ghost: bool = False
last_activation: str = "Identity"
You can browse and use the various models in mfai as examples for both Graphs, Vision Transformers and CNNs.
- The constructor of the architecture MUST have the following signature:
def __init__(
self,
in_channels: int,
out_channels: int,
input_shape: Union[None, Tuple[int, int]] = None,
settings: HalfUnetSettings,
*args,
**kwargs,
):
in_channels is the number of features/channels of the input tensor. out_channels is the number of features/channels of the output tensor, in our case the number of weather parameters to predict. settings is a dataclass instance with the settings of the model.
- The ModelABC is a Python ABC which forces you to implement all the required attributes or there will an Exception raised.
Now your model can be either registered explicitely in the system (in case the code is in this repository) or injected in the system as a plugin (in case the code is hosted on a third party repository).
- Model in mfai
This is the approach recommendend for stable models. In order to add your model to the mfai repository:
- Make a Pull Request to mfai
- Make a Pull Request to py4cast to update the version of mfai in the requirements.txt file.
- Model as a third party plugin
Prefer this approach to quickly test a new model. Once it is stable consider contributing to mfai (see previous section).
In order to be discovered, your model Python class MUST:
- be contained in a python module prefixed with py4cast_plugin_
- inherit from ModelABC and nn.Module
- have the register attribute set to True
- have a different name than the models already present in the system
- be discoverable by the system (e.g. in the PYTHONPATH or pip installed)
We provide an example module here to help you create your own plugin. This approach is based on the official python packaging guidelines.
All datasets inherit the pytorch Dataset class and a specific DatasetABC class that handles "weather-like" inputs. The user only needs to define the methods accessing data on disk and subclass the DataAccessor abstract base class.
You must decide how to define, for your dataset :
- a parameter WeatherParam, i.e a 2D field corresponding to a weather variable (e.g geopotential at 500hPa), including long name, level, type of level
- a grid Grid i.e the way a 2D field is represented on the geoid. Regular latitude-longitude is one exemple, but it can be something else (Lambert, etc...).
- a Period (contiguous array of dates and hours describing the temporal extent of your dataset).
- a Timestamp, i.e, for a single sample, how to define the dates and hours involved. In an autoregressive framework, one is typically interested in a day D, hour h and previous / subsequent timesteps : D:h-dt,D:h,D:h+dt, etc...
- Stats : normalization constants used to pre-process the inputs ; these are metadata which can be pre-computed and stored in a
cache_dir
.
This is done with the subsequent steps :
- Your DataAccessor MUST inherit from DatasetAccessor.
class TitanAccessor(DataAccessor):
...
-
Your DataAccessor MUST implement all the abstract methods from DataAccessor.
load_grid_info
,load_param_info
,get_weight_per_level
: deciding how to load metadataget_dataset_path
,cache_dir
,get_filepath
: deciding how to describe directories and files in your architecture.load_data_from_disk
,exists
: used to verify the validity of a sample and to load individual data chunks from the disk.
Please override the
optional_check_before_exists()
in your custom DataAccessor() if you want to avoid unecessary file checking for an optimisation purpose. -
This code must be included in a submodule under the py4cast.datasets module, with :
- the
__init__.py
file containing the definition of the DataAccessor class - additional files, such as
settings.py
(defining constants such as directories) ormetadata.yaml
(providing general information on the dataset). While this is up to the user, we recommend following examples fromtitan
(reanalysis dataset) or frompoesy
(ensemble reforecast dataset).
- the
-
You must modify the
registry
object in thepy4cast.datasets.__init__.py
module to include your custom DataAccessor. After that, the DataAccessor will be used each time your dataset name will include the dataAccessor name as a substring.
Plots are done using the matplotlib library. We wrap each plot in a ErrorObserver class. Below is an example of a plot that shows the spatial distribution of the error for all the variables together. See our observer.py for more examples.
class SpatialErrorPlot(ErrorObserver):
"""
Produce a map which shows where the error are accumulating (all variables together).
"""
def __init__(self, prefix: str = "Test"):
self.spatial_loss_maps = []
self.prefix = prefix
def update(
self,
obj: "AutoRegressiveLightning",
prediction: NamedTensor,
target: NamedTensor,
) -> None:
spatial_loss = obj.loss(prediction, target, reduce_spatial_dim=False)
# Getting only spatial loss for the required val_step_errors
if obj.model.info.output_dim == 1:
spatial_loss = einops.rearrange(
spatial_loss, "b t (x y) -> b t x y ", x=obj.grid_shape[0]
)
self.spatial_loss_maps.append(spatial_loss) # (B, N_log, N_lat, N_lon)
def on_step_end(self, obj: "AutoRegressiveLightning") -> None:
"""
Make the summary figure
"""
In order to add your own plot, you can create a new class that inherits from ErrorObserver and implement the update and on_step_end methods. You can then add your plot to the AutoRegressiveLightning class in the valid_plotters or test_plotters list.
self.test_plotters = [
StateErrorPlot(metrics),
SpatialErrorPlot(),
PredictionPlot(self.hparams["hparams"].num_samples_to_plot),
]
Anyone can contribute through a merge request. The code must pass the unit tests and continuous integration before it is merged.
We provide a first set of unit tests to ensure the correctness of the codebase. You can run them using the following command:
python -m pytest
Our tests cover:
- The NamedTensor class
- The models, we make sure they can be instanciated and trained in a pure PyTorch training loop.
You can run the same linting checks the CI does :
runai exec ./lint.sh .
And also reformat your code (black, isort) :
runai exec ./reformat.sh .
For conda users remove the runai exec prefix.
We have a github pipeline that runs linting (flake8, isort, black, bandit) and tests on every push to the repository. See the github workflow file for more details.
Our CI also launches two runs of the full system (bin/train.py) with our Dummy dataset using HiLam and HalfUnet32.
python bin/train.py --model hilam --dataset dummy --epochs 1 --batch_size 1 --num_pred_steps_train 1 --limit_train_batches 1