diff --git a/README.md b/README.md index 902460e6..b4679886 100644 --- a/README.md +++ b/README.md @@ -54,8 +54,32 @@ import yoyodyne ## Usage -See [`yoyodyne-predict --help`](yoyodyne/predict.py) and -[`yoyodyne-train --help`](yoyodyne/train.py). +### Training + +Training is performed by the [`yoyodyne-train`](yoyodyne/train.py) script. One +must specify the following required arguments: + +- `--model_dir`: path for model metadata and checkpoints +- `--experiment`: name of experiment (pick something unique) +- `--train`: path to TSV file containing training data +- `--val`: path to TSV file containing validation data + +The user can also specify various optional training and architectural +arguments. See below or run [`yoyodyne-train --help`](yoyodyne/train.py) for +more information. + +### Prediction + +Prediction is performed by the [`yoyodyne-predict`](yoyodyne/predict.py) +script. One must specify the following required arguments: + +- `--model_dir`: path for model metadata +- `--experiment`: name of experiment +- `--checkpoint`: path to checkpoint +- `--predict`: path to TSV file containing data to be predicted +- `--output`: path for predictions + +Run [`yoyodyne-predict --help`](yoyodyne/predict.py) for more information. ## Data format @@ -73,11 +97,12 @@ the third contains semi-colon delimited feature strings: this format is specified by `--features-col 3`. -Alternatively, for the SIGMORPHON 2016 shared task data format: +Alternatively, for the [SIGMORPHON 2016 shared +task](https://sigmorphon.github.io/sharedtasks/2016/) data: source feat1,feat2,... target -this format is specified by `--features-col 2 --features-sep , --target-col 3`. +this format is specified by `--features_col 2 --features_sep , --target_col 3`. In order to ensure that targets are ignored during prediction, one can specify `--target_col 0`. @@ -101,7 +126,7 @@ checkpoint, the **full path to the checkpoint** should be specified with for the the provided model checkpoint. During training, we save the best `--save_top_k` checkpoints (by default, 1) -ranked according to accuracy on the `--dev` set. For example, `--save_top_k 5` +ranked according to accuracy on the `--val` set. For example, `--save_top_k 5` will save the top 5 most accurate models. ## Reserved symbols @@ -153,8 +178,7 @@ source encoder using the `--source_encoder` flag: - `"transformer"`: This is a transformer encoder. When using features, the user can also specify a non-default features encoder -using the `--features_encoder` flag (`"linear"`, `"lstm"`, -`"transformer"`). +using the `--features_encoder` flag (`"linear"`, `"lstm"`, `"transformer"`). For all models, the user may also wish to specify: diff --git a/pyproject.toml b/pyproject.toml index 7206b1d3..56ad8ebb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ exclude = ["examples*"] [project] name = "yoyodyne" -version = "0.2.3" +version = "0.2.4" description = "Small-vocabulary neural sequence-to-sequence models" readme = "README.md" requires-python = ">= 3.9" diff --git a/tests/collator_test.py b/tests/collator_test.py deleted file mode 100644 index 9e002df8..00000000 --- a/tests/collator_test.py +++ /dev/null @@ -1,47 +0,0 @@ -import pytest - -from yoyodyne import collators, dataconfig, datasets - - -@pytest.mark.parametrize( - ["arch", "has_features", "has_target", "expected_separate_features"], - [ - ("feature_invariant_transformer", True, True, False), - ("feature_invariant_transformer", True, False, False), - ("lstm", True, True, False), - ("lstm", False, True, False), - ("lstm", True, False, False), - ("lstm", False, False, False), - ("pointer_generator_lstm", True, True, True), - ("pointer_generator_lstm", False, True, False), - ("pointer_generator_lstm", True, False, True), - ("pointer_generator_lstm", False, False, False), - ("transducer", True, True, True), - ("transducer", False, True, False), - ("transducer", True, False, True), - ("transducer", False, False, False), - ("transformer", True, True, False), - ("transformer", False, True, False), - ("transformer", True, False, False), - ("transformer", False, False, False), - ], -) -def test_get_collator( - make_trivial_tsv_file, - arch, - has_features, - has_target, - expected_separate_features, -): - filename = make_trivial_tsv_file - config = dataconfig.DataConfig( - features_col=3 if has_features else 0, - target_col=2 if has_target else 0, - ) - dataset = datasets.get_dataset(filename, config) - collator = collators.Collator( - dataset, - arch, - ) - assert collator.has_target == has_target - assert collator.separate_features == expected_separate_features diff --git a/tests/dataset_test.py b/tests/dataset_test.py deleted file mode 100644 index 21421e36..00000000 --- a/tests/dataset_test.py +++ /dev/null @@ -1,14 +0,0 @@ -import pytest - -from yoyodyne import dataconfig, datasets - - -@pytest.mark.parametrize( - "features_col, expected_cls", - [(0, datasets.DatasetNoFeatures), (3, datasets.DatasetFeatures)], -) -def test_get_dataset(make_trivial_tsv_file, features_col, expected_cls): - filename = make_trivial_tsv_file - config = dataconfig.DataConfig(features_col=features_col) - dataset = datasets.get_dataset(filename, config) - assert type(dataset) is expected_cls diff --git a/yoyodyne/data/__init__.py b/yoyodyne/data/__init__.py new file mode 100644 index 00000000..7553e247 --- /dev/null +++ b/yoyodyne/data/__init__.py @@ -0,0 +1,89 @@ +"""Data classes.""" + +import argparse + +from .. import defaults +from .datamodules import DataModule # noqa: F401 +from .batches import PaddedBatch, PaddedTensor # noqa: F401 +from .indexes import Index # noqa: F401 + + +def add_argparse_args(parser: argparse.ArgumentParser) -> None: + """Adds data options to the argument parser. + Args: + parser (argparse.ArgumentParser). + """ + parser.add_argument( + "--source_col", + type=int, + default=defaults.SOURCE_COL, + help="1-based index for source column. Default: %(default)s.", + ) + parser.add_argument( + "--target_col", + type=int, + default=defaults.TARGET_COL, + help="1-based index for target column. Default: %(default)s.", + ) + parser.add_argument( + "--features_col", + type=int, + default=defaults.FEATURES_COL, + help="1-based index for features column; " + "0 indicates the model will not use features. " + "Default: %(default)s.", + ) + parser.add_argument( + "--source_sep", + type=str, + default=defaults.SOURCE_SEP, + help="String used to split source string into symbols; " + "an empty string indicates that each Unicode codepoint " + "is its own symbol. Default: %(default)r.", + ) + parser.add_argument( + "--target_sep", + type=str, + default=defaults.TARGET_SEP, + help="String used to split target string into symbols; " + "an empty string indicates that each Unicode codepoint " + "is its own symbol. Default: %(default)r.", + ) + parser.add_argument( + "--features_sep", + type=str, + default=defaults.FEATURES_SEP, + help="String used to split features string into symbols; " + "an empty string indicates that each Unicode codepoint " + "is its own symbol. Default: %(default)r.", + ) + parser.add_argument( + "--tied_vocabulary", + action="store_true", + default=defaults.TIED_VOCABULARY, + help="Share source and target embeddings. Default: %(default)s.", + ) + parser.add_argument( + "--no_tied_vocabulary", + action="store_false", + dest="tied_vocabulary", + default=True, + ) + parser.add_argument( + "--batch_size", + type=int, + default=defaults.BATCH_SIZE, + help="Batch size. Default: %(default)s.", + ) + parser.add_argument( + "--max_source_length", + type=int, + default=defaults.MAX_SOURCE_LENGTH, + help="Maximum source string length. Default: %(default)s.", + ) + parser.add_argument( + "--max_target_length", + type=int, + default=defaults.MAX_TARGET_LENGTH, + help="Maximum target string length. Default: %(default)s.", + ) diff --git a/yoyodyne/batches.py b/yoyodyne/data/batches.py similarity index 100% rename from yoyodyne/batches.py rename to yoyodyne/data/batches.py diff --git a/yoyodyne/collators.py b/yoyodyne/data/collators.py similarity index 82% rename from yoyodyne/collators.py rename to yoyodyne/data/collators.py index ddc0a712..217b88f7 100644 --- a/yoyodyne/collators.py +++ b/yoyodyne/data/collators.py @@ -1,57 +1,30 @@ """Collators and related utilities.""" import argparse +import dataclasses from typing import List import torch -from . import batches, datasets, defaults, util +from .. import defaults, util +from . import batches, datasets class LengthError(Exception): pass +@dataclasses.dataclass class Collator: """Pads data.""" pad_idx: int - features_offset: int has_features: bool has_target: bool - max_source_length: int - max_target_length: int separate_features: bool - - def __init__( - self, - dataset: datasets.BaseDataset, - arch: str, - max_source_length: int = defaults.MAX_SOURCE_LENGTH, - max_target_length: int = defaults.MAX_TARGET_LENGTH, - ): - """Initializes the collator. - - Args: - dataset (dataset.BaseDataset). - arch (str). - max_source_length (int). - max_target_length (int). - """ - self.index = dataset.index - self.pad_idx = self.index.pad_idx - self.config = dataset.config - self.has_features = self.config.has_features - self.has_target = self.config.has_target - self.max_source_length = max_source_length - self.max_target_length = max_target_length - self.features_offset = ( - dataset.index.source_vocab_size if self.has_features else 0 - ) - self.separate_features = dataset.config.has_features and arch in [ - "pointer_generator_lstm", - "transducer", - ] + features_offset: int + max_source_length: int = defaults.MAX_SOURCE_LENGTH + max_target_length: int = defaults.MAX_TARGET_LENGTH def _source_length_error(self, padded_length: int) -> None: """Callback function to raise the error when the padded length of the diff --git a/yoyodyne/data/datamodules.py b/yoyodyne/data/datamodules.py new file mode 100644 index 00000000..22ef289a --- /dev/null +++ b/yoyodyne/data/datamodules.py @@ -0,0 +1,203 @@ +"""Data modules.""" + +from typing import Iterator, Optional, Set + +import pytorch_lightning as pl +from torch.utils import data + +from .. import defaults, util +from . import collators, datasets, indexes, tsv + + +class DataModule(pl.LightningDataModule): + """Parses, indexes, collates and loads data.""" + + parser: tsv.TsvParser + index: indexes.Index + batch_size: int + collator: collators.Collator + + def __init__( + self, + # Paths. + *, + train: Optional[str] = None, + val: Optional[str] = None, + predict: Optional[str] = None, + test: Optional[str] = None, + index_path: Optional[str] = None, + # TSV parsing arguments. + source_col: int = defaults.SOURCE_COL, + features_col: int = defaults.FEATURES_COL, + target_col: int = defaults.TARGET_COL, + # String parsing arguments. + source_sep: str = defaults.SOURCE_SEP, + features_sep: str = defaults.FEATURES_SEP, + target_sep: str = defaults.TARGET_SEP, + # Vocabulary options. + tied_vocabulary: bool = defaults.TIED_VOCABULARY, + # Collator options. + batch_size=defaults.BATCH_SIZE, + separate_features: bool = False, + max_source_length: int = defaults.MAX_SOURCE_LENGTH, + max_target_length: int = defaults.MAX_TARGET_LENGTH, + # Indexing. + index: Optional[indexes.Index] = None, + ): + super().__init__() + self.parser = tsv.TsvParser( + source_col=source_col, + features_col=features_col, + target_col=target_col, + source_sep=source_sep, + features_sep=features_sep, + target_sep=target_sep, + ) + self.train = train + self.val = val + self.predict = predict + self.test = test + self.batch_size = batch_size + self.separate_features = separate_features + self.index = ( + index if index is not None else self._make_index(tied_vocabulary) + ) + self.collator = collators.Collator( + pad_idx=self.index.pad_idx, + has_features=self.index.has_features, + has_target=self.index.has_target, + separate_features=separate_features, + features_offset=self.index.source_vocab_size + if self.index.has_features + else 0, + max_source_length=max_source_length, + max_target_length=max_target_length, + ) + + def _make_index(self, tied_vocabulary: bool) -> indexes.Index: + # Computes index. + source_vocabulary: Set[str] = set() + features_vocabulary: Set[str] = set() + target_vocabulary: Set[str] = set() + for path in self.paths: + if self.parser.has_features: + if self.parser.has_target: + for source, features, target in self.parser.samples(path): + source_vocabulary.update(source) + features_vocabulary.update(features) + target_vocabulary.update(target) + else: + for source, features in self.parser.samples(path): + source_vocabulary.update(source) + features_vocabulary.update(features) + elif self.parser.has_target: + for source, target in self.parser.samples(path): + source_vocabulary.update(source) + target_vocabulary.update(target) + else: + for source in self.parser.samples(path): + source_vocabulary.update(source) + if self.parser.has_target and tied_vocabulary: + source_vocabulary.update(target_vocabulary) + target_vocabulary.update(source_vocabulary) + return indexes.Index( + source_vocabulary=sorted(source_vocabulary), + features_vocabulary=sorted(features_vocabulary) + if features_vocabulary + else None, + target_vocabulary=sorted(target_vocabulary) + if target_vocabulary + else None, + ) + + # Helpers. + + @property + def paths(self) -> Iterator[str]: + if self.train is not None: + yield self.train + if self.val is not None: + yield self.val + if self.predict is not None: + yield self.predict + if self.test is not None: + yield self.test + + def log_vocabularies(self) -> None: + """Logs this module's vocabularies.""" + util.log_info(f"Source vocabulary: {self.index.source_map.pprint()}") + if self.index.has_features: + util.log_info( + f"Features vocabulary: {self.index.features_map.pprint()}" + ) + if self.index.has_target: + util.log_info( + f"Target vocabulary: {self.index.target_map.pprint()}" + ) + + def write_index(self, model_dir: str, experiment: str) -> None: + """Writes the index.""" + self.index.write(model_dir, experiment) + + @property + def has_features(self) -> int: + return self.index.has_features + + @property + def has_target(self) -> int: + return self.index.has_target + + @property + def source_vocab_size(self) -> int: + if self.separate_features: + return self.index.source_vocab_size + else: + return ( + self.index.source_vocab_size + self.index.features_vocab_size + ) + + def _dataset(self, path: str) -> datasets.Dataset: + return datasets.Dataset( + list(self.parser.samples(path)), + self.index, + self.parser, + ) + + # Required API. + + def train_dataloader(self) -> data.DataLoader: + assert self.train is not None, "no train path" + return data.DataLoader( + self._dataset(self.train), + collate_fn=self.collator, + batch_size=self.batch_size, + shuffle=True, + num_workers=1, + ) + + def val_dataloader(self) -> data.DataLoader: + assert self.val is not None, "no val path" + return data.DataLoader( + self._dataset(self.val), + collate_fn=self.collator, + batch_size=2 * self.batch_size, # Because no gradients. + num_workers=1, + ) + + def predict_dataloader(self) -> data.DataLoader: + assert self.predict is not None, "no predict path" + return data.DataLoader( + self._dataset(self.predict), + collate_fn=self.collator, + batch_size=2 * self.batch_size, # Because no gradients. + num_workers=1, + ) + + def test_dataloader(self) -> data.DataLoader: + assert self.test is not None, "no test path" + return data.DataLoader( + self._dataset(self.test), + collate_fn=self.collator, + batch_size=2 * self.batch_size, # Because no gradients. + num_workers=1, + ) diff --git a/yoyodyne/data/datasets.py b/yoyodyne/data/datasets.py new file mode 100644 index 00000000..267379dc --- /dev/null +++ b/yoyodyne/data/datasets.py @@ -0,0 +1,230 @@ +"""Datasets and related utilities. + +Anything which has a tensor member should inherit from nn.Module, run the +superclass constructor, and register the tensor as a buffer. This enables the +Trainer to move them to the appropriate device.""" + +import dataclasses + +from typing import Iterator, List, Optional + +import torch +from torch import nn +from torch.utils import data + +from .. import special + +from . import indexes, tsv + + +class Item(nn.Module): + """Source tensor, with optional features and target tensors. + + This represents a single item or observation.""" + + source: torch.Tensor + features: Optional[torch.Tensor] + target: Optional[torch.Tensor] + + def __init__(self, source, features=None, target=None): + """Initializes the item. + + Args: + source (torch.Tensor). + features (torch.Tensor, optional). + target (torch.Tensor, optional). + """ + super().__init__() + self.register_buffer("source", source) + self.register_buffer("features", features) + self.register_buffer("target", target) + + @property + def has_features(self): + return self.features is not None + + @property + def has_target(self): + return self.target is not None + + +@dataclasses.dataclass +class Dataset(data.Dataset): + """Datatset class.""" + + samples: List[List[str]] + index: indexes.Index # Usually copied from the DataModule. + parser: tsv.TsvParser # Ditto. + + @property + def has_features(self) -> bool: + return self.index.has_features + + @property + def has_target(self) -> bool: + return self.index.has_target + + def _encode( + self, + symbols: List[str], + symbol_map: indexes.SymbolMap, + ) -> torch.Tensor: + """Encodes a sequence as a tensor of indices with string boundary IDs. + + Args: + symbols (List[str]): symbols to be encoded. + symbol_map (indexes.SymbolMap): symbol map to encode with. + + Returns: + torch.Tensor: the encoded tensor. + """ + return torch.tensor( + [ + symbol_map.index(symbol, self.index.unk_idx) + for symbol in symbols + ], + dtype=torch.long, + ) + + def encode_source(self, symbols: List[str]) -> torch.Tensor: + """Encodes a source string, padding with start and end tags. + + Args: + symbols (List[str]). + + Returns: + torch.Tensor. + """ + wrapped = [special.START] + wrapped.extend(symbols) + wrapped.append(special.END) + return self._encode(wrapped, self.index.source_map) + + def encode_features(self, symbols: List[str]) -> torch.Tensor: + """Encodes a features string. + + Args: + symbols (List[str]). + + Returns: + torch.Tensor. + """ + return self._encode(symbols, self.index.features_map) + + def encode_target(self, symbols: List[str]) -> torch.Tensor: + """Encodes a features string, padding with end tags. + + Args: + symbols (List[str]). + + Returns: + torch.Tensor. + """ + wrapped = symbols + wrapped.append(special.END) + return self._encode(wrapped, self.index.target_map) + + # Decoding. + + def _decode( + self, + indices: torch.Tensor, + symbol_map: indexes.SymbolMap, + ) -> Iterator[List[str]]: + """Decodes the tensor of indices into lists of symbols. + + Args: + indices (torch.Tensor): 2d tensor of indices. + symbol_map (indexes.SymbolMap). + + Yields: + List[str]: Decoded symbols. + """ + for idx in indices.cpu().numpy(): + yield [ + symbol_map.symbol(c) + for c in idx + if c not in self.index.special_idx + ] + + def decode_source( + self, + indices: torch.Tensor, + ) -> Iterator[str]: + """Decodes a source tensor. + + Args: + indices (torch.Tensor): 2d tensor of indices. + + Yields: + str: Decoded source strings. + """ + for symbols in self._decode(indices, self.index.source_map): + yield self.parser.source_string(symbols) + + def decode_features( + self, + indices: torch.Tensor, + ) -> Iterator[str]: + """Decodes a features tensor. + + Args: + indices (torch.Tensor): 2d tensor of indices. + + Yields: + str: Decoded features strings. + """ + for symbols in self._decode(indices, self.index.target_map): + yield self.parser.feature_string(symbols) + + def decode_target( + self, + indices: torch.Tensor, + ) -> Iterator[str]: + """Decodes a target tensor. + + Args: + indices (torch.Tensor): 2d tensor of indices. + + Yields: + str: Decoded target strings. + """ + for symbols in self._decode(indices, self.index.target_map): + yield self.parser.target_string(symbols) + + # Required API. + + def __len__(self) -> int: + return len(self.samples) + + def __getitem__(self, idx: int) -> Item: + """Retrieves item by index. + + Args: + idx (int). + + Returns: + Item. + """ + if self.index.has_features: + if self.index.has_target: + source, features, target = self.samples[idx] + return Item( + source=self.encode_source(source), + features=self.encode_features(features), + target=self.encode_target(target), + ) + else: + source, features = self.samples[idx] + return Item( + source=self.encode_source(source), + features=self.encode_features(features), + ) + elif self.index.has_target: + source, target = self.samples[idx] + return Item( + source=self.encode_source(source), + target=self.encode_target(target), + ) + else: + return Item(source=self.encode_source(source)) diff --git a/yoyodyne/indexes.py b/yoyodyne/data/indexes.py similarity index 88% rename from yoyodyne/indexes.py rename to yoyodyne/data/indexes.py index 22ccedb5..d7c2d6d5 100644 --- a/yoyodyne/indexes.py +++ b/yoyodyne/data/indexes.py @@ -4,7 +4,7 @@ import pickle from typing import Dict, List, Optional, Set -from . import special +from .. import special class SymbolMap: @@ -84,39 +84,46 @@ def __init__( # Serialization support. - @staticmethod - def index_path(model_dir: str, experiment: str) -> str: - """Computes the index path. + @classmethod + def read(cls, model_dir: str, experiment: str) -> "Index": + """Loads index. Args: model_dir (str). experiment (str). Returns: - str. - """ - return f"{model_dir}/{experiment}/index.pkl" - - @classmethod - def read(cls, path: str): - """Loads symbol mappings. - - Args: - path (str): input path. + Index. """ index = cls.__new__(cls) + path = index.index_path(model_dir, experiment) with open(path, "rb") as source: dictionary = pickle.load(source) for key, value in dictionary.items(): setattr(index, key, value) return index - def write(self, path: str) -> None: - """Saves index. + @staticmethod + def index_path(model_dir: str, experiment: str) -> str: + """Computes path for the index file. + + Args: + model_dir (str). + experiment (str). + + Returns: + str. + """ + return f"{model_dir}/{experiment}/index.pkl" + + def write(self, model_dir: str, experiment: str) -> None: + """Writes index. Args: - path (str): output path. + model_dir (str). + experiment (str). """ + path = self.index_path(model_dir, experiment) os.makedirs(os.path.dirname(path), exist_ok=True) with open(path, "wb") as sink: pickle.dump(vars(self), sink) diff --git a/yoyodyne/data/tsv.py b/yoyodyne/data/tsv.py new file mode 100644 index 00000000..5087ffda --- /dev/null +++ b/yoyodyne/data/tsv.py @@ -0,0 +1,150 @@ +"""TSV parsing. + +The TsvParser yields data from TSV files using 1-based indexing and custom +separators. +""" + + +import csv +import dataclasses +from typing import Iterator, List, Tuple, Union + +from .. import defaults + + +class Error(Exception): + """Module-specific exception.""" + + pass + + +@dataclasses.dataclass +class TsvParser: + """Streams data from a TSV file. + + Args: + source_col (int, optional): 1-indexed column in TSV containing + source strings. + features_col (int, optional): 1-indexed column in TSV containing + features strings. + target_col (int, optional): 1-indexed column in TSV containing + target strings. + source_sep (str, optional): string used to split source string into + symbols; an empty string indicates that each Unicode codepoint is + its own symbol. + features_sep (str, optional): string used to split features string into + symbols; an empty string indicates that each Unicode codepoint is + its own symbol. + target_sep (str, optional): string used to split target string into + symbols; an empty string indicates that each Unicode codepoint is + its own symbol. + """ + + source_col: int = defaults.SOURCE_COL + features_col: int = defaults.FEATURES_COL + target_col: int = defaults.TARGET_COL + source_sep: str = defaults.SOURCE_SEP + features_sep: str = defaults.FEATURES_SEP + target_sep: str = defaults.TARGET_SEP + + def __post_init__(self) -> None: + # This is automatically called after initialization. + if self.source_col < 1: + raise Error(f"Out of range source column: {self.source_col}") + if self.features_col < 0: + raise Error(f"Out of range features column: {self.features_col}") + if self.features_col < 0: + raise Error(f"Out of range features column: {self.features_col}") + if self.target_col < 0: + raise Error(f"Out of range target column: {self.target_col}") + + @staticmethod + def _tsv_reader(path: str) -> Iterator[str]: + with open(path, "r") as tsv: + yield from csv.reader(tsv, delimiter="\t") + + @staticmethod + def _get_string(row: List[str], col: int) -> str: + """Returns a string from a row by index. + + Args: + row (List[str]): the split row. + col (int): the column index. + Returns: + str: symbol from that string. + """ + return row[col - 1] # -1 because we're using one-based indexing. + + @property + def has_features(self) -> bool: + return self.features_col != 0 + + @property + def has_target(self) -> bool: + return self.target_col != 0 + + def samples( + self, path: str + ) -> Iterator[ + Union[ + List[str], + Tuple[List[str], List[str]], + Tuple[List[str], List[str], List[str]], + ] + ]: + """Yields source, and features and/or target if available.""" + for row in self._tsv_reader(path): + source = self.source_symbols( + self._get_string(row, self.source_col) + ) + if self.has_features: + features = self.features_symbols( + self._get_string(row, self.features_col) + ) + if self.has_target: + target = self.target_symbols( + self._get_string(row, self.target_col) + ) + yield source, features, target + else: + yield source, features + elif self.has_target: + target = self.target_symbols( + self._get_string(row, self.target_col) + ) + yield source, target + else: + yield source + + # String parsing methods. + + @staticmethod + def _get_symbols(string: str, sep: str) -> List[str]: + return list(string) if not sep else string.split(sep) + + def source_symbols(self, string: str) -> List[str]: + return self._get_symbols(string, self.source_sep) + + def features_symbols(self, string: str) -> List[str]: + # We deliberately obfuscate these to avoid overlap with source. + return [ + f"[{symbol}]" + for symbol in self._get_symbols(string, self.features_sep) + ] + + def target_symbols(self, string: str) -> List[str]: + return self._get_symbols(string, self.target_sep) + + # Deserialization methods. + + def source_string(self, symbols: List[str]) -> str: + return self.source_sep.join(symbols) + + def features_string(self, symbols: List[str]) -> str: + return self.features_sep.join( + # This indexing strips off the obfuscation. + [symbol[1:-1] for symbol in symbols], + ) + + def target_string(self, symbols: List[str]) -> str: + return self.target_sep.join(symbols) diff --git a/yoyodyne/dataconfig.py b/yoyodyne/dataconfig.py deleted file mode 100644 index 728f9e0b..00000000 --- a/yoyodyne/dataconfig.py +++ /dev/null @@ -1,227 +0,0 @@ -"""Dataset config class.""" - -import argparse -import csv -import dataclasses -import inspect -from typing import Iterator, List, Tuple - -from . import defaults, util - - -class Error(Exception): - """Module-specific exception.""" - - pass - - -@dataclasses.dataclass -class DataConfig: - """Configuration specifications for a dataset. - - Args: - source_col (int, optional): 1-indexed column in TSV containing - source strings. - target_col (int, optional): 1-indexed column in TSV containing - target strings. - features_col (int, optional): 1-indexed column in TSV containing - features strings. - source_sep (str, optional): separator character between symbol in - source string. "" treats each character in source as a symbol. - target_sep (str, optional): separator character between symbol in - target string. "" treats each character in target as a symbol. - features_sep (str, optional): separator character between symbol in - features string. "" treats each character in features as a symbol. - tied_vocabulary (bool, optional): whether the source and target - should share a vocabulary. - """ - - source_col: int = defaults.SOURCE_COL - target_col: int = defaults.TARGET_COL - features_col: int = defaults.FEATURES_COL - source_sep: str = defaults.SOURCE_SEP - target_sep: str = defaults.TARGET_SEP - features_sep: str = defaults.FEATURES_SEP - tied_vocabulary: bool = defaults.TIED_VOCABULARY - - def __post_init__(self) -> None: - # This is automatically called after initialization. - if self.source_col < 1: - raise Error(f"Invalid source column: {self.source_col}") - if self.target_col < 0: - raise Error(f"Invalid target column: {self.target_col}") - if self.target_col == 0: - util.log_info("Ignoring targets in input") - if self.features_col < 0: - raise Error(f"Invalid features column: {self.features_col}") - if self.features_col != 0: - util.log_info("Including features") - - @classmethod - def from_argparse_args(cls, args, **kwargs): - """Creates an instance from CLI arguments.""" - params = vars(args) - valid_kwargs = inspect.signature(cls.__init__).parameters - dataconfig_kwargs = { - name: params[name] for name in valid_kwargs if name in params - } - dataconfig_kwargs.update(**kwargs) - return cls(**dataconfig_kwargs) - - @staticmethod - def _get_cell(row: List[str], col: int, sep: str) -> List[str]: - """Returns the split cell of a row. - - Args: - row (List[str]): the split row. - col (int): the column index - sep (str): the string to split the column on; if the empty string, - the column is split into characters instead. - - Returns: - List[str]: symbol from that cell. - """ - cell = row[col - 1] # -1 because we're using one-based indexing. - return list(cell) if not sep else cell.split(sep) - - # Source is always present. - - @property - def has_target(self) -> bool: - return self.target_col != 0 - - @property - def has_features(self) -> bool: - return self.features_col != 0 - - def source_samples(self, filename: str) -> Iterator[List[str]]: - """Yields source.""" - with open(filename, "r") as source: - tsv_reader = csv.reader(source, delimiter="\t") - for row in tsv_reader: - yield self._get_cell(row, self.source_col, self.source_sep) - - def source_target_samples( - self, filename: str - ) -> Iterator[Tuple[List[str], List[str]]]: - """Yields source and target.""" - with open(filename, "r") as source: - tsv_reader = csv.reader(source, delimiter="\t") - for row in tsv_reader: - source = self._get_cell(row, self.source_col, self.source_sep) - target = self._get_cell(row, self.target_col, self.target_sep) - yield source, target - - def source_features_target_samples( - self, filename: str - ) -> Iterator[Tuple[List[str], List[str], List[str]]]: - """Yields source, features, and target.""" - with open(filename, "r") as source: - tsv_reader = csv.reader(source, delimiter="\t") - for row in tsv_reader: - source = self._get_cell(row, self.source_col, self.source_sep) - # Avoids overlap with source. - features = [ - f"[{feature}]" - for feature in self._get_cell( - row, self.features_col, self.features_sep - ) - ] - target = self._get_cell(row, self.target_col, self.target_sep) - yield source, features, target - - def source_features_samples( - self, filename: str - ) -> Iterator[Tuple[List[str], List[str]]]: - """Yields source, and features.""" - with open(filename, "r") as source: - tsv_reader = csv.reader(source, delimiter="\t") - for row in tsv_reader: - source = self._get_cell(row, self.source_col, self.source_sep) - # Avoids overlap with source. - features = [ - f"[{feature}]" - for feature in self._get_cell( - row, self.features_col, self.features_sep - ) - ] - yield source, features - - def samples(self, filename: str) -> Iterator[Tuple[List[str], ...]]: - """Picks the right one for this config.""" - if self.has_features: - return ( - self.source_features_target_samples(filename) - if self.has_target - else self.source_features_samples(filename) - ) - else: - return ( - self.source_target_samples(filename) - if self.has_target - else self.source_samples(filename) - ) - - @staticmethod - def add_argparse_args(parser: argparse.ArgumentParser) -> None: - """Adds data configuration options to the argument parser. - - Args: - parser (argparse.ArgumentParser). - """ - parser.add_argument( - "--source_col", - type=int, - default=defaults.SOURCE_COL, - help="1-based index for source column. Default: %(default)s.", - ) - parser.add_argument( - "--target_col", - type=int, - default=defaults.TARGET_COL, - help="1-based index for target column. Default: %(default)s.", - ) - parser.add_argument( - "--features_col", - type=int, - default=defaults.FEATURES_COL, - help="1-based index for features column; " - "0 indicates the model will not use features. " - "Default: %(default)s.", - ) - parser.add_argument( - "--source_sep", - type=str, - default=defaults.SOURCE_SEP, - help="String used to split source string into symbols; " - "an empty string indicates that each Unicode codepoint " - "is its own symbol. Default: %(default)r.", - ) - parser.add_argument( - "--target_sep", - type=str, - default=defaults.TARGET_SEP, - help="String used to split target string into symbols; " - "an empty string indicates that each Unicode codepoint " - "is its own symbol. Default: %(default)r.", - ) - parser.add_argument( - "--features_sep", - type=str, - default=defaults.FEATURES_SEP, - help="String used to split features string into symbols; " - "an empty string indicates that each Unicode codepoint " - "is its own symbol. Default: %(default)r.", - ) - parser.add_argument( - "--tied_vocabulary", - action="store_true", - default=defaults.TIED_VOCABULARY, - help="Share source and target embeddings. Default: %(default)s.", - ) - parser.add_argument( - "--no_tied_vocabulary", - action="store_false", - dest="tied_vocabulary", - default=True, - ) diff --git a/yoyodyne/datasets.py b/yoyodyne/datasets.py deleted file mode 100644 index 6966529b..00000000 --- a/yoyodyne/datasets.py +++ /dev/null @@ -1,375 +0,0 @@ -"""Datasets and related utilities. - -Anything which has a tensor member should inherit from nn.Module, run the -superclass constructor, and register the tensor as a buffer. This enables the -Trainer to move them to the appropriate device.""" - -from typing import List, Optional, Set, Union - -import torch -from torch import nn -from torch.utils import data - -from . import dataconfig, indexes, special - - -class Item(nn.Module): - """Source tensor, with optional features and target tensors. - - This represents a single item or observation.""" - - source: torch.Tensor - features: Optional[torch.Tensor] - target: Optional[torch.Tensor] - - def __init__(self, source, features=None, target=None): - """Initializes the item. - - Args: - source (torch.Tensor). - features (torch.Tensor, optional). - target (torch.Tensor, optional). - """ - super().__init__() - self.register_buffer("source", source) - self.register_buffer("features", features) - self.register_buffer("target", target) - - @property - def has_features(self): - return self.features is not None - - @property - def has_target(self): - return self.target is not None - - -class BaseDataset(data.Dataset): - """Base datatset class.""" - - def __init__(self): - super().__init__() - - -class DatasetNoFeatures(BaseDataset): - """Dataset object without feature column.""" - - filename: str - config: dataconfig.DataConfig - samples: List[List[str]] - index: indexes.Index - - def __init__( - self, - filename, - config, - index: Optional[indexes.Index] = None, - ): - """Initializes the dataset. - - Args: - filename (str): input filename. - config (dataconfig.DataConfig): dataset configuration. - other (indexes.Index, optional): if provided, - use this index to avoid recomputing it. - """ - super().__init__() - self.config = config - self.samples = list(self.config.samples(filename)) - self.index = index if index is not None else self._make_index() - - def _make_index(self) -> indexes.Index: - """Generates index.""" - source_vocabulary: Set[str] = set() - target_vocabulary: Set[str] = set() - if self.config.has_target: - for source, target in self.samples: - source_vocabulary.update(source) - target_vocabulary.update(target) - if self.config.tied_vocabulary: - source_vocabulary.update(target_vocabulary) - target_vocabulary.update(source_vocabulary) - else: - for source in self.samples: - source_vocabulary.update(source) - return indexes.Index( - source_vocabulary=sorted(source_vocabulary), - target_vocabulary=sorted(target_vocabulary), - ) - - def encode( - self, - symbol_map: indexes.SymbolMap, - word: List[str], - add_start_tag: bool = True, - add_end_tag: bool = True, - ) -> torch.Tensor: - """Encodes a sequence as a tensor of indices with word boundary IDs. - - Args: - symbol_map (indexes.SymbolMap). - word (List[str]): word to be encoded. - add_start_tag (bool, optional): whether the sequence should be - prepended with a start tag. - add_end_tag (bool, optional): whether the sequence should be - prepended with a end tag. - - Returns: - torch.Tensor: the encoded tensor. - """ - sequence = [] - if add_start_tag: - sequence.append(special.START) - sequence.extend(word) - if add_end_tag: - sequence.append(special.END) - return torch.tensor( - [ - symbol_map.index(symbol, self.index.unk_idx) - for symbol in sequence - ], - dtype=torch.long, - ) - - def _decode( - self, - symbol_map: indexes.SymbolMap, - indices: torch.Tensor, - symbols: bool, - special: bool, - ) -> List[List[str]]: - """Decodes the tensor of indices into symbols. - - Args: - symbol_map (indexes.SymbolMap). - indices (torch.Tensor): 2d tensor of indices. - symbols (bool): whether to include the regular symbols when - decoding the string. - special (bool): whether to include the special symbols when - decoding the string. - - Returns: - List[List[str]]: decoded symbols. - """ - - def include(c: int) -> bool: - """Whether to include the symbol when decoding. - - Args: - c (int): a single symbol index. - - Returns: - bool: whether to include the symbol. - """ - include = False - is_special_char = c in self.index.special_idx - if special: - include |= is_special_char - if symbols: - # Symbols will be anything that is not SPECIAL. - include |= not is_special_char - return include - - decoded = [] - for index in indices.cpu().numpy(): - decoded.append([symbol_map.symbol(c) for c in index if include(c)]) - return decoded - - def decode_source( - self, - indices: torch.Tensor, - symbols: bool = True, - special: bool = True, - ) -> List[List[str]]: - """Given a tensor of source indices, returns lists of symbols. - - Args: - indices (torch.Tensor): 2d tensor of indices. - symbols (bool, optional): whether to include the regular symbols - vocabulary when decoding the string. - special (bool, optional): whether to include the special symbols - when decoding the string. - - Returns: - List[List[str]]: decoded symbols. - """ - return self._decode( - self.index.source_map, - indices, - symbols=symbols, - special=special, - ) - - def decode_target( - self, - indices: torch.Tensor, - symbols: bool = True, - special: bool = True, - ) -> List[List[str]]: - """Given a tensor of target indices, returns lists of symbols. - - Args: - indices (torch.Tensor): 2d tensor of indices. - special (bool, optional): whether to include the regular symbol - vocabulary when decoding the string. - special (bool, optional): whether to include the special symbols - when decoding the string. - - Returns: - List[List[str]]: decoded symbols. - """ - return self._decode( - self.index.target_map, - indices, - symbols=symbols, - special=special, - ) - - @property - def has_features(self) -> bool: - return self.index.has_features - - @staticmethod - def read_index(path: str) -> indexes.Index: - """Helper for loading index. - - Args: - path (str). - - Returns: - indexes.IndexNoFeatures. - """ - return indexes.Index.read(path) - - def __getitem__(self, idx: int) -> Item: - """Retrieves item by index. - - Args: - idx (int). - - Returns: - Item. - """ - if self.config.has_target: - source, target = self.samples[idx] - else: - source = self.samples[idx] - source_encoded = self.encode(self.index.source_map, source) - if self.config.has_target: - target_encoded = self.encode( - self.index.target_map, target, add_start_tag=False - ) - return Item(source_encoded, target=target_encoded) - else: - return Item(source_encoded) - - def __len__(self) -> int: - return len(self.samples) - - -class DatasetFeatures(DatasetNoFeatures): - """Dataset object with feature column.""" - - def _make_index(self) -> indexes.Index: - """Generates index. - - Same as in superclass, but also handles features. - """ - source_vocabulary: Set[str] = set() - features_vocabulary: Set[str] = set() - target_vocabulary: Set[str] = set() - if self.config.has_target: - for source, features, target in self.samples: - source_vocabulary.update(source) - features_vocabulary.update(features) - target_vocabulary.update(target) - if self.config.tied_vocabulary: - source_vocabulary.update(target_vocabulary) - target_vocabulary.update(source_vocabulary) - else: - for source, features in self.samples: - source_vocabulary.update(source) - features_vocabulary.update(features) - return indexes.Index( - source_vocabulary=sorted(source_vocabulary), - features_vocabulary=sorted(features_vocabulary), - target_vocabulary=sorted(target_vocabulary), - ) - - def __getitem__(self, idx: int) -> Item: - """Retrieves item by index. - - Args: - idx (int). - - Returns: - Item. - """ - if self.config.has_target: - source, features, target = self.samples[idx] - else: - source, features = self.samples[idx] - source_encoded = self.encode(self.index.source_map, source) - features_encoded = self.encode( - self.index.features_map, - features, - add_start_tag=False, - add_end_tag=False, - ) - if self.config.has_target: - return Item( - source_encoded, - target=self.encode( - self.index.target_map, target, add_start_tag=False - ), - features=features_encoded, - ) - else: - return Item(source_encoded, features=features_encoded) - - def decode_features( - self, - indices: torch.Tensor, - symbols: bool = True, - special: bool = True, - ) -> List[List[str]]: - """Given a tensor of feature indices, returns lists of symbols. - - Args: - indices (torch.Tensor): 2d tensor of indices. - symbols (bool, optional): whether to include the regular symbols - when decoding the string. - special (bool, optional): whether to include the special symbols - when decoding the string. - - Returns: - List[List[str]]: decoded symbols. - """ - return self._decode( - self.index.features_map, - indices, - symbols=symbols, - special=special, - ) - - -def get_dataset( - filename: str, - config: dataconfig.DataConfig, - index: Union[indexes.Index, str, None] = None, -) -> data.Dataset: - """Dataset factory. - - Args: - filename (str): input filename. - config (dataconfig.DataConfig): dataset configuration. - index (Union[index.Index, str], optional): input index file, - or path to index.pkl file. - - Returns: - data.Dataset: the dataset. - """ - cls = DatasetFeatures if config.has_features else DatasetNoFeatures - if isinstance(index, str): - index = indexes.Index.read(index) - return cls(filename, config, index) diff --git a/yoyodyne/models/__init__.py b/yoyodyne/models/__init__.py index 9a645398..201dfdf1 100644 --- a/yoyodyne/models/__init__.py +++ b/yoyodyne/models/__init__.py @@ -49,7 +49,7 @@ def get_model_cls_from_argparse_args( Returns: BaseEncoderDecoder. """ - return get_model_cls(args.arch, args.features_col != 0) + return get_model_cls(args.arch) def add_argparse_args(parser: argparse.ArgumentParser) -> None: diff --git a/yoyodyne/models/base.py b/yoyodyne/models/base.py index 72774e0b..2c7db7ea 100644 --- a/yoyodyne/models/base.py +++ b/yoyodyne/models/base.py @@ -10,7 +10,7 @@ import torch from torch import nn, optim -from .. import batches, defaults, evaluators, schedulers, util +from .. import data, defaults, evaluators, schedulers, util from . import modules @@ -235,7 +235,7 @@ def has_features_encoder(self): def training_step( self, - batch: batches.PaddedBatch, + batch: data.PaddedBatch, batch_idx: int, ) -> torch.Tensor: """Runs one step of training. @@ -243,7 +243,7 @@ def training_step( This is called by the PL Trainer. Args: - batch (batches.PaddedBatch) + batch (data.PaddedBatch) batch_idx (int). Returns: @@ -266,7 +266,7 @@ def training_step( def validation_step( self, - batch: batches.PaddedBatch, + batch: data.PaddedBatch, batch_idx: int, ) -> Dict: """Runs one validation step. @@ -274,7 +274,7 @@ def validation_step( This is called by the PL Trainer. Args: - batch (batches.PaddedBatch). + batch (data.PaddedBatch). batch_idx (int). Returns: @@ -316,7 +316,7 @@ def validation_epoch_end(self, validation_step_outputs: Dict) -> Dict: def predict_step( self, - batch: batches.PaddedBatch, + batch: data.PaddedBatch, batch_idx: int, ) -> torch.Tensor: """Runs one predict step. @@ -324,7 +324,7 @@ def predict_step( This is called by the PL Trainer. Args: - batch (batches.PaddedBatch). + batch (data.PaddedBatch). batch_idx (int). Returns: diff --git a/yoyodyne/models/lstm.py b/yoyodyne/models/lstm.py index 846271a6..e656e106 100644 --- a/yoyodyne/models/lstm.py +++ b/yoyodyne/models/lstm.py @@ -7,7 +7,7 @@ import torch from torch import nn -from .. import batches, defaults +from .. import data, defaults from . import base, modules @@ -139,10 +139,10 @@ def decode( finished = torch.logical_or( finished, (decoder_input == self.end_idx) ) - # Breaks when all batches predicted an EOS symbol. - # If we have a target (and are thus computing loss), - # we only break when we have decoded at least the the - # same number of steps as the target length. + # Breaks when all sequences have predicted an EOS symbol. If we + # have a target (and are thus computing loss), we only break + # when we have decoded at least the the same number of steps as + # the target length. if finished.all(): if target is None or decoder_input.size(-1) >= target.size( -1 @@ -279,12 +279,12 @@ def beam_decode( def forward( self, - batch: batches.PaddedBatch, + batch: data.PaddedBatch, ) -> torch.Tensor: """Runs the encoder-decoder model. Args: - batch (batches.PaddedBatch). + batch (data.PaddedBatch). Returns: predictions (torch.Tensor): tensor of predictions of shape diff --git a/yoyodyne/models/modules/linear.py b/yoyodyne/models/modules/linear.py index 4248289a..aeb79611 100644 --- a/yoyodyne/models/modules/linear.py +++ b/yoyodyne/models/modules/linear.py @@ -5,7 +5,7 @@ import torch from torch import nn -from ... import batches +from ... import data from . import base @@ -32,12 +32,12 @@ def init_embeddings( class LinearEncoder(LinearModule): def forward( - self, source: batches.PaddedTensor + self, source: data.PaddedTensor ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """Encodes the input. Args: - source (batches.PaddedTensor): source padded tensors and mask + source (data.PaddedTensor): source padded tensors and mask for source, of shape B x seq_len x 1. Returns: diff --git a/yoyodyne/models/modules/lstm.py b/yoyodyne/models/modules/lstm.py index e982bacb..52598d6e 100644 --- a/yoyodyne/models/modules/lstm.py +++ b/yoyodyne/models/modules/lstm.py @@ -5,7 +5,7 @@ import torch from torch import nn -from ... import batches, defaults +from ... import data, defaults from . import attention, base @@ -58,12 +58,12 @@ def init_embeddings( class LSTMEncoder(LSTMModule): def forward( - self, source: batches.PaddedTensor + self, source: data.PaddedTensor ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """Encodes the input. Args: - source (batches.PaddedTensor): source padded tensors and mask + source (data.PaddedTensor): source padded tensors and mask for source, of shape B x seq_len x 1. Returns: diff --git a/yoyodyne/models/modules/transformer.py b/yoyodyne/models/modules/transformer.py index d2eef484..8a20508b 100644 --- a/yoyodyne/models/modules/transformer.py +++ b/yoyodyne/models/modules/transformer.py @@ -6,7 +6,7 @@ import torch from torch import nn -from ... import batches +from ... import data from . import base @@ -146,11 +146,11 @@ def embed(self, symbols: torch.Tensor) -> torch.Tensor: class TransformerEncoder(TransformerModule): - def forward(self, source: batches.PaddedTensor) -> torch.Tensor: + def forward(self, source: data.PaddedTensor) -> torch.Tensor: """Encodes the source with the TransformerEncoder. Args: - source (batches.PaddedTensor). + source (data.PaddedTensor). Returns: torch.Tensor: sequence of encoded symbols. diff --git a/yoyodyne/models/pointer_generator.py b/yoyodyne/models/pointer_generator.py index 30e9087c..96a4fb74 100644 --- a/yoyodyne/models/pointer_generator.py +++ b/yoyodyne/models/pointer_generator.py @@ -6,7 +6,7 @@ import torch from torch import nn -from .. import batches +from .. import data from . import lstm, modules @@ -289,10 +289,10 @@ def decode( finished = torch.logical_or( finished, (decoder_input == self.end_idx) ) - # Breaks when all batches predicted an EOS symbol. - # If we have a target (and are thus computing loss), - # we only break when we have decoded at least the the - # same number of steps as the target length. + # Breaks when all sequences have predicted an EOS symbol. If we + # have a target (and are thus computing loss), we only break + # when we have decoded at least the the same number of steps as + # the target length. if finished.all(): if target is None or decoder_input.size(-1) >= target.size( -1 @@ -303,12 +303,12 @@ def decode( def forward( self, - batch: batches.PaddedBatch, + batch: data.PaddedBatch, ) -> torch.Tensor: """Runs the encoder-decoder. Args: - batch (batches.PaddedBatch). + batch (data.PaddedBatch). Returns: torch.Tensor. diff --git a/yoyodyne/models/transducer.py b/yoyodyne/models/transducer.py index 7f35fd99..3cb7d96b 100644 --- a/yoyodyne/models/transducer.py +++ b/yoyodyne/models/transducer.py @@ -8,7 +8,7 @@ from maxwell import actions from torch import nn -from .. import batches +from .. import data from . import expert, lstm, modules @@ -71,12 +71,12 @@ def get_decoder(self) -> modules.lstm.LSTMDecoder: def forward( self, - batch: batches.PaddedBatch, + batch: data.PaddedBatch, ) -> Tuple[List[List[int]], torch.Tensor]: """Runs the encoder-decoder model. Args: - batch (batches.PaddedBatch). + batch (data.PaddedBatch). Returns: Tuple[List[List[int]], torch.Tensor] of encoded prediction values @@ -502,15 +502,13 @@ def _get_loss_func( # Prevents base construction of unused loss function. return None - def training_step( - self, batch: batches.PaddedBatch, batch_idx: int - ) -> Dict: + def training_step(self, batch: data.PaddedBatch, batch_idx: int) -> Dict: """Runs one step of training. This is called by the PL Trainer. Args: - batch (batches.PaddedBatch) + batch (data.PaddedBatch) batch_idx (int). Returns: @@ -527,9 +525,7 @@ def training_step( ) return loss - def validation_step( - self, batch: batches.PaddedBatch, batch_idx: int - ) -> Dict: + def validation_step(self, batch: data.PaddedBatch, batch_idx: int) -> Dict: predictions, loss = self(batch) # Evaluation requires prediction as a tensor. predictions = self.convert_prediction(predictions) diff --git a/yoyodyne/models/transformer.py b/yoyodyne/models/transformer.py index 354f7e71..91250a29 100644 --- a/yoyodyne/models/transformer.py +++ b/yoyodyne/models/transformer.py @@ -6,7 +6,7 @@ import torch from torch import nn -from .. import batches, defaults +from .. import data, defaults from . import base, modules @@ -105,10 +105,10 @@ def _decode_greedy( finished = torch.logical_or( finished, (predictions[-1] == self.end_idx) ) - # Breaks when all batches predicted an EOS symbol. - # If we have a target (and are thus computing loss), - # we only break when we have decoded at least the the - # same number of steps as the target length. + # Breaks when all sequences have predicted an EOS symbol. If we + # have a target (and are thus computing loss), we only break when + # we have decoded at least the the same number of steps as the + # target length. if finished.all(): if targets is None or len(outputs) >= targets.size(-1): break @@ -117,12 +117,12 @@ def _decode_greedy( def forward( self, - batch: batches.PaddedBatch, + batch: data.PaddedBatch, ) -> torch.Tensor: """Runs the encoder-decoder. Args: - batch (batches.PaddedBatch). + batch (data.PaddedBatch). Returns: torch.Tensor. diff --git a/yoyodyne/predict.py b/yoyodyne/predict.py index fa2003f7..fa572b62 100644 --- a/yoyodyne/predict.py +++ b/yoyodyne/predict.py @@ -4,9 +4,8 @@ import os import pytorch_lightning as pl -from torch.utils import data -from . import collators, dataconfig, datasets, defaults, models, util +from . import data, models, util def get_trainer_from_argparse_args( @@ -23,51 +22,36 @@ def get_trainer_from_argparse_args( return pl.Trainer.from_argparse_args(args, max_epochs=0) -def get_dataset_from_argparse_args( +def get_datamodule_from_argparse_args( args: argparse.Namespace, -) -> datasets.BaseDataset: +) -> data.DataModule: """Creates the dataset from CLI arguments. Args: args (argparse.Namespace). Returns: - datasets.BaseDataset. + data.DataModule. """ - config = dataconfig.DataConfig.from_argparse_args(args) - return datasets.get_dataset(args.predict, config, args.index) - - -def get_loader( - dataset: datasets.BaseDataset, - arch: str, - batch_size: int, - max_source_length: int, - max_target_length: int, -) -> data.DataLoader: - """Creates the loader. - - Args: - dataset (data.Dataset). - arch (str). - batch_size (int). - max_source_length (int). - max_target_length (int). - - Returns: - data.DataLoader. - """ - collator = collators.Collator( - dataset, - arch, - max_source_length, - max_target_length, - ) - return data.DataLoader( - dataset, - collate_fn=collator, - batch_size=batch_size, - num_workers=1, + separate_features = args.features_col != 0 and args.arch in [ + "pointer_generator_lstm", + "transducer", + ] + index = data.Index.read(args.model_dir, args.experiment) + return data.DataModule( + predict=args.predict, + batch_size=args.batch_size, + source_col=args.source_col, + features_col=args.features_col, + target_col=args.target_col, + source_sep=args.source_sep, + features_sep=args.features_sep, + target_sep=args.target_sep, + tied_vocabulary=args.tied_vocabulary, + separate_features=separate_features, + max_source_length=args.max_source_length, + max_target_length=args.max_target_length, + index=index, ) @@ -99,8 +83,8 @@ def _mkdir(output: str) -> None: def predict( trainer: pl.Trainer, - model: pl.LightningModule, - loader: data.DataLoader, + model: models.BaseEncoderDecoder, + datamodule: data.DataModule, output: str, ) -> None: """Predicts from the model. @@ -108,25 +92,19 @@ def predict( Args: trainer (pl.Trainer). model (pl.LightningModule). - loader (data.DataLoader). + datamdule (data.DataModule). output (str). - target_sep (str). """ - dataset = loader.dataset - target_sep = dataset.config.target_sep util.log_info(f"Writing to {output}") _mkdir(output) + loader = datamodule.predict_dataloader() with open(output, "w") as sink: - for batch in trainer.predict(model, dataloaders=loader): + for batch in trainer.predict(model, loader): batch = model.evaluator.finalize_predictions( - batch, dataset.index.end_idx, dataset.index.pad_idx + batch, datamodule.index.end_idx, datamodule.index.pad_idx ) - for prediction in dataset.decode_target( - batch, - symbols=True, - special=False, - ): - print(target_sep.join(prediction), file=sink) + for prediction in loader.dataset.decode_target(batch): + print(prediction, file=sink) def add_argparse_args(parser: argparse.ArgumentParser) -> None: @@ -135,10 +113,18 @@ def add_argparse_args(parser: argparse.ArgumentParser) -> None: Args: parser (argparse.ArgumentParser). """ + # Path arguments. + parser.add_argument( + "--checkpoint", required=True, help="Path to checkpoint (.ckpt)." + ) + parser.add_argument( + "--model_dir", + required=True, + help="Path to output model directory.", + ) parser.add_argument( "--experiment", required=True, help="Name of experiment." ) - # Path arguments. parser.add_argument( "--predict", required=True, @@ -149,22 +135,10 @@ def add_argparse_args(parser: argparse.ArgumentParser) -> None: required=True, help="Path to prediction output data TSV.", ) - parser.add_argument("--index", required=True, help="Path to index (.pkl).") - parser.add_argument( - "--checkpoint", required=True, help="Path to checkpoint (.ckpt)." - ) - # Predicting arguments. - parser.add_argument( - "--batch_size", - type=int, - default=defaults.BATCH_SIZE, - help="Batch size. Default: %(default)s.", - ) + # Prediction arguments. # TODO: add --beam_width. - # Data configuration arguments. - dataconfig.DataConfig.add_argparse_args(parser) - # Collator arguments. - collators.Collator.add_argparse_args(parser) + # Data arguments. + data.add_argparse_args(parser) # Architecture arguments; the architecture-specific ones are not needed. models.add_argparse_args(parser) # Among the things this adds, the following are likely to be useful: @@ -180,15 +154,9 @@ def main() -> None: args = parser.parse_args() util.log_arguments(args) trainer = get_trainer_from_argparse_args(args) - loader = get_loader( - get_dataset_from_argparse_args(args), - args.arch, - args.batch_size, - args.max_source_length, - args.max_target_length, - ) + datamodule = get_datamodule_from_argparse_args(args) model = get_model_from_argparse_args(args) - predict(trainer, model, loader, args.output) + predict(trainer, model, datamodule, args.output) if __name__ == "__main__": diff --git a/yoyodyne/train.py b/yoyodyne/train.py index 65d68a71..2e5a85e5 100644 --- a/yoyodyne/train.py +++ b/yoyodyne/train.py @@ -1,22 +1,14 @@ """Trains a sequence-to-sequence neural network.""" import argparse -from typing import List, Optional, Tuple +from typing import List, Optional import pytorch_lightning as pl import wandb from pytorch_lightning import callbacks, loggers -from torch.utils import data +from torch.utils import data as torch_data -from . import ( - collators, - dataconfig, - datasets, - defaults, - models, - schedulers, - util, -) +from . import data, defaults, models, schedulers, util class Error(Exception): @@ -43,7 +35,6 @@ def _get_logger(experiment: str, model_dir: str, log_wandb: bool) -> List: wandb.define_metric("val_accuracy", summary="max") # Logs the path to local artifacts made by PTL. wandb.config.update({"local_run_dir": trainer_logger[0].log_dir}) - return trainer_logger @@ -103,84 +94,49 @@ def get_trainer_from_argparse_args( ) -def get_datasets_from_argparse_args( +def get_datamodule_from_argparse_args( args: argparse.Namespace, -) -> Tuple[datasets.BaseDataset, datasets.BaseDataset]: - """Creates the datasets from CLI arguments. +) -> data.DataModule: + """Creates the datamodule from CLI arguments. Args: - args (argparse.Namespace). + args (Argparse.Namespace). Returns: - Tuple[datasets.BaseDataset, datasets.BaseDataset]: the training and - development datasets. + data.DataModule. """ - config = dataconfig.DataConfig.from_argparse_args(args) - if config.target_col == 0: - raise Error("target_col must be specified for training") - train_set = datasets.get_dataset(args.train, config) - dev_set = datasets.get_dataset(args.dev, config, train_set.index) - util.log_info(f"Source vocabulary: {train_set.index.source_map.pprint()}") - if train_set.has_features: - util.log_info( - f"Feature vocabulary: {train_set.index.features_map.pprint()}" - ) - util.log_info(f"Target vocabulary: {train_set.index.target_map.pprint()}") - return train_set, dev_set - - -def get_loaders( - train_set: datasets.BaseDataset, - dev_set: datasets.BaseDataset, - arch: str, - batch_size: int, - max_source_length: int, - max_target_length: int, -) -> Tuple[data.DataLoader, data.DataLoader]: - """Creates the loaders. - - Args: - train_set (datasets.BaseDataset). - dev_set (datasets.BaseDataset). - arch (str). - batch_size (int). - max_source_length (int). - max_target_length (int). - - Returns: - Tuple[data.DataLoader, data.DataLoader]: the training and development - loaders. - """ - collator = collators.Collator( - train_set, - arch, - max_source_length, - max_target_length, - ) - train_loader = data.DataLoader( - train_set, - collate_fn=collator, - batch_size=batch_size, - shuffle=True, - num_workers=1, # Our data loading is simple. - ) - dev_loader = data.DataLoader( - dev_set, - collate_fn=collator, - batch_size=2 * batch_size, # Because we're not collecting gradients. - num_workers=1, + separate_features = args.features_col != 0 and args.arch in [ + "pointer_generator_lstm", + "transducer", + ] + datamodule = data.DataModule( + train=args.train, + val=args.val, + batch_size=args.batch_size, + source_col=args.source_col, + features_col=args.features_col, + target_col=args.target_col, + source_sep=args.source_sep, + features_sep=args.features_sep, + target_sep=args.target_sep, + tied_vocabulary=args.tied_vocabulary, + separate_features=separate_features, + max_source_length=args.max_source_length, + max_target_length=args.max_target_length, ) - return train_loader, dev_loader + datamodule.index.write(args.model_dir, args.experiment) + datamodule.log_vocabularies() + return datamodule def get_model_from_argparse_args( - train_set: datasets.BaseDataset, args: argparse.Namespace, + datamodule: data.DataModule, ) -> models.BaseEncoderDecoder: """Creates the model. Args: - train_set (datasets.BaseDataset). + train_set (data.BaseDataset). args (argparse.Namespace). Returns: @@ -192,7 +148,7 @@ def get_model_from_argparse_args( ) expert = ( models.expert.get_expert( - train_set, + datamodule.train_dataloader().dataset, epochs=args.oracle_em_epochs, oracle_factor=args.oracle_factor, sed_params_path=args.sed_params, @@ -201,7 +157,7 @@ def get_model_from_argparse_args( else None ) scheduler_kwargs = schedulers.get_scheduler_kwargs_from_argparse_args(args) - separate_features = train_set.has_features and args.arch in [ + separate_features = datamodule.has_features and args.arch in [ "pointer_generator_lstm", "transducer", ] @@ -213,12 +169,12 @@ def get_model_from_argparse_args( else None ) features_vocab_size = ( - train_set.index.features_vocab_size if train_set.has_features else 0 + datamodule.index.features_vocab_size if datamodule.has_features else 0 ) source_vocab_size = ( - train_set.index.source_vocab_size + features_vocab_size + datamodule.index.source_vocab_size + features_vocab_size if not separate_features - else train_set.index.source_vocab_size + else datamodule.index.source_vocab_size ) # Please pass all arguments by keyword and keep in lexicographic order. return model_cls( @@ -231,7 +187,7 @@ def get_model_from_argparse_args( dropout=args.dropout, embedding_size=args.embedding_size, encoder_layers=args.encoder_layers, - end_idx=train_set.index.end_idx, + end_idx=datamodule.index.end_idx, expert=expert, features_encoder_cls=features_encoder_cls, features_vocab_size=features_vocab_size, @@ -241,22 +197,22 @@ def get_model_from_argparse_args( max_source_length=args.max_source_length, max_target_length=args.max_target_length, optimizer=args.optimizer, - output_size=train_set.index.target_vocab_size, - pad_idx=train_set.index.pad_idx, + output_size=datamodule.index.target_vocab_size, + pad_idx=datamodule.index.pad_idx, scheduler=args.scheduler, scheduler_kwargs=scheduler_kwargs, source_encoder_cls=source_encoder_cls, source_vocab_size=source_vocab_size, - start_idx=train_set.index.start_idx, - target_vocab_size=train_set.index.target_vocab_size, + start_idx=datamodule.index.start_idx, + target_vocab_size=datamodule.index.target_vocab_size, ) def train( trainer: pl.Trainer, model: models.BaseEncoderDecoder, - train_loader: data.DataLoader, - dev_loader: data.DataLoader, + train_loader: torch_data.DataLoader, + dev_loader: torch_data.DataLoader, train_from: Optional[str] = None, ) -> str: """Trains the model. @@ -285,36 +241,30 @@ def add_argparse_args(parser: argparse.ArgumentParser) -> None: Args: argparse.ArgumentParser. """ + # Path arguments. + parser.add_argument( + "--model_dir", + required=True, + help="Path to output model directory.", + ) parser.add_argument( "--experiment", required=True, help="Name of experiment." ) - # Path arguments. parser.add_argument( "--train", required=True, help="Path to input training data TSV.", ) parser.add_argument( - "--dev", + "--val", required=True, - help="Path to input development data TSV.", - ) - parser.add_argument( - "--model_dir", - required=True, - help="Path to output model directory.", + help="Path to input validation data TSV.", ) parser.add_argument( "--train_from", help="Path to ckpt checkpoint to resume training from.", ) # Other training arguments. - parser.add_argument( - "--batch_size", - type=int, - default=defaults.BATCH_SIZE, - help="Batch size. Default: %(default)s.", - ) parser.add_argument( "--patience", type=int, help="Patience for early stopping." ) @@ -336,10 +286,8 @@ def add_argparse_args(parser: argparse.ArgumentParser) -> None: action="store_false", dest="log_wandb", ) - # Data configuration arguments. - dataconfig.DataConfig.add_argparse_args(parser) - # Collator arguments. - collators.Collator.add_argparse_args(parser) + # Data arguments. + data.add_argparse_args(parser) # Architecture arguments. models.add_argparse_args(parser) models.modules.add_argparse_args(parser) @@ -374,32 +322,19 @@ def main() -> None: util.log_arguments(args) pl.seed_everything(args.seed) trainer = get_trainer_from_argparse_args(args) - train_set, dev_set = get_datasets_from_argparse_args(args) - index = train_set.index.index_path(args.model_dir, args.experiment) - train_set.index.write(index) - util.log_info(f"Index: {index}") - train_loader, dev_loader = get_loaders( - train_set, - dev_set, - args.arch, - args.batch_size, - args.max_source_length, - args.max_target_length, - ) - model = get_model_from_argparse_args(train_set, args) + datamodule = get_datamodule_from_argparse_args(args) + model = get_model_from_argparse_args(args, datamodule) # Tuning options. Batch autoscaling is unsupported; LR tuning logs the # suggested value and then exits. if args.auto_scale_batch_size: raise Error("Batch auto-scaling is not supported") return if args.auto_lr_find: - result = trainer.tuner.lr_find(model, train_loader, dev_loader) + result = trainer.tuner.lr_find(model, datamodule=datamodule) util.log_info(f"Best initial LR: {result.suggestion():.8f}") return # Otherwise, train and log the best checkpoint. - best_checkpoint = train( - trainer, model, train_loader, dev_loader, train_from=args.train_from - ) + best_checkpoint = train(trainer, model, datamodule, args.train_from) util.log_info(f"Best checkpoint: {best_checkpoint}")