Skip to content

Commit

Permalink
2310 Add load_csv_datalist utility API (Project-MONAI#2349)
Browse files Browse the repository at this point in the history
* [DKMED] add CSV datalist

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] add group feature

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] add unit test

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] add more unit tests

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] add optional install

Signed-off-by: Nic Ma <[email protected]>

* [MONAI] python code formatting

Signed-off-by: monai-bot <[email protected]>

* [DLMED] fix flake8 issue

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] add doc-strings

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] fix typo

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] add CSVDataset for non-iterable data

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] fix min test

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] add CSVIterableDataset base

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] add CSVIterableDataset

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] support multiple processes

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] fix tests

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] fix flake8

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] fix docs-build

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] fix min tests

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] fix CI tests

Signed-off-by: Nic Ma <[email protected]>

* [MONAI] python code formatting

Signed-off-by: monai-bot <[email protected]>

* [DLMED] fix typo

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] change sys.platform

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] skip if windows

Signed-off-by: Nic Ma <[email protected]>

* [MONAI] python code formatting

Signed-off-by: monai-bot <[email protected]>

* [DLMED] add col_types arg

Signed-off-by: Nic Ma <[email protected]>

Co-authored-by: monai-bot <[email protected]>
Co-authored-by: Eric Kerfoot <[email protected]>
Co-authored-by: Wenqi Li <[email protected]>
  • Loading branch information
4 people authored Jun 22, 2021
1 parent 8cda6c1 commit 075bccd
Show file tree
Hide file tree
Showing 13 changed files with 622 additions and 7 deletions.
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ sphinxcontrib-jsmath
sphinxcontrib-qthelp
sphinxcontrib-serializinghtml
sphinx-autodoc-typehints==1.11.1
pandas
12 changes: 12 additions & 0 deletions docs/source/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ Generic Interfaces
:members:
:special-members: __next__

`CSVIterableDataset`
~~~~~~~~~~~~~~~~~~~~
.. autoclass:: CSVIterableDataset
:members:
:special-members: __next__

`PersistentDataset`
~~~~~~~~~~~~~~~~~~~
.. autoclass:: PersistentDataset
Expand Down Expand Up @@ -75,6 +81,12 @@ Generic Interfaces
:members:
:special-members: __getitem__

`CSVDataset`
~~~~~~~~~~~~
.. autoclass:: CSVDataset
:members:
:special-members: __getitem__

Patch-based dataset
-------------------

Expand Down
4 changes: 2 additions & 2 deletions docs/source/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,9 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is

- The options are
```
[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil]
[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas]
```
which correspond to `nibabel`, `scikit-image`, `pillow`, `tensorboard`,
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb` and `psutil`, respectively.
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim` `openslide-python` and `pandas`, respectively.

- `pip install 'monai[all]'` installs all the optional dependencies.
1 change: 1 addition & 0 deletions monai/config/deviceconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def get_optional_config_values():
output["tqdm"] = get_package_version("tqdm")
output["lmdb"] = get_package_version("lmdb")
output["psutil"] = psutil_version
output["pandas"] = get_package_version("pandas")

return output

Expand Down
4 changes: 3 additions & 1 deletion monai/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
ArrayDataset,
CacheDataset,
CacheNTransDataset,
CSVDataset,
Dataset,
LMDBDataset,
NPZDictItemDataset,
Expand All @@ -26,7 +27,7 @@
from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter
from .image_dataset import ImageDataset
from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader, WSIReader
from .iterable_dataset import IterableDataset
from .iterable_dataset import CSVIterableDataset, IterableDataset
from .nifti_saver import NiftiSaver
from .nifti_writer import write_nifti
from .png_saver import PNGSaver
Expand All @@ -38,6 +39,7 @@
from .utils import (
compute_importance_map,
compute_shape_offset,
convert_tables_to_dicts,
correct_nifti_header_if_necessary,
create_file_basename,
decollate_batch,
Expand Down
74 changes: 72 additions & 2 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@
from torch.utils.data import Dataset as _TorchDataset
from torch.utils.data import Subset

from monai.data.utils import first, pickle_hashing
from monai.data.utils import convert_tables_to_dicts, first, pickle_hashing
from monai.transforms import Compose, Randomizable, ThreadUnsafe, Transform, apply_transform
from monai.utils import MAX_SEED, get_seed, min_version, optional_import
from monai.utils import MAX_SEED, ensure_tuple, get_seed, min_version, optional_import

if TYPE_CHECKING:
from tqdm import tqdm
Expand All @@ -41,6 +41,7 @@
tqdm, has_tqdm = optional_import("tqdm", "4.47.0", min_version, "tqdm")

lmdb, _ = optional_import("lmdb")
pd, _ = optional_import("pandas")


class Dataset(_TorchDataset):
Expand Down Expand Up @@ -1061,3 +1062,72 @@ def _transform(self, index: int):
data = apply_transform(self.transform, data)

return data


class CSVDataset(Dataset):
"""
Dataset to load data from CSV files and generate a list of dictionaries,
every dictionay maps to a row of the CSV file, and the keys of dictionary
map to the column names of the CSV file.
It can load multiple CSV files and join the tables with addtional `kwargs` arg.
Support to only load specific rows and columns.
And it can also group several loaded columns to generate a new column, for example,
set `col_groups={"meta": ["meta_0", "meta_1", "meta_2"]}`, output can be::
[
{"image": "./image0.nii", "meta_0": 11, "meta_1": 12, "meta_2": 13, "meta": [11, 12, 13]},
{"image": "./image1.nii", "meta_0": 21, "meta_1": 22, "meta_2": 23, "meta": [21, 22, 23]},
]
Args:
filename: the filename of expected CSV file to load. if providing a list
of filenames, it will load all the files and join tables.
row_indices: indices of the expected rows to load. it should be a list,
every item can be a int number or a range `[start, end)` for the indices.
for example: `row_indices=[[0, 100], 200, 201, 202, 300]`. if None,
load all the rows in the file.
col_names: names of the expected columns to load. if None, load all the columns.
col_types: `type` and `default value` to convert the loaded columns, if None, use original data.
it should be a dictionary, every item maps to an expected column, the `key` is the column
name and the `value` is None or a dictionary to define the default value and data type.
the supported keys in dictionary are: ["type", "default"]. for example::
col_types = {
"subject_id": {"type": str},
"label": {"type": int, "default": 0},
"ehr_0": {"type": float, "default": 0.0},
"ehr_1": {"type": float, "default": 0.0},
"image": {"type": str, "default": None},
}
col_groups: args to group the loaded columns to generate a new column,
it should be a dictionary, every item maps to a group, the `key` will
be the new column name, the `value` is the names of columns to combine. for example:
`col_groups={"ehr": [f"ehr_{i}" for i in range(10)], "meta": ["meta_1", "meta_2"]}`
transform: transform to apply on the loaded items of a dictionary data.
kwargs: additional arguments for `pandas.merge()` API to join tables.
"""

def __init__(
self,
filename: Union[str, Sequence[str]],
row_indices: Optional[Sequence[Union[int, str]]] = None,
col_names: Optional[Sequence[str]] = None,
col_types: Optional[Dict[str, Optional[Dict[str, Any]]]] = None,
col_groups: Optional[Dict[str, Sequence[str]]] = None,
transform: Optional[Callable] = None,
**kwargs,
):
files = ensure_tuple(filename)
dfs = [pd.read_csv(f) for f in files]
data = convert_tables_to_dicts(
dfs=dfs,
row_indices=row_indices,
col_names=col_names,
col_types=col_types,
col_groups=col_groups,
**kwargs,
)
super().__init__(data=data, transform=transform)
99 changes: 98 additions & 1 deletion monai/data/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Iterable, Optional
import math
from typing import Any, Callable, Dict, Iterable, Optional, Sequence, Union

from torch.utils.data import IterableDataset as _TorchIterableDataset
from torch.utils.data import get_worker_info

from monai.data.utils import convert_tables_to_dicts
from monai.transforms import apply_transform
from monai.utils import ensure_tuple, optional_import

pd, _ = optional_import("pandas")


class IterableDataset(_TorchIterableDataset):
Expand Down Expand Up @@ -43,3 +49,94 @@ def __iter__(self):
if self.transform is not None:
data = apply_transform(self.transform, data)
yield data


class CSVIterableDataset(IterableDataset):
"""
Iterable dataset to load CSV files and generate dictionary data.
It can be helpful when loading extemely big CSV files that can't read into memory directly.
To accelerate the loading process, it can support multi-processing based on PyTorch DataLoader workers,
every process executes tranforms on part of every loaded chunk.
Note: the order of output data may not match data source in multi-processing mode.
It can load data from multiple CSV files and join the tables with addtional `kwargs` arg.
Support to only load specific columns.
And it can also group several loaded columns to generate a new column, for example,
set `col_groups={"meta": ["meta_0", "meta_1", "meta_2"]}`, output can be::
[
{"image": "./image0.nii", "meta_0": 11, "meta_1": 12, "meta_2": 13, "meta": [11, 12, 13]},
{"image": "./image1.nii", "meta_0": 21, "meta_1": 22, "meta_2": 23, "meta": [21, 22, 23]},
]
Args:
filename: the filename of expected CSV file to load. if providing a list
of filenames, it will load all the files and join tables.
chunksize: rows of a chunk when loading iterable data from CSV files, default to 1000. more details:
https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html.
col_names: names of the expected columns to load. if None, load all the columns.
col_types: `type` and `default value` to convert the loaded columns, if None, use original data.
it should be a dictionary, every item maps to an expected column, the `key` is the column
name and the `value` is None or a dictionary to define the default value and data type.
the supported keys in dictionary are: ["type", "default"]. for example::
col_types = {
"subject_id": {"type": str},
"label": {"type": int, "default": 0},
"ehr_0": {"type": float, "default": 0.0},
"ehr_1": {"type": float, "default": 0.0},
"image": {"type": str, "default": None},
}
col_groups: args to group the loaded columns to generate a new column,
it should be a dictionary, every item maps to a group, the `key` will
be the new column name, the `value` is the names of columns to combine. for example:
`col_groups={"ehr": [f"ehr_{i}" for i in range(10)], "meta": ["meta_1", "meta_2"]}`
transform: transform to apply on the loaded items of a dictionary data.
kwargs: additional arguments for `pandas.merge()` API to join tables.
"""

def __init__(
self,
filename: Union[str, Sequence[str]],
chunksize: int = 1000,
col_names: Optional[Sequence[str]] = None,
col_types: Optional[Dict[str, Optional[Dict[str, Any]]]] = None,
col_groups: Optional[Dict[str, Sequence[str]]] = None,
transform: Optional[Callable] = None,
**kwargs,
):
self.files = ensure_tuple(filename)
self.chunksize = chunksize
self.iters = self.reset()
self.col_names = col_names
self.col_types = col_types
self.col_groups = col_groups
self.kwargs = kwargs
super().__init__(data=None, transform=transform) # type: ignore

def reset(self, filename: Optional[Union[str, Sequence[str]]] = None):
if filename is not None:
# update files if necessary
self.files = ensure_tuple(filename)
self.iters = [pd.read_csv(f, chunksize=self.chunksize) for f in self.files]
return self.iters

def __iter__(self):
for chunks in zip(*self.iters):
self.data = convert_tables_to_dicts(
dfs=chunks,
col_names=self.col_names,
col_types=self.col_types,
col_groups=self.col_groups,
**self.kwargs,
)
info = get_worker_info()
if info is not None:
length = len(self.data)
per_worker = int(math.ceil(length / float(info.num_workers)))
start = info.id * per_worker
self.data = self.data[start : min(start + per_worker, length)]

return super().__iter__()
84 changes: 83 additions & 1 deletion monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
import pickle
import warnings
from collections import defaultdict
from functools import reduce
from itertools import product, starmap
from pathlib import PurePath
from typing import Dict, Generator, Iterable, List, Mapping, Optional, Sequence, Tuple, Union
from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional, Sequence, Tuple, Union

import numpy as np
import torch
Expand All @@ -37,8 +38,11 @@
)
from monai.utils.enums import Method

pd, _ = optional_import("pandas")
DataFrame, _ = optional_import("pandas", name="DataFrame")
nib, _ = optional_import("nibabel")


__all__ = [
"get_random_patch",
"iter_patch_slices",
Expand All @@ -65,6 +69,7 @@
"decollate_batch",
"pad_list_data_collate",
"no_collation",
"convert_tables_to_dicts",
]


Expand Down Expand Up @@ -983,3 +988,80 @@ def sorted_dict(item, key=None, reverse=False):
if not isinstance(item, dict):
return item
return {k: sorted_dict(v) if isinstance(v, dict) else v for k, v in sorted(item.items(), key=key, reverse=reverse)}


def convert_tables_to_dicts(
dfs,
row_indices: Optional[Sequence[Union[int, str]]] = None,
col_names: Optional[Sequence[str]] = None,
col_types: Optional[Dict[str, Optional[Dict[str, Any]]]] = None,
col_groups: Optional[Dict[str, Sequence[str]]] = None,
**kwargs,
) -> List[Dict[str, Any]]:
"""
Utility to join pandas tables, select rows, columns and generate groups.
Will return a list of dictionaries, every dictionary maps to a row of data in tables.
Args:
dfs: data table in pandas Dataframe format. if providing a list of tables, will join them.
row_indices: indices of the expected rows to load. it should be a list,
every item can be a int number or a range `[start, end)` for the indices.
for example: `row_indices=[[0, 100], 200, 201, 202, 300]`. if None,
load all the rows in the file.
col_names: names of the expected columns to load. if None, load all the columns.
col_types: `type` and `default value` to convert the loaded columns, if None, use original data.
it should be a dictionary, every item maps to an expected column, the `key` is the column
name and the `value` is None or a dictionary to define the default value and data type.
the supported keys in dictionary are: ["type", "default"], and note that the value of `default`
should not be `None`. for example::
col_types = {
"subject_id": {"type": str},
"label": {"type": int, "default": 0},
"ehr_0": {"type": float, "default": 0.0},
"ehr_1": {"type": float, "default": 0.0},
}
col_groups: args to group the loaded columns to generate a new column,
it should be a dictionary, every item maps to a group, the `key` will
be the new column name, the `value` is the names of columns to combine. for example:
`col_groups={"ehr": [f"ehr_{i}" for i in range(10)], "meta": ["meta_1", "meta_2"]}`
kwargs: additional arguments for `pandas.merge()` API to join tables.
"""
df = reduce(lambda l, r: pd.merge(l, r, **kwargs), ensure_tuple(dfs))
# parse row indices
rows: List[Union[int, str]] = []
if row_indices is None:
rows = slice(df.shape[0]) # type: ignore
else:
for i in row_indices:
if isinstance(i, (tuple, list)):
if len(i) != 2:
raise ValueError("range of row indices must contain 2 values: start and end.")
rows.extend(list(range(i[0], i[1])))
else:
rows.append(i)

# convert to a list of dictionaries corresponding to every row
data_ = df.loc[rows] if col_names is None else df.loc[rows, col_names]
if isinstance(col_types, dict):
# fill default values for NaN
defaults = {k: v["default"] for k, v in col_types.items() if v is not None and v.get("default") is not None}
if len(defaults) > 0:
data_ = data_.fillna(value=defaults)
# convert data types
types = {k: v["type"] for k, v in col_types.items() if v is not None and "type" in v}
if len(types) > 0:
data_ = data_.astype(dtype=types)
data: List[Dict] = data_.to_dict(orient="records")

# group columns to generate new column
if col_groups is not None:
groups: Dict[str, List] = {}
for name, cols in col_groups.items():
groups[name] = df.loc[rows, cols].values
# invert items of groups to every row of data
data = [dict(d, **{k: v[i] for k, v in groups.items()}) for i, d in enumerate(data)]

return data
Loading

0 comments on commit 075bccd

Please sign in to comment.