Skip to content

Commit

Permalink
Merge branch 'dev' into 7499-torchio-transforms-wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
ericspod authored Nov 27, 2024
2 parents 2a7842d + 20372f0 commit 8b286ce
Show file tree
Hide file tree
Showing 21 changed files with 1,012 additions and 117 deletions.
5 changes: 5 additions & 0 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,11 @@ Nets
.. autoclass:: ViTAutoEnc
:members:

`MaskedAutoEncoderViT`
~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: MaskedAutoEncoderViT
:members:

`FullyConnectedNet`
~~~~~~~~~~~~~~~~~~~
.. autoclass:: FullyConnectedNet
Expand Down
2 changes: 1 addition & 1 deletion monai/bundle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,4 @@
MACRO_KEY,
load_bundle_config,
)
from .workflows import BundleWorkflow, ConfigWorkflow
from .workflows import BundleWorkflow, ConfigWorkflow, PythonicWorkflow
10 changes: 10 additions & 0 deletions monai/bundle/reference_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,16 @@ def get_resolved_content(self, id: str, **kwargs: Any) -> ConfigExpression | str
"""
return self._resolve_one_item(id=id, **kwargs)

def remove_resolved_content(self, id: str) -> Any | None:
"""
Remove the resolved ``ConfigItem`` by id.
Args:
id: id name of the expected item.
"""
return self.resolved_content.pop(id) if id in self.resolved_content else None

@classmethod
def normalize_id(cls, id: str | int) -> str:
"""
Expand Down
209 changes: 187 additions & 22 deletions monai/bundle/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,18 @@ class BundleWorkflow(ABC):
workflow_type: specifies the workflow type: "train" or "training" for a training workflow,
or "infer", "inference", "eval", "evaluation" for a inference workflow,
other unsupported string will raise a ValueError.
default to `train` for train workflow.
default to `None` for only using meta properties.
workflow: specifies the workflow type: "train" or "training" for a training workflow,
or "infer", "inference", "eval", "evaluation" for a inference workflow,
other unsupported string will raise a ValueError.
default to `None` for common workflow.
properties_path: the path to the JSON file of properties.
properties_path: the path to the JSON file of properties. If `workflow_type` is specified, properties will be
loaded from the file based on the provided `workflow_type` and meta. If no `workflow_type` is specified,
properties will default to loading from "meta". If `properties_path` is None, default properties
will be sourced from "monai/bundle/properties.py" based on the workflow_type:
For a training workflow, properties load from `TrainProperties` and `MetaProperties`.
For a inference workflow, properties load from `InferProperties` and `MetaProperties`.
For workflow_type = None : only `MetaProperties` will be loaded.
meta_file: filepath of the metadata file, if this is a list of file paths, their contents will be merged in order.
logging_file: config file for `logging` module in the program. for more details:
https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig.
Expand Down Expand Up @@ -97,29 +103,50 @@ def __init__(
meta_file = None

workflow_type = workflow if workflow is not None else workflow_type
if workflow_type is None and properties_path is None:
self.properties = copy(MetaProperties)
self.workflow_type = None
self.meta_file = meta_file
return
if workflow_type is not None:
if workflow_type.lower() in self.supported_train_type:
workflow_type = "train"
elif workflow_type.lower() in self.supported_infer_type:
workflow_type = "infer"
else:
raise ValueError(f"Unsupported workflow type: '{workflow_type}'.")

if properties_path is not None:
properties_path = Path(properties_path)
if not properties_path.is_file():
raise ValueError(f"Property file {properties_path} does not exist.")
with open(properties_path) as json_file:
self.properties = json.load(json_file)
self.workflow_type = None
self.meta_file = meta_file
return
if workflow_type.lower() in self.supported_train_type: # type: ignore[union-attr]
self.properties = {**TrainProperties, **MetaProperties}
self.workflow_type = "train"
elif workflow_type.lower() in self.supported_infer_type: # type: ignore[union-attr]
self.properties = {**InferProperties, **MetaProperties}
self.workflow_type = "infer"
try:
properties = json.load(json_file)
self.properties: dict = {}
if workflow_type is not None and workflow_type in properties:
self.properties = properties[workflow_type]
if "meta" in properties:
self.properties.update(properties["meta"])
elif workflow_type is None:
if "meta" in properties:
self.properties = properties["meta"]
logger.info(
"No workflow type specified, default to load meta properties from property file."
)
else:
logger.warning("No 'meta' key found in properties while workflow_type is None.")
except KeyError as e:
raise ValueError(f"{workflow_type} not found in property file {properties_path}") from e
except json.JSONDecodeError as e:
raise ValueError(f"Error decoding JSON from property file {properties_path}") from e
else:
raise ValueError(f"Unsupported workflow type: '{workflow_type}'.")
if workflow_type == "train":
self.properties = {**TrainProperties, **MetaProperties}
elif workflow_type == "infer":
self.properties = {**InferProperties, **MetaProperties}
elif workflow_type is None:
self.properties = copy(MetaProperties)
logger.info("No workflow type and property file specified, default to 'meta' properties.")
else:
raise ValueError(f"Unsupported workflow type: '{workflow_type}'.")

self.workflow_type = workflow_type
self.meta_file = meta_file

@abstractmethod
Expand Down Expand Up @@ -226,6 +253,124 @@ def check_properties(self) -> list[str] | None:
return [n for n, p in self.properties.items() if p.get(BundleProperty.REQUIRED, False) and not hasattr(self, n)]


class PythonicWorkflow(BundleWorkflow):
"""
Base class for the pythonic workflow specification in bundle, it can be a training, evaluation or inference workflow.
It defines the basic interfaces for the bundle workflow behavior: `initialize`, `finalize`, etc.
This also provides the interface to get / set public properties to interact with a bundle workflow through
defined `get_<property>` accessor methods or directly defining members of the object.
For how to set the properties, users can define the `_set_<property>` methods or directly set the members of the object.
The `initialize` method is called to set up the workflow before running. This method sets up internal state
and prepares properties. If properties are modified after the workflow has been initialized, `self._is_initialized`
is set to `False`. Before running the workflow again, `initialize` should be called to ensure that the workflow is
properly set up with the new property values.
Args:
workflow_type: specifies the workflow type: "train" or "training" for a training workflow,
or "infer", "inference", "eval", "evaluation" for a inference workflow,
other unsupported string will raise a ValueError.
default to `None` for only using meta properties.
workflow: specifies the workflow type: "train" or "training" for a training workflow,
or "infer", "inference", "eval", "evaluation" for a inference workflow,
other unsupported string will raise a ValueError.
default to `None` for common workflow.
properties_path: the path to the JSON file of properties. If `workflow_type` is specified, properties will be
loaded from the file based on the provided `workflow_type` and meta. If no `workflow_type` is specified,
properties will default to loading from "meta". If `properties_path` is None, default properties
will be sourced from "monai/bundle/properties.py" based on the workflow_type:
For a training workflow, properties load from `TrainProperties` and `MetaProperties`.
For a inference workflow, properties load from `InferProperties` and `MetaProperties`.
For workflow_type = None : only `MetaProperties` will be loaded.
config_file: path to the config file, typically used to store hyperparameters.
meta_file: filepath of the metadata file, if this is a list of file paths, their contents will be merged in order.
logging_file: config file for `logging` module in the program. for more details:
https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig.
"""

supported_train_type: tuple = ("train", "training")
supported_infer_type: tuple = ("infer", "inference", "eval", "evaluation")

def __init__(
self,
workflow_type: str | None = None,
properties_path: PathLike | None = None,
config_file: str | Sequence[str] | None = None,
meta_file: str | Sequence[str] | None = None,
logging_file: str | None = None,
**override: Any,
):
meta_file = str(Path(os.getcwd()) / "metadata.json") if meta_file is None else meta_file
super().__init__(
workflow_type=workflow_type, properties_path=properties_path, meta_file=meta_file, logging_file=logging_file
)
self._props_vals: dict = {}
self._set_props_vals: dict = {}
self.parser = ConfigParser()
if config_file is not None:
self.parser.read_config(f=config_file)
if self.meta_file is not None:
self.parser.read_meta(f=self.meta_file)

# the rest key-values in the _args are to override config content
self.parser.update(pairs=override)
self._is_initialized: bool = False

def initialize(self, *args: Any, **kwargs: Any) -> Any:
"""
Initialize the bundle workflow before running.
"""
self._props_vals = {}
self._is_initialized = True

def _get_property(self, name: str, property: dict) -> Any:
"""
With specified property name and information, get the expected property value.
If the property is already generated, return from the bucket directly.
If user explicitly set the property, return it directly.
Otherwise, generate the expected property as a class private property with prefix "_".
Args:
name: the name of target property.
property: other information for the target property, defined in `TrainProperties` or `InferProperties`.
"""
if not self._is_initialized:
raise RuntimeError("Please execute 'initialize' before getting any properties.")
value = None
if name in self._set_props_vals:
value = self._set_props_vals[name]
elif name in self._props_vals:
value = self._props_vals[name]
elif name in self.parser.config[self.parser.meta_key]: # type: ignore[index]
id = self.properties.get(name, None).get(BundlePropertyConfig.ID, None)
value = self.parser[id]
else:
try:
value = getattr(self, f"get_{name}")()
except AttributeError as e:
if property[BundleProperty.REQUIRED]:
raise ValueError(
f"unsupported property '{name}' is required in the bundle properties,"
f"need to implement a method 'get_{name}' to provide the property."
) from e
self._props_vals[name] = value
return value

def _set_property(self, name: str, property: dict, value: Any) -> Any:
"""
With specified property name and information, set value for the expected property.
Stores user-reset initialized objects that should not be re-initialized and marks the workflow as not initialized.
Args:
name: the name of target property.
property: other information for the target property, defined in `TrainProperties` or `InferProperties`.
value: value to set for the property.
"""
self._set_props_vals[name] = value
self._is_initialized = False


class ConfigWorkflow(BundleWorkflow):
"""
Specification for the config-based bundle workflow.
Expand Down Expand Up @@ -262,7 +407,13 @@ class ConfigWorkflow(BundleWorkflow):
or "infer", "inference", "eval", "evaluation" for a inference workflow,
other unsupported string will raise a ValueError.
default to `None` for common workflow.
properties_path: the path to the JSON file of properties.
properties_path: the path to the JSON file of properties. If `workflow_type` is specified, properties will be
loaded from the file based on the provided `workflow_type` and meta. If no `workflow_type` is specified,
properties will default to loading from "train". If `properties_path` is None, default properties
will be sourced from "monai/bundle/properties.py" based on the workflow_type:
For a training workflow, properties load from `TrainProperties` and `MetaProperties`.
For a inference workflow, properties load from `InferProperties` and `MetaProperties`.
For workflow_type = None : only `MetaProperties` will be loaded.
override: id-value pairs to override or add the corresponding config content.
e.g. ``--net#input_chns 42``, ``--net %/data/other.json#net_arg``
Expand Down Expand Up @@ -324,7 +475,6 @@ def __init__(
self.parser.read_config(f=config_file)
if self.meta_file is not None:
self.parser.read_meta(f=self.meta_file)

# the rest key-values in the _args are to override config content
self.parser.update(pairs=override)
self.init_id = init_id
Expand Down Expand Up @@ -394,8 +544,23 @@ def check_properties(self) -> list[str] | None:
ret.extend(wrong_props)
return ret

def _run_expr(self, id: str, **kwargs: dict) -> Any:
return self.parser.get_parsed_content(id, **kwargs) if id in self.parser else None
def _run_expr(self, id: str, **kwargs: dict) -> list[Any]:
"""
Evaluate the expression or expression list given by `id`. The resolved values from the evaluations are not stored,
allowing this to be evaluated repeatedly (eg. in streaming applications) without restarting the hosting process.
"""
ret = []
if id in self.parser:
# suppose all the expressions are in a list, run and reset the expressions
if isinstance(self.parser[id], list):
for i in range(len(self.parser[id])):
sub_id = f"{id}{ID_SEP_KEY}{i}"
ret.append(self.parser.get_parsed_content(sub_id, **kwargs))
self.parser.ref_resolver.remove_resolved_content(sub_id)
else:
ret.append(self.parser.get_parsed_content(id, **kwargs))
self.parser.ref_resolver.remove_resolved_content(id)
return ret

def _get_prop_id(self, name: str, property: dict) -> Any:
prop_id = property[BundlePropertyConfig.ID]
Expand Down
22 changes: 18 additions & 4 deletions monai/networks/blocks/selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from __future__ import annotations

from typing import Tuple, Union
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -154,10 +154,12 @@ def __init__(
)
self.input_size = input_size

def forward(self, x):
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
"""
Args:
x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C
attn_mask (torch.Tensor, optional): mask to apply to the attention matrix.
B x (s_dim_1 * ... * s_dim_n). Defaults to None.
Return:
torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C
Expand All @@ -176,7 +178,13 @@ def forward(self, x):

if self.use_flash_attention:
x = F.scaled_dot_product_attention(
query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal
query=q,
key=k,
value=v,
attn_mask=attn_mask,
scale=self.scale,
dropout_p=self.dropout_rate,
is_causal=self.causal,
)
else:
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
Expand All @@ -186,10 +194,16 @@ def forward(self, x):
att_mat = self.rel_positional_embedding(x, att_mat, q)

if self.causal:
if attn_mask is not None:
raise ValueError("Causal attention does not support attention masks.")
att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[-2], : x.shape[-2]] == 0, float("-inf"))

att_mat = att_mat.softmax(dim=-1)
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(1).unsqueeze(2)
attn_mask = attn_mask.expand(-1, self.num_heads, -1, -1)
att_mat = att_mat.masked_fill(attn_mask == 0, float("-inf"))

att_mat = att_mat.softmax(dim=-1)
if self.save_attn:
# no gradients and new tensor;
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
Expand Down
6 changes: 4 additions & 2 deletions monai/networks/blocks/transformerblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,10 @@ def __init__(
use_flash_attention=use_flash_attention,
)

def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:
x = x + self.attn(self.norm1(x))
def forward(
self, x: torch.Tensor, context: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
x = x + self.attn(self.norm1(x), attn_mask=attn_mask)
if self.with_cross_attention:
x = x + self.cross_attn(self.norm_cross_attn(x), context=context)
x = x + self.mlp(self.norm2(x))
Expand Down
1 change: 1 addition & 0 deletions monai/networks/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from .generator import Generator
from .highresnet import HighResBlock, HighResNet
from .hovernet import Hovernet, HoVernet, HoVerNet, HoverNet
from .masked_autoencoder_vit import MaskedAutoEncoderViT
from .mednext import (
MedNeXt,
MedNext,
Expand Down
Loading

0 comments on commit 8b286ce

Please sign in to comment.