The Module API is the core design for any task launched by this app. It ensures:
- Proper interaction with
TaskManager
, leveraging distributed computing viastart_task_from_dataset()
. - Effective collection of the results from all the Dask workers using
result()
, and utilization of caching so computations are only done once. - Proper interaction with
DatasetSplitManager
to have access to all the dataset updates.
When requesting a module to be computed, different arguments need to be set:
dataset_split_name
: Defines on which dataset split the module should be computed. Some modules need both splits and takeDatasetSplitName.all
.config
: Azimuth config defined by the user, which may contain some attributes that will affect the computations.mod_options
: Different module options can be defined that will affect how the computations are being done.indices
: For some modules only (see details below), indices can be sent to get the results for the specified indices only.filters
: For some modules only (see details below), filters can be sent to get the results for the filtered dataset only.
Each task follows the same Module
API.
We implement compute_on_dataset_split()
that will get the result on the entire dataset and compute()
that computes on a specific batch.
On some tasks, there is no need to implement compute()
, such as Aggregation Modules working on full datasets.
from typing import List
from datasets import Dataset
from azimuth.modules.base_classes import Module
from azimuth.types import ModuleResponse
class YourModule(Module):
def compute_on_dataset_split(self) -> List[ModuleResponse]:
...
def compute(self, batch: Dataset) -> List[ModuleResponse]:
...
Each Module
returns a List[ModuleResponse]
which is the result for each index or set of indices
requested. This is a pydantic object that you need to inherit from.
- The length of this list will vary based on the parent module, as described in the next section.
- We usually override the return type to an object that better describes the output of the module.
The new type should still be wrapped in a list. So
mypy
does not complain, we add# type: ignore
to the function signature line.
When creating a new module, one crucial decision is to determine what type of Module it should inherit from. Different options are available, with different particularities.
- This is the parent module. No Module should directly inherit from it. It includes a lot of the basic methods that apply to all modules.
- All module options are ignored by default (
required_mod_options
andoptional_mod_options
are empty) in this module.
- It works based on indices, and its results will be the same length as the number of indices that were requested. This has crucial implications for caching: the cache can be retrieved per indices, i.e. if we request results for a set of indices that is partially different from a previous set of indices, only the missing indices will be computed.
optional_mod_options
includeindices
for that purpose..get_dataset_split()
returns a subset of thedataset_split
, based onindices
.
- It inherits from Indexable Module.
- It implements a new method,
.save_result()
, which saves the result of the module in the huggingface dataset.
- It inherits from Dataset Result Module, and is a module that wraps models.
- These modules handle all the tasks which are in need of direct access to the model in order to compute their results (e.g. prediction, saliency).
required_mod_options
include bothmodel_contract_method_name
andpipeline_index
.optional_mod_options
include options that can affect models, such as thethreshold
.
- This module inherits from
Module
, and so the same methods and attributes are available. - This module performs an aggregation, which makes the length of the module's result not equal to the number of data points that were requested. For example, the metric module will compute
X
metrics, based on a set ofn
data points. This has crucial implications for caching, as the module needs to be recomputed as soon as the set of data points is changed. For that reason, the result of an aggregation task needs to be a list of length one.
- This module inherits from
AggregationModule
and is similar in all points. required_mod_options
includepipeline_index
since all filterable modules require filtering on the dataset, which involve predictions.optional_mod_options
includefilters
, which mean the dataset can be filtered usingDatasetFilters
(as opposed to regular modules which use indices only).- As such,
.get_dataset_split()
returns a subset of thedataset_split
, based onfilters
.
- This module inherits from
AggregationModule
and is similar in all points. The only difference it that it works on multiple dataset split at the same time. As such, it only acceptsDatasetSplitName.all
.
⚠️ One more time: All aggregation modules need to have a list of length one as a result. Indexable modules need to have a result length equal to the number of indices (or to the full dataset split if no indices are specified). More details on caching is available in the section Caching below.
Because the user can make changes to the dataset, we needed to add a Mixin that can be added to Modules that will need to be recomputed if the user change the dataset.
- Expirable Mixin
- This module needs to be recomputed every time the user makes a change to the dataset. This usually happens because of the use of data action tags in the filter in module options, which can lead to a different set of data points once the user starts to annotate. This also happens when thresholds are edited in the config, which recompute some smart tags.
- Since it is a Mixin, it is added as a second inheritance to Modules that need it.
A Module is assigned a scope, for example the ModelContractConfig
is used when you need access to the model.
The caching mechanism is based on the Module's scope. This way, we can have more robust caching
where we invalidate the cache only when a field used by the Module is changed.
You can define your Module as such to use a Scope. Scopes are defined in ../config.py.
class YourModule(Module[ModelContractConfig]):
...
While it is possible to not assign a Scope, mypy
won't be able to help you.
Some class variables may need to be overridden for a new module, to ensure that the module throws if not called accordingly. This also has crucial implications for the cache: the module results will not be regenerated for nothing.
For example, it should throw if a module is called with the wrong dataset_split, or called with mod_options that should not affect the module's results.
allowed_splits
defines which dataset_split_name
that can be sent to the module.
- The default is that both train and test can be sent.
- We override this attribute when necessary, such as for Comparison Modules.
required_mod_options
is a set containing all mod_options
that are mandatory for a given module.
- By default, the set is empty.
- Different default values are available in the different parent modules below. When overriding the default values, the values from the parent module should be added to the set.
Similarly, optional_mod_options
is a set containing mod_options
that are optional for a given module.
When you call self.get_dataset_manager()
, self.get_dataset()
, self.get_model()
, the value is stored in a singleton from ArtifactManager
.
This can cause issues when we send Module
on the distributed network as we need to copy these objects which can add latency to the network.
It is safe to run these methods during compute_on_dataset_split()
and compute()
as this is run on the workers.
In addition, we have sections of the code where we give the DatasetSplitManager
directly, save_result
for example.
There shouldn't be any occurrence where we need the model or the dataset outside a Module, except when calling self.get_indices()
in caching methods from Module
, and in tests.
Once the new module is implemented, it needs to be added to the set of supported modules that the task manager can launch.
- Add your new module to the relevant enum:
- For regular modules, it should be in
azimuth.types.general.modules.SupportedModule
. - For domain specific modules, if you add a method, it should be added in
azimuth.types.general.modules.SupportedMethod
. The module should be added toazimuth.types.general.modules.SupportedModelContract
.
- For regular modules, it should be in
- For Domain Specific Modules only, add a method in
azimuth.modules.module.DomainSpecificModule
if your module implements a new method. Define its route inroute_request()
. - Add a new mapping so the task manager knows what module should be launched.
- For domain specific modules, it should be both in
azimuth.modules.model_contract_task_mapping
andazimuth.modules.task_mapping
. - For all other modules, it should be in
azimuth.modules.task_mapping
- For domain specific modules, it should be both in
There are different ways to call a module.
Calling directly compute_on_dataset_split()
or compute()
works for tests, but should not be used otherwise, since it ignores the cache.
There are two ways to launch a task:
- If you are calling a
Module
from a place where aTaskManager
is defined, such as the start-up task, you can use it to launch the task. - If you are calling a Module where you don't have access to a
TaskManager
, such as from another module, you can useazimuth.modules.task_execution.get_task_result
to get the result safely.
You can call TaskManager.get_task()
with the Module
parameters to get the Module
back.
If the task has been requested before, it will return the same object.
If the Module is not done()
, the task will be launched.
Be aware that the Module
computation is not necessarily completed when you get the Module back.
In addition, we can supply dependencies
which is a list of other Modules that we need to be completed before we start the computation.
This is useful when the dependency edits the HuggingFace Dataset that our Module
needs for its computation.
The function azimuth.modules.task_execution.get_task_result
will launch the task for you on a Dask
client.
Note the result_type
argument, it is only used for mypy
so that get_task_result
returns the correct type.
See the example above how you would get the results of another module from MyModule
.
from typing import List
from azimuth.modules.base_classes import Module
from azimuth.modules.task_execution import get_task_result
from azimuth.types import ModuleResponse
class AnotherModule(Module):
...
class MyModule(Module):
def compute_on_dataset_split(self):
inner_module = AnotherModule(...)
result = get_task_result(task_module=inner_module, result_type=List[ModuleResponse])
You can explore the function get_task_result
in our repo to see how we first check in the cache, get a client, and start the task.
Every module result is cached in a HDF5 file. This allows for random-access to the file and avoids
re-computation. The caching is implemented in azimuth.modules.caching.HDF5CacheMixin
. A new cache
is generated everytime the .name
of the module changes. The .name
is generated automatically
based on the task_name
, the dataset_split_name
, and two striped hash sequences, based on
either required_mod_options
/optional_mod_options
values (minus indices
) or config
attribute values. By default, a
new cache folder is created when one of the attribute value from the ProjectConfig
is changed.
This includes attributes related to the dataset.