From eb60f3720a86691290801c9e30421d3f2f5bf144 Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Thu, 13 Jul 2023 07:13:02 -0400 Subject: [PATCH 01/18] Moves batches to submodule. --- yoyodyne/collators.py | 38 ++++++++++++-------------- yoyodyne/data/__init__.py | 1 + yoyodyne/{ => data}/batches.py | 2 +- yoyodyne/models/base.py | 14 +++++----- yoyodyne/models/lstm.py | 8 +++--- yoyodyne/models/modules/linear.py | 6 ++-- yoyodyne/models/modules/lstm.py | 6 ++-- yoyodyne/models/modules/transformer.py | 6 ++-- yoyodyne/models/pointer_generator.py | 8 +++--- yoyodyne/models/transducer.py | 16 ++++------- yoyodyne/models/transformer.py | 8 +++--- 11 files changed, 53 insertions(+), 60 deletions(-) create mode 100644 yoyodyne/data/__init__.py rename yoyodyne/{ => data}/batches.py (99%) diff --git a/yoyodyne/collators.py b/yoyodyne/collators.py index ddc0a712..f85af6b3 100644 --- a/yoyodyne/collators.py +++ b/yoyodyne/collators.py @@ -5,7 +5,7 @@ import torch -from . import batches, datasets, defaults, util +from . import data, datasets, defaults, util class LengthError(Exception): @@ -102,18 +102,16 @@ def concatenate_source_and_features( for item in itemlist ] - def pad_source( - self, itemlist: List[datasets.Item] - ) -> batches.PaddedTensor: + def pad_source(self, itemlist: List[datasets.Item]) -> data.PaddedTensor: """Pads source. Args: itemlist (List[datasets.Item]). Returns: - batches.PaddedTensor. + data.PaddedTensor. """ - return batches.PaddedTensor( + return data.PaddedTensor( [item.source for item in itemlist], self.pad_idx, self._source_length_error, @@ -122,16 +120,16 @@ def pad_source( def pad_source_features( self, itemlist: List[datasets.Item], - ) -> batches.PaddedTensor: + ) -> data.PaddedTensor: """Pads concatenated source and features. Args: itemlist (List[datasets.Item]). Returns: - batches.PaddedTensor. + data.PaddedTensor. """ - return batches.PaddedTensor( + return data.PaddedTensor( self.concatenate_source_and_features(itemlist), self.pad_idx, self._source_length_error, @@ -140,54 +138,52 @@ def pad_source_features( def pad_features( self, itemlist: List[datasets.Item], - ) -> batches.PaddedTensor: + ) -> data.PaddedTensor: """Pads features. Args: itemlist (List[datasets.Item]). Returns: - batches.PaddedTensor. + data.PaddedTensor. """ - return batches.PaddedTensor( + return data.PaddedTensor( [item.features for item in itemlist], self.pad_idx ) - def pad_target( - self, itemlist: List[datasets.Item] - ) -> batches.PaddedTensor: + def pad_target(self, itemlist: List[datasets.Item]) -> data.PaddedTensor: """Pads target. Args: itemlist (List[datasets.Item]). Returns: - batches.PaddedTensor. + data.PaddedTensor. """ - return batches.PaddedTensor( + return data.PaddedTensor( [item.target for item in itemlist], self.pad_idx, self._target_length_warning, ) - def __call__(self, itemlist: List[datasets.Item]) -> batches.PaddedBatch: + def __call__(self, itemlist: List[datasets.Item]) -> data.PaddedBatch: """Pads all elements of an itemlist. Args: itemlist (List[datasets.Item]). Returns: - batches.PaddedBatch. + data.PaddedBatch. """ padded_target = self.pad_target(itemlist) if self.has_target else None if self.separate_features: - return batches.PaddedBatch( + return data.PaddedBatch( self.pad_source(itemlist), features=self.pad_features(itemlist), target=padded_target, ) else: - return batches.PaddedBatch( + return data.PaddedBatch( self.pad_source_features(itemlist), target=padded_target, ) diff --git a/yoyodyne/data/__init__.py b/yoyodyne/data/__init__.py new file mode 100644 index 00000000..81eb60b7 --- /dev/null +++ b/yoyodyne/data/__init__.py @@ -0,0 +1 @@ +from .batches import PaddedBatch, PaddedTensor # noqa: F401 diff --git a/yoyodyne/batches.py b/yoyodyne/data/batches.py similarity index 99% rename from yoyodyne/batches.py rename to yoyodyne/data/batches.py index dd42a6cf..2e1648b8 100644 --- a/yoyodyne/batches.py +++ b/yoyodyne/data/batches.py @@ -28,7 +28,7 @@ def __init__( ): """Constructs the padded tensor from a list of tensors. - The optional pad_len argument can be used, e.g., to keep all batches + The optional pad_len argument can be used, e.g., to keep all data the exact same length, which improves performance on certain accelerators. If not specified, it will be computed using the length of the longest input tensor. diff --git a/yoyodyne/models/base.py b/yoyodyne/models/base.py index 72774e0b..2c7db7ea 100644 --- a/yoyodyne/models/base.py +++ b/yoyodyne/models/base.py @@ -10,7 +10,7 @@ import torch from torch import nn, optim -from .. import batches, defaults, evaluators, schedulers, util +from .. import data, defaults, evaluators, schedulers, util from . import modules @@ -235,7 +235,7 @@ def has_features_encoder(self): def training_step( self, - batch: batches.PaddedBatch, + batch: data.PaddedBatch, batch_idx: int, ) -> torch.Tensor: """Runs one step of training. @@ -243,7 +243,7 @@ def training_step( This is called by the PL Trainer. Args: - batch (batches.PaddedBatch) + batch (data.PaddedBatch) batch_idx (int). Returns: @@ -266,7 +266,7 @@ def training_step( def validation_step( self, - batch: batches.PaddedBatch, + batch: data.PaddedBatch, batch_idx: int, ) -> Dict: """Runs one validation step. @@ -274,7 +274,7 @@ def validation_step( This is called by the PL Trainer. Args: - batch (batches.PaddedBatch). + batch (data.PaddedBatch). batch_idx (int). Returns: @@ -316,7 +316,7 @@ def validation_epoch_end(self, validation_step_outputs: Dict) -> Dict: def predict_step( self, - batch: batches.PaddedBatch, + batch: data.PaddedBatch, batch_idx: int, ) -> torch.Tensor: """Runs one predict step. @@ -324,7 +324,7 @@ def predict_step( This is called by the PL Trainer. Args: - batch (batches.PaddedBatch). + batch (data.PaddedBatch). batch_idx (int). Returns: diff --git a/yoyodyne/models/lstm.py b/yoyodyne/models/lstm.py index 846271a6..0619a985 100644 --- a/yoyodyne/models/lstm.py +++ b/yoyodyne/models/lstm.py @@ -7,7 +7,7 @@ import torch from torch import nn -from .. import batches, defaults +from .. import data, defaults from . import base, modules @@ -139,7 +139,7 @@ def decode( finished = torch.logical_or( finished, (decoder_input == self.end_idx) ) - # Breaks when all batches predicted an EOS symbol. + # Breaks when all data predicted an EOS symbol. # If we have a target (and are thus computing loss), # we only break when we have decoded at least the the # same number of steps as the target length. @@ -279,12 +279,12 @@ def beam_decode( def forward( self, - batch: batches.PaddedBatch, + batch: data.PaddedBatch, ) -> torch.Tensor: """Runs the encoder-decoder model. Args: - batch (batches.PaddedBatch). + batch (data.PaddedBatch). Returns: predictions (torch.Tensor): tensor of predictions of shape diff --git a/yoyodyne/models/modules/linear.py b/yoyodyne/models/modules/linear.py index 4248289a..aeb79611 100644 --- a/yoyodyne/models/modules/linear.py +++ b/yoyodyne/models/modules/linear.py @@ -5,7 +5,7 @@ import torch from torch import nn -from ... import batches +from ... import data from . import base @@ -32,12 +32,12 @@ def init_embeddings( class LinearEncoder(LinearModule): def forward( - self, source: batches.PaddedTensor + self, source: data.PaddedTensor ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """Encodes the input. Args: - source (batches.PaddedTensor): source padded tensors and mask + source (data.PaddedTensor): source padded tensors and mask for source, of shape B x seq_len x 1. Returns: diff --git a/yoyodyne/models/modules/lstm.py b/yoyodyne/models/modules/lstm.py index e982bacb..52598d6e 100644 --- a/yoyodyne/models/modules/lstm.py +++ b/yoyodyne/models/modules/lstm.py @@ -5,7 +5,7 @@ import torch from torch import nn -from ... import batches, defaults +from ... import data, defaults from . import attention, base @@ -58,12 +58,12 @@ def init_embeddings( class LSTMEncoder(LSTMModule): def forward( - self, source: batches.PaddedTensor + self, source: data.PaddedTensor ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """Encodes the input. Args: - source (batches.PaddedTensor): source padded tensors and mask + source (data.PaddedTensor): source padded tensors and mask for source, of shape B x seq_len x 1. Returns: diff --git a/yoyodyne/models/modules/transformer.py b/yoyodyne/models/modules/transformer.py index d2eef484..8a20508b 100644 --- a/yoyodyne/models/modules/transformer.py +++ b/yoyodyne/models/modules/transformer.py @@ -6,7 +6,7 @@ import torch from torch import nn -from ... import batches +from ... import data from . import base @@ -146,11 +146,11 @@ def embed(self, symbols: torch.Tensor) -> torch.Tensor: class TransformerEncoder(TransformerModule): - def forward(self, source: batches.PaddedTensor) -> torch.Tensor: + def forward(self, source: data.PaddedTensor) -> torch.Tensor: """Encodes the source with the TransformerEncoder. Args: - source (batches.PaddedTensor). + source (data.PaddedTensor). Returns: torch.Tensor: sequence of encoded symbols. diff --git a/yoyodyne/models/pointer_generator.py b/yoyodyne/models/pointer_generator.py index 30e9087c..18e07124 100644 --- a/yoyodyne/models/pointer_generator.py +++ b/yoyodyne/models/pointer_generator.py @@ -6,7 +6,7 @@ import torch from torch import nn -from .. import batches +from .. import data from . import lstm, modules @@ -289,7 +289,7 @@ def decode( finished = torch.logical_or( finished, (decoder_input == self.end_idx) ) - # Breaks when all batches predicted an EOS symbol. + # Breaks when all data predicted an EOS symbol. # If we have a target (and are thus computing loss), # we only break when we have decoded at least the the # same number of steps as the target length. @@ -303,12 +303,12 @@ def decode( def forward( self, - batch: batches.PaddedBatch, + batch: data.PaddedBatch, ) -> torch.Tensor: """Runs the encoder-decoder. Args: - batch (batches.PaddedBatch). + batch (data.PaddedBatch). Returns: torch.Tensor. diff --git a/yoyodyne/models/transducer.py b/yoyodyne/models/transducer.py index 7f35fd99..3cb7d96b 100644 --- a/yoyodyne/models/transducer.py +++ b/yoyodyne/models/transducer.py @@ -8,7 +8,7 @@ from maxwell import actions from torch import nn -from .. import batches +from .. import data from . import expert, lstm, modules @@ -71,12 +71,12 @@ def get_decoder(self) -> modules.lstm.LSTMDecoder: def forward( self, - batch: batches.PaddedBatch, + batch: data.PaddedBatch, ) -> Tuple[List[List[int]], torch.Tensor]: """Runs the encoder-decoder model. Args: - batch (batches.PaddedBatch). + batch (data.PaddedBatch). Returns: Tuple[List[List[int]], torch.Tensor] of encoded prediction values @@ -502,15 +502,13 @@ def _get_loss_func( # Prevents base construction of unused loss function. return None - def training_step( - self, batch: batches.PaddedBatch, batch_idx: int - ) -> Dict: + def training_step(self, batch: data.PaddedBatch, batch_idx: int) -> Dict: """Runs one step of training. This is called by the PL Trainer. Args: - batch (batches.PaddedBatch) + batch (data.PaddedBatch) batch_idx (int). Returns: @@ -527,9 +525,7 @@ def training_step( ) return loss - def validation_step( - self, batch: batches.PaddedBatch, batch_idx: int - ) -> Dict: + def validation_step(self, batch: data.PaddedBatch, batch_idx: int) -> Dict: predictions, loss = self(batch) # Evaluation requires prediction as a tensor. predictions = self.convert_prediction(predictions) diff --git a/yoyodyne/models/transformer.py b/yoyodyne/models/transformer.py index 354f7e71..5b000e91 100644 --- a/yoyodyne/models/transformer.py +++ b/yoyodyne/models/transformer.py @@ -6,7 +6,7 @@ import torch from torch import nn -from .. import batches, defaults +from .. import data, defaults from . import base, modules @@ -105,7 +105,7 @@ def _decode_greedy( finished = torch.logical_or( finished, (predictions[-1] == self.end_idx) ) - # Breaks when all batches predicted an EOS symbol. + # Breaks when all data predicted an EOS symbol. # If we have a target (and are thus computing loss), # we only break when we have decoded at least the the # same number of steps as the target length. @@ -117,12 +117,12 @@ def _decode_greedy( def forward( self, - batch: batches.PaddedBatch, + batch: data.PaddedBatch, ) -> torch.Tensor: """Runs the encoder-decoder. Args: - batch (batches.PaddedBatch). + batch (data.PaddedBatch). Returns: torch.Tensor. From ffe7851480fd3cec32dfabdcac33c82b33b306a9 Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Sat, 15 Jul 2023 19:43:58 -0400 Subject: [PATCH 02/18] Moves datasets in. --- yoyodyne/collators.py | 26 +++++++++++++------------- yoyodyne/data/__init__.py | 6 ++++++ yoyodyne/{ => data}/datasets.py | 2 +- yoyodyne/predict.py | 10 +++++----- yoyodyne/train.py | 30 +++++++++++++++--------------- 5 files changed, 40 insertions(+), 34 deletions(-) rename yoyodyne/{ => data}/datasets.py (99%) diff --git a/yoyodyne/collators.py b/yoyodyne/collators.py index f85af6b3..c24f3707 100644 --- a/yoyodyne/collators.py +++ b/yoyodyne/collators.py @@ -5,7 +5,7 @@ import torch -from . import data, datasets, defaults, util +from . import data, defaults, util class LengthError(Exception): @@ -25,7 +25,7 @@ class Collator: def __init__( self, - dataset: datasets.BaseDataset, + dataset: data.BaseDataset, arch: str, max_source_length: int = defaults.MAX_SOURCE_LENGTH, max_target_length: int = defaults.MAX_TARGET_LENGTH, @@ -90,7 +90,7 @@ def _target_length_warning(self, padded_length: int) -> None: def concatenate_source_and_features( self, - itemlist: List[datasets.Item], + itemlist: List[data.Item], ) -> List[torch.Tensor]: """Concatenates source and feature tensors.""" return [ @@ -102,11 +102,11 @@ def concatenate_source_and_features( for item in itemlist ] - def pad_source(self, itemlist: List[datasets.Item]) -> data.PaddedTensor: + def pad_source(self, itemlist: List[data.Item]) -> data.PaddedTensor: """Pads source. Args: - itemlist (List[datasets.Item]). + itemlist (List[data.Item]). Returns: data.PaddedTensor. @@ -119,12 +119,12 @@ def pad_source(self, itemlist: List[datasets.Item]) -> data.PaddedTensor: def pad_source_features( self, - itemlist: List[datasets.Item], + itemlist: List[data.Item], ) -> data.PaddedTensor: """Pads concatenated source and features. Args: - itemlist (List[datasets.Item]). + itemlist (List[data.Item]). Returns: data.PaddedTensor. @@ -137,12 +137,12 @@ def pad_source_features( def pad_features( self, - itemlist: List[datasets.Item], + itemlist: List[data.Item], ) -> data.PaddedTensor: """Pads features. Args: - itemlist (List[datasets.Item]). + itemlist (List[data.Item]). Returns: data.PaddedTensor. @@ -151,11 +151,11 @@ def pad_features( [item.features for item in itemlist], self.pad_idx ) - def pad_target(self, itemlist: List[datasets.Item]) -> data.PaddedTensor: + def pad_target(self, itemlist: List[data.Item]) -> data.PaddedTensor: """Pads target. Args: - itemlist (List[datasets.Item]). + itemlist (List[data.Item]). Returns: data.PaddedTensor. @@ -166,11 +166,11 @@ def pad_target(self, itemlist: List[datasets.Item]) -> data.PaddedTensor: self._target_length_warning, ) - def __call__(self, itemlist: List[datasets.Item]) -> data.PaddedBatch: + def __call__(self, itemlist: List[data.Item]) -> data.PaddedBatch: """Pads all elements of an itemlist. Args: - itemlist (List[datasets.Item]). + itemlist (List[data.Item]). Returns: data.PaddedBatch. diff --git a/yoyodyne/data/__init__.py b/yoyodyne/data/__init__.py index 81eb60b7..d7600af1 100644 --- a/yoyodyne/data/__init__.py +++ b/yoyodyne/data/__init__.py @@ -1 +1,7 @@ from .batches import PaddedBatch, PaddedTensor # noqa: F401 +from .datasets import ( + BaseDataset, + DatasetNoFeatures, + DatasetFeatures, + get_dataset, +) # noqa: F401 diff --git a/yoyodyne/datasets.py b/yoyodyne/data/datasets.py similarity index 99% rename from yoyodyne/datasets.py rename to yoyodyne/data/datasets.py index 6966529b..dc8d2a2e 100644 --- a/yoyodyne/datasets.py +++ b/yoyodyne/data/datasets.py @@ -10,7 +10,7 @@ from torch import nn from torch.utils import data -from . import dataconfig, indexes, special +from .. import dataconfig, indexes, special class Item(nn.Module): diff --git a/yoyodyne/predict.py b/yoyodyne/predict.py index fa2003f7..7411bbb2 100644 --- a/yoyodyne/predict.py +++ b/yoyodyne/predict.py @@ -6,7 +6,7 @@ import pytorch_lightning as pl from torch.utils import data -from . import collators, dataconfig, datasets, defaults, models, util +from . import collators, dataconfig, data, defaults, models, util def get_trainer_from_argparse_args( @@ -25,21 +25,21 @@ def get_trainer_from_argparse_args( def get_dataset_from_argparse_args( args: argparse.Namespace, -) -> datasets.BaseDataset: +) -> data.BaseDataset: """Creates the dataset from CLI arguments. Args: args (argparse.Namespace). Returns: - datasets.BaseDataset. + data.BaseDataset. """ config = dataconfig.DataConfig.from_argparse_args(args) - return datasets.get_dataset(args.predict, config, args.index) + return data.get_dataset(args.predict, config, args.index) def get_loader( - dataset: datasets.BaseDataset, + dataset: data.BaseDataset, arch: str, batch_size: int, max_source_length: int, diff --git a/yoyodyne/train.py b/yoyodyne/train.py index 65d68a71..1801f493 100644 --- a/yoyodyne/train.py +++ b/yoyodyne/train.py @@ -11,7 +11,7 @@ from . import ( collators, dataconfig, - datasets, + data, defaults, models, schedulers, @@ -103,23 +103,23 @@ def get_trainer_from_argparse_args( ) -def get_datasets_from_argparse_args( +def get_data_from_argparse_args( args: argparse.Namespace, -) -> Tuple[datasets.BaseDataset, datasets.BaseDataset]: - """Creates the datasets from CLI arguments. +) -> Tuple[data.BaseDataset, data.BaseDataset]: + """Creates the data from CLI arguments. Args: args (argparse.Namespace). Returns: - Tuple[datasets.BaseDataset, datasets.BaseDataset]: the training and - development datasets. + Tuple[data.BaseDataset, data.BaseDataset]: the training and + development data. """ config = dataconfig.DataConfig.from_argparse_args(args) if config.target_col == 0: raise Error("target_col must be specified for training") - train_set = datasets.get_dataset(args.train, config) - dev_set = datasets.get_dataset(args.dev, config, train_set.index) + train_set = data.get_dataset(args.train, config) + dev_set = data.get_dataset(args.dev, config, train_set.index) util.log_info(f"Source vocabulary: {train_set.index.source_map.pprint()}") if train_set.has_features: util.log_info( @@ -130,8 +130,8 @@ def get_datasets_from_argparse_args( def get_loaders( - train_set: datasets.BaseDataset, - dev_set: datasets.BaseDataset, + train_set: data.BaseDataset, + dev_set: data.BaseDataset, arch: str, batch_size: int, max_source_length: int, @@ -140,8 +140,8 @@ def get_loaders( """Creates the loaders. Args: - train_set (datasets.BaseDataset). - dev_set (datasets.BaseDataset). + train_set (data.BaseDataset). + dev_set (data.BaseDataset). arch (str). batch_size (int). max_source_length (int). @@ -174,13 +174,13 @@ def get_loaders( def get_model_from_argparse_args( - train_set: datasets.BaseDataset, + train_set: data.BaseDataset, args: argparse.Namespace, ) -> models.BaseEncoderDecoder: """Creates the model. Args: - train_set (datasets.BaseDataset). + train_set (data.BaseDataset). args (argparse.Namespace). Returns: @@ -374,7 +374,7 @@ def main() -> None: util.log_arguments(args) pl.seed_everything(args.seed) trainer = get_trainer_from_argparse_args(args) - train_set, dev_set = get_datasets_from_argparse_args(args) + train_set, dev_set = get_data_from_argparse_args(args) index = train_set.index.index_path(args.model_dir, args.experiment) train_set.index.write(index) util.log_info(f"Index: {index}") From cfee22df4bbeadc9c3e3410b81e19519488f1c8e Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Sat, 15 Jul 2023 19:55:00 -0400 Subject: [PATCH 03/18] Fixes gaps from previous commit. --- yoyodyne/data/__init__.py | 1 + yoyodyne/predict.py | 12 ++++++------ yoyodyne/train.py | 12 ++++++------ 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/yoyodyne/data/__init__.py b/yoyodyne/data/__init__.py index d7600af1..0982c408 100644 --- a/yoyodyne/data/__init__.py +++ b/yoyodyne/data/__init__.py @@ -1,5 +1,6 @@ from .batches import PaddedBatch, PaddedTensor # noqa: F401 from .datasets import ( + Item, BaseDataset, DatasetNoFeatures, DatasetFeatures, diff --git a/yoyodyne/predict.py b/yoyodyne/predict.py index 7411bbb2..06817894 100644 --- a/yoyodyne/predict.py +++ b/yoyodyne/predict.py @@ -4,7 +4,7 @@ import os import pytorch_lightning as pl -from torch.utils import data +from torch.utils import data as torch_data from . import collators, dataconfig, data, defaults, models, util @@ -44,7 +44,7 @@ def get_loader( batch_size: int, max_source_length: int, max_target_length: int, -) -> data.DataLoader: +) -> torch_data.DataLoader: """Creates the loader. Args: @@ -55,7 +55,7 @@ def get_loader( max_target_length (int). Returns: - data.DataLoader. + torch_data.DataLoader. """ collator = collators.Collator( dataset, @@ -63,7 +63,7 @@ def get_loader( max_source_length, max_target_length, ) - return data.DataLoader( + return torch_data.DataLoader( dataset, collate_fn=collator, batch_size=batch_size, @@ -100,7 +100,7 @@ def _mkdir(output: str) -> None: def predict( trainer: pl.Trainer, model: pl.LightningModule, - loader: data.DataLoader, + loader: torch_data.DataLoader, output: str, ) -> None: """Predicts from the model. @@ -108,7 +108,7 @@ def predict( Args: trainer (pl.Trainer). model (pl.LightningModule). - loader (data.DataLoader). + loader (torch_data.DataLoader). output (str). target_sep (str). """ diff --git a/yoyodyne/train.py b/yoyodyne/train.py index 1801f493..e624ce92 100644 --- a/yoyodyne/train.py +++ b/yoyodyne/train.py @@ -6,7 +6,7 @@ import pytorch_lightning as pl import wandb from pytorch_lightning import callbacks, loggers -from torch.utils import data +from torch.utils import data as torch_data from . import ( collators, @@ -136,7 +136,7 @@ def get_loaders( batch_size: int, max_source_length: int, max_target_length: int, -) -> Tuple[data.DataLoader, data.DataLoader]: +) -> Tuple[torch_data.DataLoader, torch_data.DataLoader]: """Creates the loaders. Args: @@ -157,14 +157,14 @@ def get_loaders( max_source_length, max_target_length, ) - train_loader = data.DataLoader( + train_loader = torch_data.DataLoader( train_set, collate_fn=collator, batch_size=batch_size, shuffle=True, num_workers=1, # Our data loading is simple. ) - dev_loader = data.DataLoader( + dev_loader = torch_data.DataLoader( dev_set, collate_fn=collator, batch_size=2 * batch_size, # Because we're not collecting gradients. @@ -255,8 +255,8 @@ def get_model_from_argparse_args( def train( trainer: pl.Trainer, model: models.BaseEncoderDecoder, - train_loader: data.DataLoader, - dev_loader: data.DataLoader, + train_loader: torch_data.DataLoader, + dev_loader: torch_data.DataLoader, train_from: Optional[str] = None, ) -> str: """Trains the model. From d8dee251f4d29ea5b7c53b87b7e6d6fc394e0135 Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Sat, 15 Jul 2023 20:05:16 -0400 Subject: [PATCH 04/18] Migrates collator too. --- yoyodyne/data/__init__.py | 37 +++++++++++++++++---- yoyodyne/{ => data}/collators.py | 57 +++++++++++++++++--------------- yoyodyne/predict.py | 8 ++--- yoyodyne/train.py | 7 ++-- 4 files changed, 68 insertions(+), 41 deletions(-) rename yoyodyne/{ => data}/collators.py (81%) diff --git a/yoyodyne/data/__init__.py b/yoyodyne/data/__init__.py index 0982c408..0dfdea3b 100644 --- a/yoyodyne/data/__init__.py +++ b/yoyodyne/data/__init__.py @@ -1,8 +1,31 @@ +import argparse + +from .. import defaults + from .batches import PaddedBatch, PaddedTensor # noqa: F401 -from .datasets import ( - Item, - BaseDataset, - DatasetNoFeatures, - DatasetFeatures, - get_dataset, -) # noqa: F401 +from .collators import Collator # noqa: F401 +from .datasets import Item # noqa: F401 +from .datasets import BaseDataset # noqa: F401 +from .datasets import DatasetNoFeatures # noqa: F401 +from .datasets import DatasetFeatures # noqa: F401 +from .datasets import get_dataset # noqa: F401 + + +def add_argparse_args(parser: argparse.ArgumentParser) -> None: + """Adds collator options to the argument parser. + + Args: + parser (argparse.ArgumentParser). + """ + parser.add_argument( + "--max_source_length", + type=int, + default=defaults.MAX_SOURCE_LENGTH, + help="Maximum source string length. Default: %(default)s.", + ) + parser.add_argument( + "--max_target_length", + type=int, + default=defaults.MAX_TARGET_LENGTH, + help="Maximum target string length. Default: %(default)s.", + ) diff --git a/yoyodyne/collators.py b/yoyodyne/data/collators.py similarity index 81% rename from yoyodyne/collators.py rename to yoyodyne/data/collators.py index c24f3707..dbc4de0f 100644 --- a/yoyodyne/collators.py +++ b/yoyodyne/data/collators.py @@ -5,7 +5,8 @@ import torch -from . import data, defaults, util +from .. import defaults, util +from . import batches, datasets class LengthError(Exception): @@ -25,7 +26,7 @@ class Collator: def __init__( self, - dataset: data.BaseDataset, + dataset: datasets.BaseDataset, arch: str, max_source_length: int = defaults.MAX_SOURCE_LENGTH, max_target_length: int = defaults.MAX_TARGET_LENGTH, @@ -90,7 +91,7 @@ def _target_length_warning(self, padded_length: int) -> None: def concatenate_source_and_features( self, - itemlist: List[data.Item], + itemlist: List[datasets.Item], ) -> List[torch.Tensor]: """Concatenates source and feature tensors.""" return [ @@ -102,16 +103,18 @@ def concatenate_source_and_features( for item in itemlist ] - def pad_source(self, itemlist: List[data.Item]) -> data.PaddedTensor: + def pad_source( + self, itemlist: List[datasets.Item] + ) -> batches.PaddedTensor: """Pads source. Args: - itemlist (List[data.Item]). + itemlist (List[datasets.Item]). Returns: - data.PaddedTensor. + batches.PaddedTensor. """ - return data.PaddedTensor( + return batches.PaddedTensor( [item.source for item in itemlist], self.pad_idx, self._source_length_error, @@ -119,17 +122,17 @@ def pad_source(self, itemlist: List[data.Item]) -> data.PaddedTensor: def pad_source_features( self, - itemlist: List[data.Item], - ) -> data.PaddedTensor: + itemlist: List[datasets.Item], + ) -> batches.PaddedTensor: """Pads concatenated source and features. Args: - itemlist (List[data.Item]). + itemlist (List[datasets.Item]). Returns: - data.PaddedTensor. + batches.PaddedTensor. """ - return data.PaddedTensor( + return batches.PaddedTensor( self.concatenate_source_and_features(itemlist), self.pad_idx, self._source_length_error, @@ -137,53 +140,55 @@ def pad_source_features( def pad_features( self, - itemlist: List[data.Item], - ) -> data.PaddedTensor: + itemlist: List[datasets.Item], + ) -> batches.PaddedTensor: """Pads features. Args: - itemlist (List[data.Item]). + itemlist (List[datasets.Item]). Returns: - data.PaddedTensor. + batches.PaddedTensor. """ - return data.PaddedTensor( + return batches.PaddedTensor( [item.features for item in itemlist], self.pad_idx ) - def pad_target(self, itemlist: List[data.Item]) -> data.PaddedTensor: + def pad_target( + self, itemlist: List[datasets.Item] + ) -> batches.PaddedTensor: """Pads target. Args: - itemlist (List[data.Item]). + itemlist (List[datasets.Item]). Returns: - data.PaddedTensor. + batches.PaddedTensor. """ - return data.PaddedTensor( + return batches.PaddedTensor( [item.target for item in itemlist], self.pad_idx, self._target_length_warning, ) - def __call__(self, itemlist: List[data.Item]) -> data.PaddedBatch: + def __call__(self, itemlist: List[datasets.Item]) -> batches.PaddedBatch: """Pads all elements of an itemlist. Args: - itemlist (List[data.Item]). + itemlist (List[datasets.Item]). Returns: - data.PaddedBatch. + batches.PaddedBatch. """ padded_target = self.pad_target(itemlist) if self.has_target else None if self.separate_features: - return data.PaddedBatch( + return batches.PaddedBatch( self.pad_source(itemlist), features=self.pad_features(itemlist), target=padded_target, ) else: - return data.PaddedBatch( + return batches.PaddedBatch( self.pad_source_features(itemlist), target=padded_target, ) diff --git a/yoyodyne/predict.py b/yoyodyne/predict.py index 06817894..67c0e9cd 100644 --- a/yoyodyne/predict.py +++ b/yoyodyne/predict.py @@ -6,7 +6,7 @@ import pytorch_lightning as pl from torch.utils import data as torch_data -from . import collators, dataconfig, data, defaults, models, util +from . import dataconfig, data, defaults, models, util def get_trainer_from_argparse_args( @@ -57,7 +57,7 @@ def get_loader( Returns: torch_data.DataLoader. """ - collator = collators.Collator( + collator = data.Collator( dataset, arch, max_source_length, @@ -163,8 +163,8 @@ def add_argparse_args(parser: argparse.ArgumentParser) -> None: # TODO: add --beam_width. # Data configuration arguments. dataconfig.DataConfig.add_argparse_args(parser) - # Collator arguments. - collators.Collator.add_argparse_args(parser) + # Data arguments. + data.add_argparse_args(parser) # Architecture arguments; the architecture-specific ones are not needed. models.add_argparse_args(parser) # Among the things this adds, the following are likely to be useful: diff --git a/yoyodyne/train.py b/yoyodyne/train.py index e624ce92..4c11d4d1 100644 --- a/yoyodyne/train.py +++ b/yoyodyne/train.py @@ -9,7 +9,6 @@ from torch.utils import data as torch_data from . import ( - collators, dataconfig, data, defaults, @@ -151,7 +150,7 @@ def get_loaders( Tuple[data.DataLoader, data.DataLoader]: the training and development loaders. """ - collator = collators.Collator( + collator = data.Collator( train_set, arch, max_source_length, @@ -338,8 +337,8 @@ def add_argparse_args(parser: argparse.ArgumentParser) -> None: ) # Data configuration arguments. dataconfig.DataConfig.add_argparse_args(parser) - # Collator arguments. - collators.Collator.add_argparse_args(parser) + # Data arguments. + data.add_argparse_args(parser) # Architecture arguments. models.add_argparse_args(parser) models.modules.add_argparse_args(parser) From ca0b0c19909c29363687ad77bb98571acd821c20 Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Sat, 15 Jul 2023 20:31:20 -0400 Subject: [PATCH 05/18] Incomplete work on migrating datasets and indexes. --- yoyodyne/data/__init__.py | 66 +++++++++- yoyodyne/data/datasets.py | 77 +++-------- yoyodyne/{ => data}/indexes.py | 0 yoyodyne/data/tsv.py | 174 +++++++++++++++++++++++++ yoyodyne/dataconfig.py | 227 --------------------------------- yoyodyne/predict.py | 4 +- yoyodyne/train.py | 21 ++- 7 files changed, 267 insertions(+), 302 deletions(-) rename yoyodyne/{ => data}/indexes.py (100%) create mode 100644 yoyodyne/data/tsv.py delete mode 100644 yoyodyne/dataconfig.py diff --git a/yoyodyne/data/__init__.py b/yoyodyne/data/__init__.py index 0dfdea3b..aa588481 100644 --- a/yoyodyne/data/__init__.py +++ b/yoyodyne/data/__init__.py @@ -9,14 +9,76 @@ from .datasets import DatasetNoFeatures # noqa: F401 from .datasets import DatasetFeatures # noqa: F401 from .datasets import get_dataset # noqa: F401 +from .indexes import Index # noqa: F401 def add_argparse_args(parser: argparse.ArgumentParser) -> None: - """Adds collator options to the argument parser. - + """Adds data options to the argument parser. Args: parser (argparse.ArgumentParser). """ + parser.add_argument( + "--source_col", + type=int, + default=defaults.SOURCE_COL, + help="1-based index for source column. Default: %(default)s.", + ) + parser.add_argument( + "--target_col", + type=int, + default=defaults.TARGET_COL, + help="1-based index for target column. Default: %(default)s.", + ) + parser.add_argument( + "--features_col", + type=int, + default=defaults.FEATURES_COL, + help="1-based index for features column; " + "0 indicates the model will not use features. " + "Default: %(default)s.", + ) + parser.add_argument( + "--source_sep", + type=str, + default=defaults.SOURCE_SEP, + help="String used to split source string into symbols; " + "an empty string indicates that each Unicode codepoint " + "is its own symbol. Default: %(default)r.", + ) + parser.add_argument( + "--target_sep", + type=str, + default=defaults.TARGET_SEP, + help="String used to split target string into symbols; " + "an empty string indicates that each Unicode codepoint " + "is its own symbol. Default: %(default)r.", + ) + parser.add_argument( + "--features_sep", + type=str, + default=defaults.FEATURES_SEP, + help="String used to split features string into symbols; " + "an empty string indicates that each Unicode codepoint " + "is its own symbol. Default: %(default)r.", + ) + parser.add_argument( + "--tied_vocabulary", + action="store_true", + default=defaults.TIED_VOCABULARY, + help="Share source and target embeddings. Default: %(default)s.", + ) + parser.add_argument( + "--no_tied_vocabulary", + action="store_false", + dest="tied_vocabulary", + default=True, + ) + parser.add_argument( + "--batch_size", + type=int, + default=defaults.BATCH_SIZE, + help="Batch size. Default: %(default)s.", + ) parser.add_argument( "--max_source_length", type=int, diff --git a/yoyodyne/data/datasets.py b/yoyodyne/data/datasets.py index dc8d2a2e..83de6077 100644 --- a/yoyodyne/data/datasets.py +++ b/yoyodyne/data/datasets.py @@ -10,7 +10,9 @@ from torch import nn from torch.utils import data -from .. import dataconfig, indexes, special +from .. import special + +from . import indexes, tsv class Item(nn.Module): @@ -54,49 +56,30 @@ def __init__(self): class DatasetNoFeatures(BaseDataset): """Dataset object without feature column.""" - filename: str - config: dataconfig.DataConfig - samples: List[List[str]] - index: indexes.Index + samples: List[str] + index: indexes.Index # Usually copied. + string_parser: tsv.StringParser # Ditto. def __init__( self, filename, - config, - index: Optional[indexes.Index] = None, + tsv_parser, + string_parser, + index: indexes.Index, ): """Initializes the dataset. Args: filename (str): input filename. - config (dataconfig.DataConfig): dataset configuration. + string_parser (tsv.StringParser). other (indexes.Index, optional): if provided, use this index to avoid recomputing it. """ super().__init__() - self.config = config - self.samples = list(self.config.samples(filename)) + self.samples = list(tsv_parser.samples(filename)) + self.string_parser = string_parser self.index = index if index is not None else self._make_index() - def _make_index(self) -> indexes.Index: - """Generates index.""" - source_vocabulary: Set[str] = set() - target_vocabulary: Set[str] = set() - if self.config.has_target: - for source, target in self.samples: - source_vocabulary.update(source) - target_vocabulary.update(target) - if self.config.tied_vocabulary: - source_vocabulary.update(target_vocabulary) - target_vocabulary.update(source_vocabulary) - else: - for source in self.samples: - source_vocabulary.update(source) - return indexes.Index( - source_vocabulary=sorted(source_vocabulary), - target_vocabulary=sorted(target_vocabulary), - ) - def encode( self, symbol_map: indexes.SymbolMap, @@ -270,31 +253,9 @@ def __len__(self) -> int: class DatasetFeatures(DatasetNoFeatures): """Dataset object with feature column.""" - def _make_index(self) -> indexes.Index: - """Generates index. - - Same as in superclass, but also handles features. - """ - source_vocabulary: Set[str] = set() - features_vocabulary: Set[str] = set() - target_vocabulary: Set[str] = set() - if self.config.has_target: - for source, features, target in self.samples: - source_vocabulary.update(source) - features_vocabulary.update(features) - target_vocabulary.update(target) - if self.config.tied_vocabulary: - source_vocabulary.update(target_vocabulary) - target_vocabulary.update(source_vocabulary) - else: - for source, features in self.samples: - source_vocabulary.update(source) - features_vocabulary.update(features) - return indexes.Index( - source_vocabulary=sorted(source_vocabulary), - features_vocabulary=sorted(features_vocabulary), - target_vocabulary=sorted(target_vocabulary), - ) + samples: List[str] + index: indexes.Index # Usually copied. + string_parser: tsv.StringParser # Ditto. def __getitem__(self, idx: int) -> Item: """Retrieves item by index. @@ -355,21 +316,21 @@ def decode_features( def get_dataset( filename: str, - config: dataconfig.DataConfig, + string_parser: tsv.StringParser, index: Union[indexes.Index, str, None] = None, ) -> data.Dataset: """Dataset factory. Args: filename (str): input filename. - config (dataconfig.DataConfig): dataset configuration. + string_parser (tsv.StringParser): string parser. index (Union[index.Index, str], optional): input index file, or path to index.pkl file. Returns: data.Dataset: the dataset. """ - cls = DatasetFeatures if config.has_features else DatasetNoFeatures + cls = DatasetFeatures if string_parser.has_features else DatasetNoFeatures if isinstance(index, str): index = indexes.Index.read(index) - return cls(filename, config, index) + return cls(filename, string_parser, index) diff --git a/yoyodyne/indexes.py b/yoyodyne/data/indexes.py similarity index 100% rename from yoyodyne/indexes.py rename to yoyodyne/data/indexes.py diff --git a/yoyodyne/data/tsv.py b/yoyodyne/data/tsv.py new file mode 100644 index 00000000..8a448d79 --- /dev/null +++ b/yoyodyne/data/tsv.py @@ -0,0 +1,174 @@ +"""TSV parsing. + +The TsvParser yield string tuples from TSV files using 1-based indexing. + +The CellParser converts between raw strings ("strings") and lists of string +symbols. +""" + + +import csv +import dataclasses +from typing import Iterator, List, Tuple + +from .. import defaults, util + + +class Error(Exception): + """Module-specific exception.""" + + pass + + +@dataclasses.dataclass +class TsvParser: + """Streams rows from a TSV file. + + Args: + source_col (int, optional): 1-indexed column in TSV containing + source strings. + features_col (int, optional): 1-indexed column in TSV containing + features strings. + target_col (int, optional): 1-indexed column in TSV containing + target strings. + """ + + source_col: int = defaults.SOURCE_COL + features_col: int = defaults.FEATURES_COL + target_col: int = defaults.TARGET_COL + + def __post_init__(self) -> None: + # This is automatically called after initialization. + if self.source_col < 1: + raise Error(f"Invalid source column: {self.source_col}") + if self.features_col < 0: + raise Error(f"Invalid features column: {self.features_col}") + if self.features_col != 0: + util.log_info("Including features") + if self.target_col < 0: + raise Error(f"Invalid target column: {self.target_col}") + if self.target_col == 0: + util.log_info("Ignoring targets in input") + + @staticmethod + def _tsv_reader(path: str) -> Iterator[str]: + with open(path, "r") as tsv: + yield from csv.reader(tsv, delimiter="\t") + + @staticmethod + def _get_string(row: List[str], col: int) -> str: + """Returns a string from a row by index. + Args: + row (List[str]): the split row. + col (int): the column index. + Returns: + str: symbol from that string. + """ + return row[col - 1] # -1 because we're using one-based indexing. + + @property + def has_source(self) -> bool: + return True + + @property + def has_features(self) -> bool: + return self.features_col != 0 + + @property + def has_target(self) -> bool: + return self.target_col != 0 + + def source_samples(self, path: str) -> Iterator[str]: + """Yields source.""" + for row in self._tsv_reader(path): + yield self._get_string(row, self.source_col) + + def source_target_samples(self, path: str) -> Iterator[Tuple[str, str]]: + """Yields source and target.""" + for row in self._tsv_reader(path): + source = self._get_string(row, self.source_col) + target = self._get_string(row, self.target_col) + yield source, target + + def source_features_target_samples( + self, path: str + ) -> Iterator[Tuple[str, str, str]]: + """Yields source, features, and target.""" + for row in self._tsv_reader(path): + source = self._get_string(row, self.source_col) + features = self._get_string(row, self.features_col) + target = self._get_string(row, self.target_col) + yield source, features, target + + def source_features_samples(self, path: str) -> Iterator[Tuple[str, str]]: + """Yields source, and features.""" + for row in self._tsv_reader(path): + source = self._get_string(row, self.source_col) + features = self._get_string(row, self.features_col) + yield source, features + + def samples(self, path: str) -> Iterator[Tuple[str, ...]]: + """Picks the right one.""" + if self.has_features: + if self.has_target: + self.source_features_target_samples(path) + else: + return self.source_features_samples(path) + elif self.has_target: + return self.source_target_samples(path) + else: + return self.source_samples(path) + + +@dataclasses.dataclass +class StringParser: + """Parses strings from the TSV file into lists of symbols. + + Args: + source_sep (str, optional): string used to split source string into + symbols; an empty string indicates that each Unicode codepoint is + its own symbol. + features_sep (str, optional): string used to split features string into + symbols; an empty string indicates that each Unicode codepoint is + its own symbol. + target_sep (str, optional): string used to split target string into + symbols; an empty string indicates that each Unicode codepoint is + its own symbol. + """ + + source_sep: str = defaults.SOURCE_SEP + features_sep: str = defaults.FEATURES_SEP + target_sep: str = defaults.TARGET_SEP + + # Parsing methods. + + @staticmethod + def _get_symbols(string: str, sep: str) -> List[str]: + return list(string) if not sep else sep.split(cell) + + def source_symbols(self, string: str) -> List[str]: + return self._get_symbols(string, self.features_sep) + + def features_symbols(self, string: str) -> List[str]: + # We deliberately obfuscate these to avoid overlap with source. + return [ + f"[{symbol}]" + for symbol in self._get_symbols(string, self.features_sep) + ] + + def target_symbols(self, string: str) -> List[str]: + return self._get_symbols(string, self.target_sep) + + # Deserialization methods. + + def source_string(self, symbols: List[str]) -> str: + return self.source_sep.join(symbols) + + def features_string(self, symbols: List[str]) -> str: + return self.features_sep.join( + # This indexing strips off the obfuscation. + [symbol[1:-1] for symbol in symbols], + ) + + def target_string(self, symbols: List[str]) -> str: + return self.target_sep.join(symbols) diff --git a/yoyodyne/dataconfig.py b/yoyodyne/dataconfig.py deleted file mode 100644 index 728f9e0b..00000000 --- a/yoyodyne/dataconfig.py +++ /dev/null @@ -1,227 +0,0 @@ -"""Dataset config class.""" - -import argparse -import csv -import dataclasses -import inspect -from typing import Iterator, List, Tuple - -from . import defaults, util - - -class Error(Exception): - """Module-specific exception.""" - - pass - - -@dataclasses.dataclass -class DataConfig: - """Configuration specifications for a dataset. - - Args: - source_col (int, optional): 1-indexed column in TSV containing - source strings. - target_col (int, optional): 1-indexed column in TSV containing - target strings. - features_col (int, optional): 1-indexed column in TSV containing - features strings. - source_sep (str, optional): separator character between symbol in - source string. "" treats each character in source as a symbol. - target_sep (str, optional): separator character between symbol in - target string. "" treats each character in target as a symbol. - features_sep (str, optional): separator character between symbol in - features string. "" treats each character in features as a symbol. - tied_vocabulary (bool, optional): whether the source and target - should share a vocabulary. - """ - - source_col: int = defaults.SOURCE_COL - target_col: int = defaults.TARGET_COL - features_col: int = defaults.FEATURES_COL - source_sep: str = defaults.SOURCE_SEP - target_sep: str = defaults.TARGET_SEP - features_sep: str = defaults.FEATURES_SEP - tied_vocabulary: bool = defaults.TIED_VOCABULARY - - def __post_init__(self) -> None: - # This is automatically called after initialization. - if self.source_col < 1: - raise Error(f"Invalid source column: {self.source_col}") - if self.target_col < 0: - raise Error(f"Invalid target column: {self.target_col}") - if self.target_col == 0: - util.log_info("Ignoring targets in input") - if self.features_col < 0: - raise Error(f"Invalid features column: {self.features_col}") - if self.features_col != 0: - util.log_info("Including features") - - @classmethod - def from_argparse_args(cls, args, **kwargs): - """Creates an instance from CLI arguments.""" - params = vars(args) - valid_kwargs = inspect.signature(cls.__init__).parameters - dataconfig_kwargs = { - name: params[name] for name in valid_kwargs if name in params - } - dataconfig_kwargs.update(**kwargs) - return cls(**dataconfig_kwargs) - - @staticmethod - def _get_cell(row: List[str], col: int, sep: str) -> List[str]: - """Returns the split cell of a row. - - Args: - row (List[str]): the split row. - col (int): the column index - sep (str): the string to split the column on; if the empty string, - the column is split into characters instead. - - Returns: - List[str]: symbol from that cell. - """ - cell = row[col - 1] # -1 because we're using one-based indexing. - return list(cell) if not sep else cell.split(sep) - - # Source is always present. - - @property - def has_target(self) -> bool: - return self.target_col != 0 - - @property - def has_features(self) -> bool: - return self.features_col != 0 - - def source_samples(self, filename: str) -> Iterator[List[str]]: - """Yields source.""" - with open(filename, "r") as source: - tsv_reader = csv.reader(source, delimiter="\t") - for row in tsv_reader: - yield self._get_cell(row, self.source_col, self.source_sep) - - def source_target_samples( - self, filename: str - ) -> Iterator[Tuple[List[str], List[str]]]: - """Yields source and target.""" - with open(filename, "r") as source: - tsv_reader = csv.reader(source, delimiter="\t") - for row in tsv_reader: - source = self._get_cell(row, self.source_col, self.source_sep) - target = self._get_cell(row, self.target_col, self.target_sep) - yield source, target - - def source_features_target_samples( - self, filename: str - ) -> Iterator[Tuple[List[str], List[str], List[str]]]: - """Yields source, features, and target.""" - with open(filename, "r") as source: - tsv_reader = csv.reader(source, delimiter="\t") - for row in tsv_reader: - source = self._get_cell(row, self.source_col, self.source_sep) - # Avoids overlap with source. - features = [ - f"[{feature}]" - for feature in self._get_cell( - row, self.features_col, self.features_sep - ) - ] - target = self._get_cell(row, self.target_col, self.target_sep) - yield source, features, target - - def source_features_samples( - self, filename: str - ) -> Iterator[Tuple[List[str], List[str]]]: - """Yields source, and features.""" - with open(filename, "r") as source: - tsv_reader = csv.reader(source, delimiter="\t") - for row in tsv_reader: - source = self._get_cell(row, self.source_col, self.source_sep) - # Avoids overlap with source. - features = [ - f"[{feature}]" - for feature in self._get_cell( - row, self.features_col, self.features_sep - ) - ] - yield source, features - - def samples(self, filename: str) -> Iterator[Tuple[List[str], ...]]: - """Picks the right one for this config.""" - if self.has_features: - return ( - self.source_features_target_samples(filename) - if self.has_target - else self.source_features_samples(filename) - ) - else: - return ( - self.source_target_samples(filename) - if self.has_target - else self.source_samples(filename) - ) - - @staticmethod - def add_argparse_args(parser: argparse.ArgumentParser) -> None: - """Adds data configuration options to the argument parser. - - Args: - parser (argparse.ArgumentParser). - """ - parser.add_argument( - "--source_col", - type=int, - default=defaults.SOURCE_COL, - help="1-based index for source column. Default: %(default)s.", - ) - parser.add_argument( - "--target_col", - type=int, - default=defaults.TARGET_COL, - help="1-based index for target column. Default: %(default)s.", - ) - parser.add_argument( - "--features_col", - type=int, - default=defaults.FEATURES_COL, - help="1-based index for features column; " - "0 indicates the model will not use features. " - "Default: %(default)s.", - ) - parser.add_argument( - "--source_sep", - type=str, - default=defaults.SOURCE_SEP, - help="String used to split source string into symbols; " - "an empty string indicates that each Unicode codepoint " - "is its own symbol. Default: %(default)r.", - ) - parser.add_argument( - "--target_sep", - type=str, - default=defaults.TARGET_SEP, - help="String used to split target string into symbols; " - "an empty string indicates that each Unicode codepoint " - "is its own symbol. Default: %(default)r.", - ) - parser.add_argument( - "--features_sep", - type=str, - default=defaults.FEATURES_SEP, - help="String used to split features string into symbols; " - "an empty string indicates that each Unicode codepoint " - "is its own symbol. Default: %(default)r.", - ) - parser.add_argument( - "--tied_vocabulary", - action="store_true", - default=defaults.TIED_VOCABULARY, - help="Share source and target embeddings. Default: %(default)s.", - ) - parser.add_argument( - "--no_tied_vocabulary", - action="store_false", - dest="tied_vocabulary", - default=True, - ) diff --git a/yoyodyne/predict.py b/yoyodyne/predict.py index 67c0e9cd..452d2e9c 100644 --- a/yoyodyne/predict.py +++ b/yoyodyne/predict.py @@ -6,7 +6,7 @@ import pytorch_lightning as pl from torch.utils import data as torch_data -from . import dataconfig, data, defaults, models, util +from . import data, defaults, models, util def get_trainer_from_argparse_args( @@ -161,8 +161,6 @@ def add_argparse_args(parser: argparse.ArgumentParser) -> None: help="Batch size. Default: %(default)s.", ) # TODO: add --beam_width. - # Data configuration arguments. - dataconfig.DataConfig.add_argparse_args(parser) # Data arguments. data.add_argparse_args(parser) # Architecture arguments; the architecture-specific ones are not needed. diff --git a/yoyodyne/train.py b/yoyodyne/train.py index 4c11d4d1..e998b781 100644 --- a/yoyodyne/train.py +++ b/yoyodyne/train.py @@ -8,14 +8,7 @@ from pytorch_lightning import callbacks, loggers from torch.utils import data as torch_data -from . import ( - dataconfig, - data, - defaults, - models, - schedulers, - util, -) +from . import data, defaults, models, schedulers, util class Error(Exception): @@ -114,9 +107,15 @@ def get_data_from_argparse_args( Tuple[data.BaseDataset, data.BaseDataset]: the training and development data. """ - config = dataconfig.DataConfig.from_argparse_args(args) - if config.target_col == 0: + tsv_parser = data.TsvParser( + source_col=args.source_col, + features_col=args.features_col, + target_col=args.target_col, + ) + if not tsv_parser.has_target: raise Error("target_col must be specified for training") + + train_set = data.get_dataset(args.train, config) dev_set = data.get_dataset(args.dev, config, train_set.index) util.log_info(f"Source vocabulary: {train_set.index.source_map.pprint()}") @@ -335,8 +334,6 @@ def add_argparse_args(parser: argparse.ArgumentParser) -> None: action="store_false", dest="log_wandb", ) - # Data configuration arguments. - dataconfig.DataConfig.add_argparse_args(parser) # Data arguments. data.add_argparse_args(parser) # Architecture arguments. From d1f29455a1ab8077db302f8fe4e35db467e1a9e2 Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Sun, 16 Jul 2023 15:26:24 -0400 Subject: [PATCH 06/18] Moves all the pieces into the `data` directory. However, I haven't created the data module class yet, so it's a little goofy. --- yoyodyne/data/__init__.py | 13 +++---- yoyodyne/data/collators.py | 7 ++-- yoyodyne/data/datasets.py | 57 +++++++++++++------------------ yoyodyne/data/indexes.py | 2 +- yoyodyne/data/tsv.py | 2 +- yoyodyne/predict.py | 14 ++++++-- yoyodyne/train.py | 70 ++++++++++++++++++++++++++++++++------ 7 files changed, 107 insertions(+), 58 deletions(-) diff --git a/yoyodyne/data/__init__.py b/yoyodyne/data/__init__.py index aa588481..e5a8c5ca 100644 --- a/yoyodyne/data/__init__.py +++ b/yoyodyne/data/__init__.py @@ -10,6 +10,7 @@ from .datasets import DatasetFeatures # noqa: F401 from .datasets import get_dataset # noqa: F401 from .indexes import Index # noqa: F401 +from .tsv import TsvParser, StringParser # noqa: F401 def add_argparse_args(parser: argparse.ArgumentParser) -> None: @@ -73,12 +74,12 @@ def add_argparse_args(parser: argparse.ArgumentParser) -> None: dest="tied_vocabulary", default=True, ) - parser.add_argument( - "--batch_size", - type=int, - default=defaults.BATCH_SIZE, - help="Batch size. Default: %(default)s.", - ) + # parser.add_argument( + # "--batch_size", + # type=int, + # default=defaults.BATCH_SIZE, + # help="Batch size. Default: %(default)s.", + # ) parser.add_argument( "--max_source_length", type=int, diff --git a/yoyodyne/data/collators.py b/yoyodyne/data/collators.py index dbc4de0f..5a8d8883 100644 --- a/yoyodyne/data/collators.py +++ b/yoyodyne/data/collators.py @@ -41,15 +41,14 @@ def __init__( """ self.index = dataset.index self.pad_idx = self.index.pad_idx - self.config = dataset.config - self.has_features = self.config.has_features - self.has_target = self.config.has_target + self.has_features = dataset.has_features + self.has_target = dataset.has_target self.max_source_length = max_source_length self.max_target_length = max_target_length self.features_offset = ( dataset.index.source_vocab_size if self.has_features else 0 ) - self.separate_features = dataset.config.has_features and arch in [ + self.separate_features = dataset.has_features and arch in [ "pointer_generator_lstm", "transducer", ] diff --git a/yoyodyne/data/datasets.py b/yoyodyne/data/datasets.py index 83de6077..62cad37b 100644 --- a/yoyodyne/data/datasets.py +++ b/yoyodyne/data/datasets.py @@ -4,7 +4,7 @@ superclass constructor, and register the tensor as a buffer. This enables the Trainer to move them to the appropriate device.""" -from typing import List, Optional, Set, Union +from typing import List, Optional, Union import torch from torch import nn @@ -49,17 +49,25 @@ def has_target(self): class BaseDataset(data.Dataset): """Base datatset class.""" + samples: List[str] + index: indexes.Index # Usually copied. + string_parser: tsv.StringParser # Ditto. + def __init__(self): super().__init__() + @property + def has_features(self) -> bool: + return self.index.has_features + + @property + def has_target(self) -> bool: + return self.index.has_target + class DatasetNoFeatures(BaseDataset): """Dataset object without feature column.""" - samples: List[str] - index: indexes.Index # Usually copied. - string_parser: tsv.StringParser # Ditto. - def __init__( self, filename, @@ -78,7 +86,7 @@ def __init__( super().__init__() self.samples = list(tsv_parser.samples(filename)) self.string_parser = string_parser - self.index = index if index is not None else self._make_index() + self.index = index def encode( self, @@ -208,22 +216,6 @@ def decode_target( special=special, ) - @property - def has_features(self) -> bool: - return self.index.has_features - - @staticmethod - def read_index(path: str) -> indexes.Index: - """Helper for loading index. - - Args: - path (str). - - Returns: - indexes.IndexNoFeatures. - """ - return indexes.Index.read(path) - def __getitem__(self, idx: int) -> Item: """Retrieves item by index. @@ -233,12 +225,12 @@ def __getitem__(self, idx: int) -> Item: Returns: Item. """ - if self.config.has_target: + if self.has_target: source, target = self.samples[idx] else: source = self.samples[idx] source_encoded = self.encode(self.index.source_map, source) - if self.config.has_target: + if self.has_target: target_encoded = self.encode( self.index.target_map, target, add_start_tag=False ) @@ -266,7 +258,7 @@ def __getitem__(self, idx: int) -> Item: Returns: Item. """ - if self.config.has_target: + if self.has_target: source, features, target = self.samples[idx] else: source, features = self.samples[idx] @@ -277,7 +269,7 @@ def __getitem__(self, idx: int) -> Item: add_start_tag=False, add_end_tag=False, ) - if self.config.has_target: + if self.has_target: return Item( source_encoded, target=self.encode( @@ -316,6 +308,7 @@ def decode_features( def get_dataset( filename: str, + tsv_parser: tsv.TsvParser, string_parser: tsv.StringParser, index: Union[indexes.Index, str, None] = None, ) -> data.Dataset: @@ -323,14 +316,12 @@ def get_dataset( Args: filename (str): input filename. - string_parser (tsv.StringParser): string parser. - index (Union[index.Index, str], optional): input index file, - or path to index.pkl file. + tsv_parser (tsv.TsvParser). + string_parser (tsv.StringParser). + index (indexes.Index). Returns: data.Dataset: the dataset. """ - cls = DatasetFeatures if string_parser.has_features else DatasetNoFeatures - if isinstance(index, str): - index = indexes.Index.read(index) - return cls(filename, string_parser, index) + cls = DatasetFeatures if tsv_parser.has_features else DatasetNoFeatures + return cls(filename, tsv_parser, string_parser, index) diff --git a/yoyodyne/data/indexes.py b/yoyodyne/data/indexes.py index 22ccedb5..b9e06d9a 100644 --- a/yoyodyne/data/indexes.py +++ b/yoyodyne/data/indexes.py @@ -4,7 +4,7 @@ import pickle from typing import Dict, List, Optional, Set -from . import special +from .. import special class SymbolMap: diff --git a/yoyodyne/data/tsv.py b/yoyodyne/data/tsv.py index 8a448d79..07731667 100644 --- a/yoyodyne/data/tsv.py +++ b/yoyodyne/data/tsv.py @@ -144,7 +144,7 @@ class StringParser: @staticmethod def _get_symbols(string: str, sep: str) -> List[str]: - return list(string) if not sep else sep.split(cell) + return list(string) if not sep else sep.split(string) def source_symbols(self, string: str) -> List[str]: return self._get_symbols(string, self.features_sep) diff --git a/yoyodyne/predict.py b/yoyodyne/predict.py index 452d2e9c..58e1a33d 100644 --- a/yoyodyne/predict.py +++ b/yoyodyne/predict.py @@ -34,8 +34,18 @@ def get_dataset_from_argparse_args( Returns: data.BaseDataset. """ - config = dataconfig.DataConfig.from_argparse_args(args) - return data.get_dataset(args.predict, config, args.index) + tsv_parser = data.TsvParser( + source_col=args.source_col, + features_col=args.features_col, + target_col=args.target_col, + ) + string_parser = data.StringParser( + source_sep=args.source_sep, + features_sep=args.features_sep, + target_sep=args.target_sep, + ) + index = data.Index.read(args.index) + return data.get_dataset(args.predict, tsv_parser, string_parser, index) def get_loader( diff --git a/yoyodyne/train.py b/yoyodyne/train.py index e998b781..ecb0d5d3 100644 --- a/yoyodyne/train.py +++ b/yoyodyne/train.py @@ -1,7 +1,7 @@ """Trains a sequence-to-sequence neural network.""" import argparse -from typing import List, Optional, Tuple +from typing import List, Optional, Set, Tuple import pytorch_lightning as pl import wandb @@ -114,16 +114,64 @@ def get_data_from_argparse_args( ) if not tsv_parser.has_target: raise Error("target_col must be specified for training") - - - train_set = data.get_dataset(args.train, config) - dev_set = data.get_dataset(args.dev, config, train_set.index) - util.log_info(f"Source vocabulary: {train_set.index.source_map.pprint()}") - if train_set.has_features: - util.log_info( - f"Feature vocabulary: {train_set.index.features_map.pprint()}" - ) - util.log_info(f"Target vocabulary: {train_set.index.target_map.pprint()}") + string_parser = data.StringParser( + source_sep=args.source_sep, + features_sep=args.features_sep, + target_sep=args.target_sep, + ) + # TODO: move this into the data module. + separate_features = tsv_parser.has_features and args.arch in [ + "pointer_generator_lstm", + "transducer", + ] + # Computes index. + source_vocabulary: Set[str] = set() + features_vocabulary: Set[str] = set() + target_vocabulary: Set[str] = set() + for path in [args.train, args.dev]: + if tsv_parser.has_features: + if tsv_parser.has_target: + for source, features, target in tsv_parser.samples(path): + source_vocabulary.update( + string_parser.source_symbols(source) + ) + features_vocabulary.update( + string_parser.features_symbols(features) + ) + target_vocabulary.update( + string_parser.target_symbols(target) + ) + else: + for source, features in tsv_parser.samples(path): + source_vocabulary.update( + string_parser.source_symbols(source) + ) + features_vocabulary.update( + string_parser.features_symbols(features) + ) + elif tsv_parser.has_target: + for source, target in tsv_parser.samples(path): + source_vocabulary.update(string_parser.source_symbols(source)) + target_vocabulary.update(string_parser.target_symbols(target)) + else: + for source in tsv_parser.samples(path): + source_vocabulary.update(string_parser.source_symbols(source)) + if tsv_parser.has_target and args.tied_vocabulary: + source_vocabulary.update(target_vocabulary) + target_vocabulary.update(source_vocabulary) + index = data.Index( + source_vocabulary=sorted(source_vocabulary), + features_vocabulary=sorted(features_vocabulary) + if separate_features + else None, + target_vocabulary=sorted(target_vocabulary), + ) + util.log_info(f"Source vocabulary: {index.source_map.pprint()}") + if tsv_parser.has_features: + util.log_info(f"Feature vocabulary: {index.features_map.pprint()}") + util.log_info(f"Target vocabulary: {index.target_map.pprint()}") + train_set = data.get_dataset(args.train, tsv_parser, string_parser, index) + dev_set = data.get_dataset(args.dev, tsv_parser, string_parser, index) return train_set, dev_set From 533db765db946c229a035351621b6088948bf706 Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Sun, 16 Jul 2023 16:40:07 -0400 Subject: [PATCH 07/18] More hacking. --- README.md | 2 +- yoyodyne/data/__init__.py | 24 +-- yoyodyne/data/collators.py | 37 +--- yoyodyne/data/datamodules.py | 205 ++++++++++++++++++++++ yoyodyne/data/datasets.py | 320 +++++++++++++---------------------- yoyodyne/data/tsv.py | 4 - yoyodyne/models/__init__.py | 2 +- yoyodyne/predict.py | 99 +++-------- yoyodyne/train.py | 184 +++++--------------- 9 files changed, 400 insertions(+), 477 deletions(-) create mode 100644 yoyodyne/data/datamodules.py diff --git a/README.md b/README.md index 902460e6..a05fa1de 100644 --- a/README.md +++ b/README.md @@ -101,7 +101,7 @@ checkpoint, the **full path to the checkpoint** should be specified with for the the provided model checkpoint. During training, we save the best `--save_top_k` checkpoints (by default, 1) -ranked according to accuracy on the `--dev` set. For example, `--save_top_k 5` +ranked according to accuracy on the `--val` set. For example, `--save_top_k 5` will save the top 5 most accurate models. ## Reserved symbols diff --git a/yoyodyne/data/__init__.py b/yoyodyne/data/__init__.py index e5a8c5ca..e1b4368d 100644 --- a/yoyodyne/data/__init__.py +++ b/yoyodyne/data/__init__.py @@ -1,16 +1,10 @@ +"""Data classes.""" + import argparse from .. import defaults - +from .datamodules import DataModule # noqa: F401 from .batches import PaddedBatch, PaddedTensor # noqa: F401 -from .collators import Collator # noqa: F401 -from .datasets import Item # noqa: F401 -from .datasets import BaseDataset # noqa: F401 -from .datasets import DatasetNoFeatures # noqa: F401 -from .datasets import DatasetFeatures # noqa: F401 -from .datasets import get_dataset # noqa: F401 -from .indexes import Index # noqa: F401 -from .tsv import TsvParser, StringParser # noqa: F401 def add_argparse_args(parser: argparse.ArgumentParser) -> None: @@ -74,12 +68,12 @@ def add_argparse_args(parser: argparse.ArgumentParser) -> None: dest="tied_vocabulary", default=True, ) - # parser.add_argument( - # "--batch_size", - # type=int, - # default=defaults.BATCH_SIZE, - # help="Batch size. Default: %(default)s.", - # ) + parser.add_argument( + "--batch_size", + type=int, + default=defaults.BATCH_SIZE, + help="Batch size. Default: %(default)s.", + ) parser.add_argument( "--max_source_length", type=int, diff --git a/yoyodyne/data/collators.py b/yoyodyne/data/collators.py index 5a8d8883..217b88f7 100644 --- a/yoyodyne/data/collators.py +++ b/yoyodyne/data/collators.py @@ -1,6 +1,7 @@ """Collators and related utilities.""" import argparse +import dataclasses from typing import List import torch @@ -13,45 +14,17 @@ class LengthError(Exception): pass +@dataclasses.dataclass class Collator: """Pads data.""" pad_idx: int - features_offset: int has_features: bool has_target: bool - max_source_length: int - max_target_length: int separate_features: bool - - def __init__( - self, - dataset: datasets.BaseDataset, - arch: str, - max_source_length: int = defaults.MAX_SOURCE_LENGTH, - max_target_length: int = defaults.MAX_TARGET_LENGTH, - ): - """Initializes the collator. - - Args: - dataset (dataset.BaseDataset). - arch (str). - max_source_length (int). - max_target_length (int). - """ - self.index = dataset.index - self.pad_idx = self.index.pad_idx - self.has_features = dataset.has_features - self.has_target = dataset.has_target - self.max_source_length = max_source_length - self.max_target_length = max_target_length - self.features_offset = ( - dataset.index.source_vocab_size if self.has_features else 0 - ) - self.separate_features = dataset.has_features and arch in [ - "pointer_generator_lstm", - "transducer", - ] + features_offset: int + max_source_length: int = defaults.MAX_SOURCE_LENGTH + max_target_length: int = defaults.MAX_TARGET_LENGTH def _source_length_error(self, padded_length: int) -> None: """Callback function to raise the error when the padded length of the diff --git a/yoyodyne/data/datamodules.py b/yoyodyne/data/datamodules.py new file mode 100644 index 00000000..d8c63dc6 --- /dev/null +++ b/yoyodyne/data/datamodules.py @@ -0,0 +1,205 @@ +"""Data modules.""" + +from typing import Optional, Set + +import pytorch_lightning as pl +from torch.utils import data + +from .. import defaults, util +from . import collators, datasets, indexes, tsv + + +class DataModule(pl.LightningDataModule): + """Parses, indexes, collates and loads data.""" + + tsv_parser: tsv.TsvParser + string_parser: tsv.StringParser + index: indexes.Index + batch_size: int + collator: collators.Collator + + def __init__( + self, + # Paths. + *, + train: Optional[str] = None, + val: Optional[str] = None, + predict: Optional[str] = None, + test: Optional[str] = None, + index_path: Optional[str] = None, + # TSV parsing arguments. + source_col: int = defaults.SOURCE_COL, + features_col: int = defaults.FEATURES_COL, + target_col: int = defaults.TARGET_COL, + # String parsing arguments. + source_sep: str = defaults.SOURCE_SEP, + features_sep: str = defaults.FEATURES_SEP, + target_sep: str = defaults.TARGET_SEP, + # Vocabulary options. + tied_vocabulary: bool = defaults.TIED_VOCABULARY, + # Collator options. + batch_size=defaults.BATCH_SIZE, + separate_features: bool = False, + max_source_length: int = defaults.MAX_SOURCE_LENGTH, + max_target_length: int = defaults.MAX_TARGET_LENGTH, + ): + super().__init__() + self.tsv_parser = tsv.TsvParser(source_col, features_col, target_col) + self.string_parser = tsv.StringParser( + source_sep, features_sep, target_sep + ) + self.train = train + self.val = val + self.predict = predict + self.test = test + # Computes index. + source_vocabulary: Set[str] = set() + features_vocabulary: Set[str] = set() + target_vocabulary: Set[str] = set() + for path in [self.train, self.val, self.predict, self.test]: + if path is None: + continue + if self.tsv_parser.has_features: + if self.tsv_parser.has_target: + for source, features, target in self.tsv_parser.samples( + path + ): + source_vocabulary.update( + self.string_parser.source_symbols(source) + ) + features_vocabulary.update( + self.string_parser.features_symbols(features) + ) + target_vocabulary.update( + self.string_parser.target_symbols(target) + ) + else: + for source, features in self.tsv_parser.samples(path): + source_vocabulary.update( + self.string_parser.source_symbols(source) + ) + features_vocabulary.update( + self.string_parser.features_symbols(features) + ) + elif self.tsv_parser.has_target: + for source, target in self.tsv_parser.samples(path): + source_vocabulary.update( + self.string_parser.source_symbols(source) + ) + target_vocabulary.update( + self.string_parser.target_symbols(target) + ) + else: + for source in self.tsv_parser.samples(path): + source_vocabulary.update( + self.string_parser.source_symbols(source) + ) + if self.tsv_parser.has_target and tied_vocabulary: + source_vocabulary.update(target_vocabulary) + target_vocabulary.update(source_vocabulary) + self.separate_features = separate_features + self.index = indexes.Index( + source_vocabulary=sorted(source_vocabulary), + # These two are stored as nulls if empty. + features_vocabulary=sorted(features_vocabulary) + if self.separate_features + else None, + target_vocabulary=sorted(target_vocabulary), + ) + # Stores batch size. + self.batch_size = batch_size + # Makes collator. + self.collator = collators.Collator( + pad_idx=self.index.pad_idx, + has_features=self.index.has_features, + has_target=self.index.has_target, + separate_features=separate_features, + features_offset=self.index.source_vocab_size + if self.index.has_features + else 0, + max_source_length=max_source_length, + max_target_length=max_target_length, + ) + + # Helpers. + + def log_vocabularies(self) -> None: + """Logs this module's vocabularies.""" + util.log_info(f"Source vocabulary: {self.index.source_map.pprint()}") + if self.index.has_features: + util.log_info( + f"Features vocabulary: {self.index.features_map.pprint()}" + ) + if self.index.has_target: + util.log_info( + f"Target vocabulary: {self.index.target_map.pprint()}" + ) + + def write_index(self, model_dir: str, experiment: str) -> None: + """Writes the index.""" + index_path = self.index.index_path(model_dir, experiment) + self.index.write(index_path) + util.log_info(f"Index path: {index_path}") + + @property + def has_features(self) -> int: + return self.index.has_features + + @property + def has_target(self) -> int: + return self.index.has_target + + @property + def source_vocab_size(self) -> int: + if self.separate_features: + return self.index.source_vocab_size + else: + return ( + self.index.source_vocab_size + self.index.features_vocab_size + ) + + def _dataset(self, path: str) -> datasets.Dataset: + return datasets.Dataset( + list(self.tsv_parser.samples(path)), + self.index, + self.string_parser, + ) + + # Required API. + + def train_dataloader(self) -> data.DataLoader: + assert self.train is not None, "no train path" + return data.DataLoader( + self._dataset(self.train), + collate_fn=self.collator, + batch_size=self.batch_size, + shuffle=True, + num_workers=1, + ) + + def val_dataloader(self) -> data.DataLoader: + assert self.val is not None, "no val path" + return data.DataLoader( + self._dataset(self.val), + collate_fn=self.collator, + batch_size=2 * self.batch_size, # Because no gradients. + num_workers=1, + ) + + def predict_dataloader(self) -> data.DataLoader: + assert self.predict is not None, "no predict path" + return data.DataLoader( + self._dataset(self.predict), + collate_fn=self.collator, + batch_size=2 * self.batch_size, # Because no gradients. + num_workers=1, + ) + + def test_dataloader(self) -> data.DataLoader: + assert self.test is not None, "no test path" + return data.DataLoader( + self._dataset(self.test), + collate_fn=self.collator, + batch_size=2 * self.batch_size, # Because no gradients. + num_workers=1, + ) diff --git a/yoyodyne/data/datasets.py b/yoyodyne/data/datasets.py index 62cad37b..63324ffa 100644 --- a/yoyodyne/data/datasets.py +++ b/yoyodyne/data/datasets.py @@ -4,7 +4,9 @@ superclass constructor, and register the tensor as a buffer. This enables the Trainer to move them to the appropriate device.""" -from typing import List, Optional, Union +import dataclasses + +from typing import Iterator, List, Optional import torch from torch import nn @@ -46,16 +48,14 @@ def has_target(self): return self.target is not None -class BaseDataset(data.Dataset): - """Base datatset class.""" +@dataclasses.dataclass +class Dataset(data.Dataset): + """Datatset class.""" samples: List[str] - index: indexes.Index # Usually copied. + index: indexes.Index # Usually copied from the DataModule. string_parser: tsv.StringParser # Ditto. - def __init__(self): - super().__init__() - @property def has_features(self) -> bool: return self.index.has_features @@ -64,191 +64,143 @@ def has_features(self) -> bool: def has_target(self) -> bool: return self.index.has_target - -class DatasetNoFeatures(BaseDataset): - """Dataset object without feature column.""" - - def __init__( - self, - filename, - tsv_parser, - string_parser, - index: indexes.Index, - ): - """Initializes the dataset. - - Args: - filename (str): input filename. - string_parser (tsv.StringParser). - other (indexes.Index, optional): if provided, - use this index to avoid recomputing it. - """ - super().__init__() - self.samples = list(tsv_parser.samples(filename)) - self.string_parser = string_parser - self.index = index - - def encode( + def _encode( self, + symbols: List[str], symbol_map: indexes.SymbolMap, - word: List[str], - add_start_tag: bool = True, - add_end_tag: bool = True, ) -> torch.Tensor: - """Encodes a sequence as a tensor of indices with word boundary IDs. + """Encodes a sequence as a tensor of indices with string boundary IDs. Args: - symbol_map (indexes.SymbolMap). - word (List[str]): word to be encoded. - add_start_tag (bool, optional): whether the sequence should be - prepended with a start tag. - add_end_tag (bool, optional): whether the sequence should be - prepended with a end tag. + string (str): string to be encoded. + sep (str): separator to use. + symbol_map (indexes.SymbolMap): symbol map to encode with. Returns: torch.Tensor: the encoded tensor. """ - sequence = [] - if add_start_tag: - sequence.append(special.START) - sequence.extend(word) - if add_end_tag: - sequence.append(special.END) return torch.tensor( [ symbol_map.index(symbol, self.index.unk_idx) - for symbol in sequence + for symbol in symbols ], dtype=torch.long, ) + def encode_source(self, string: str) -> torch.Tensor: + """Encodes a source string, padding with start and end tags. + + Args: + string (str). + + Returns: + torch.Tensor. + """ + wrapped = [special.START] + wrapped.extend(self.string_parser.source_symbols(string)) + wrapped.append(special.END) + return self._encode(wrapped, self.index.source_map) + + def encode_features(self, string: str) -> torch.Tensor: + """Encodes a features string. + + Args: + string (str). + + Returns: + torch.Tensor. + """ + return self._encode( + self.string_parser.features_symbols(string), + self.index.features_map, + ) + + def encode_target(self, string: str) -> torch.Tensor: + """Encodes a features string, padding with end tags. + + Args: + string (str). + + Returns: + torch.Tensor. + """ + wrapped = self.string_parser.target_symbols(string) + wrapped.append(special.END) + return self._encode(wrapped, self.index.target_map) + + # Decoding. + def _decode( self, - symbol_map: indexes.SymbolMap, indices: torch.Tensor, - symbols: bool, - special: bool, - ) -> List[List[str]]: - """Decodes the tensor of indices into symbols. + symbol_map: indexes.SymbolMap, + ) -> Iterator[List[str]]: + """Decodes the tensor of indices into lists of symbols. Args: - symbol_map (indexes.SymbolMap). indices (torch.Tensor): 2d tensor of indices. - symbols (bool): whether to include the regular symbols when - decoding the string. - special (bool): whether to include the special symbols when - decoding the string. + symbol_map (indexes.SymbolMap). - Returns: - List[List[str]]: decoded symbols. + Yields: + List[str]: Decoded symbols. """ - - def include(c: int) -> bool: - """Whether to include the symbol when decoding. - - Args: - c (int): a single symbol index. - - Returns: - bool: whether to include the symbol. - """ - include = False - is_special_char = c in self.index.special_idx - if special: - include |= is_special_char - if symbols: - # Symbols will be anything that is not SPECIAL. - include |= not is_special_char - return include - - decoded = [] - for index in indices.cpu().numpy(): - decoded.append([symbol_map.symbol(c) for c in index if include(c)]) - return decoded + for idx in indices.cpu().numpy(): + yield [ + symbol_map.symbol(c) + for c in idx + if c not in self.index.special_idx + ] def decode_source( self, indices: torch.Tensor, - symbols: bool = True, - special: bool = True, - ) -> List[List[str]]: - """Given a tensor of source indices, returns lists of symbols. + ) -> Iterator[str]: + """Decodes a source tensor. Args: indices (torch.Tensor): 2d tensor of indices. - symbols (bool, optional): whether to include the regular symbols - vocabulary when decoding the string. - special (bool, optional): whether to include the special symbols - when decoding the string. - Returns: - List[List[str]]: decoded symbols. + Yields: + str: Decoded source strings. """ - return self._decode( - self.index.source_map, - indices, - symbols=symbols, - special=special, - ) + for symbols in self._decode(indices, self.index.source_map): + yield self.string_parser.source_string(symbols) - def decode_target( + def decode_features( self, indices: torch.Tensor, - symbols: bool = True, - special: bool = True, - ) -> List[List[str]]: - """Given a tensor of target indices, returns lists of symbols. + ) -> Iterator[str]: + """Decodes a features tensor. Args: indices (torch.Tensor): 2d tensor of indices. - special (bool, optional): whether to include the regular symbol - vocabulary when decoding the string. - special (bool, optional): whether to include the special symbols - when decoding the string. - Returns: - List[List[str]]: decoded symbols. + Yields: + str: Decoded features strings. """ - return self._decode( - self.index.target_map, - indices, - symbols=symbols, - special=special, - ) + for symbols in self._decode(indices, self.index.target_map): + yield self.string_parser.feature_string(symbols) - def __getitem__(self, idx: int) -> Item: - """Retrieves item by index. + def decode_target( + self, + indices: torch.Tensor, + ) -> Iterator[str]: + """Decodes a target tensor. Args: - idx (int). + indices (torch.Tensor): 2d tensor of indices. - Returns: - Item. + Yields: + str: Decoded target strings. """ - if self.has_target: - source, target = self.samples[idx] - else: - source = self.samples[idx] - source_encoded = self.encode(self.index.source_map, source) - if self.has_target: - target_encoded = self.encode( - self.index.target_map, target, add_start_tag=False - ) - return Item(source_encoded, target=target_encoded) - else: - return Item(source_encoded) + for symbols in self._decode(indices, self.index.target_map): + yield self.string_parser.target_string(symbols) + + # Required API. def __len__(self) -> int: return len(self.samples) - -class DatasetFeatures(DatasetNoFeatures): - """Dataset object with feature column.""" - - samples: List[str] - index: indexes.Index # Usually copied. - string_parser: tsv.StringParser # Ditto. - def __getitem__(self, idx: int) -> Item: """Retrieves item by index. @@ -258,70 +210,26 @@ def __getitem__(self, idx: int) -> Item: Returns: Item. """ - if self.has_target: - source, features, target = self.samples[idx] - else: - source, features = self.samples[idx] - source_encoded = self.encode(self.index.source_map, source) - features_encoded = self.encode( - self.index.features_map, - features, - add_start_tag=False, - add_end_tag=False, - ) - if self.has_target: + if self.index.has_features: + if self.index.has_target: + source, features, target = self.samples[idx] + return Item( + source=self.encode_source(source), + features=self.encode_features(features), + target=self.encode_target(target), + ) + else: + source, features = self.samples[idx] + return Item( + source=self.encode_source(source), + features=self.encode_features(features), + ) + elif self.index.has_target: + source, target = self.samples[idx] return Item( - source_encoded, - target=self.encode( - self.index.target_map, target, add_start_tag=False - ), - features=features_encoded, + source=self.encode_source(source), + target=self.encode_target(target), ) else: - return Item(source_encoded, features=features_encoded) - - def decode_features( - self, - indices: torch.Tensor, - symbols: bool = True, - special: bool = True, - ) -> List[List[str]]: - """Given a tensor of feature indices, returns lists of symbols. - - Args: - indices (torch.Tensor): 2d tensor of indices. - symbols (bool, optional): whether to include the regular symbols - when decoding the string. - special (bool, optional): whether to include the special symbols - when decoding the string. - - Returns: - List[List[str]]: decoded symbols. - """ - return self._decode( - self.index.features_map, - indices, - symbols=symbols, - special=special, - ) - - -def get_dataset( - filename: str, - tsv_parser: tsv.TsvParser, - string_parser: tsv.StringParser, - index: Union[indexes.Index, str, None] = None, -) -> data.Dataset: - """Dataset factory. - - Args: - filename (str): input filename. - tsv_parser (tsv.TsvParser). - string_parser (tsv.StringParser). - index (indexes.Index). - - Returns: - data.Dataset: the dataset. - """ - cls = DatasetFeatures if tsv_parser.has_features else DatasetNoFeatures - return cls(filename, tsv_parser, string_parser, index) + source = self.samples[idx] + return Item(source=self.encode_source(source)) diff --git a/yoyodyne/data/tsv.py b/yoyodyne/data/tsv.py index 07731667..d87adf1e 100644 --- a/yoyodyne/data/tsv.py +++ b/yoyodyne/data/tsv.py @@ -66,10 +66,6 @@ def _get_string(row: List[str], col: int) -> str: """ return row[col - 1] # -1 because we're using one-based indexing. - @property - def has_source(self) -> bool: - return True - @property def has_features(self) -> bool: return self.features_col != 0 diff --git a/yoyodyne/models/__init__.py b/yoyodyne/models/__init__.py index 9a645398..201dfdf1 100644 --- a/yoyodyne/models/__init__.py +++ b/yoyodyne/models/__init__.py @@ -49,7 +49,7 @@ def get_model_cls_from_argparse_args( Returns: BaseEncoderDecoder. """ - return get_model_cls(args.arch, args.features_col != 0) + return get_model_cls(args.arch) def add_argparse_args(parser: argparse.ArgumentParser) -> None: diff --git a/yoyodyne/predict.py b/yoyodyne/predict.py index 58e1a33d..6d9113fb 100644 --- a/yoyodyne/predict.py +++ b/yoyodyne/predict.py @@ -4,9 +4,8 @@ import os import pytorch_lightning as pl -from torch.utils import data as torch_data -from . import data, defaults, models, util +from . import data, models, util def get_trainer_from_argparse_args( @@ -23,61 +22,35 @@ def get_trainer_from_argparse_args( return pl.Trainer.from_argparse_args(args, max_epochs=0) -def get_dataset_from_argparse_args( +def get_datamodule_from_argparse_args( args: argparse.Namespace, -) -> data.BaseDataset: +) -> data.DataModule: """Creates the dataset from CLI arguments. Args: args (argparse.Namespace). Returns: - data.BaseDataset. + data.DataModule. """ - tsv_parser = data.TsvParser( + separate_features = args.features_col != 0 and args.arch in [ + "pointer_generator_lstm", + "transducer", + ] + # TODO(kbg): reuse index? + return data.DataModule( + predict=args.predict, + batch_size=args.batch_size, source_col=args.source_col, features_col=args.features_col, target_col=args.target_col, - ) - string_parser = data.StringParser( source_sep=args.source_sep, features_sep=args.features_sep, target_sep=args.target_sep, - ) - index = data.Index.read(args.index) - return data.get_dataset(args.predict, tsv_parser, string_parser, index) - - -def get_loader( - dataset: data.BaseDataset, - arch: str, - batch_size: int, - max_source_length: int, - max_target_length: int, -) -> torch_data.DataLoader: - """Creates the loader. - - Args: - dataset (data.Dataset). - arch (str). - batch_size (int). - max_source_length (int). - max_target_length (int). - - Returns: - torch_data.DataLoader. - """ - collator = data.Collator( - dataset, - arch, - max_source_length, - max_target_length, - ) - return torch_data.DataLoader( - dataset, - collate_fn=collator, - batch_size=batch_size, - num_workers=1, + tied_vocabulary=args.tied_vocabulary, + separate_features=separate_features, + max_source_length=args.max_source_length, + max_target_length=args.max_target_length, ) @@ -109,8 +82,8 @@ def _mkdir(output: str) -> None: def predict( trainer: pl.Trainer, - model: pl.LightningModule, - loader: torch_data.DataLoader, + model: models.BaseEncoderDecoder, + datamodule: data.DataModule, output: str, ) -> None: """Predicts from the model. @@ -118,25 +91,19 @@ def predict( Args: trainer (pl.Trainer). model (pl.LightningModule). - loader (torch_data.DataLoader). + dataomdule (data.DataModule). output (str). - target_sep (str). """ - dataset = loader.dataset - target_sep = dataset.config.target_sep util.log_info(f"Writing to {output}") _mkdir(output) + decode_target = datamodule.predict_dataloader().dataset.decode_target with open(output, "w") as sink: - for batch in trainer.predict(model, dataloaders=loader): + for batch in trainer.predict(model, datamodule=datamodule): batch = model.evaluator.finalize_predictions( - batch, dataset.index.end_idx, dataset.index.pad_idx + batch, datamodule.index.end_idx, datamodule.index.pad_idx ) - for prediction in dataset.decode_target( - batch, - symbols=True, - special=False, - ): - print(target_sep.join(prediction), file=sink) + for prediction in decode_target(batch): + print(prediction, file=sink) def add_argparse_args(parser: argparse.ArgumentParser) -> None: @@ -163,13 +130,7 @@ def add_argparse_args(parser: argparse.ArgumentParser) -> None: parser.add_argument( "--checkpoint", required=True, help="Path to checkpoint (.ckpt)." ) - # Predicting arguments. - parser.add_argument( - "--batch_size", - type=int, - default=defaults.BATCH_SIZE, - help="Batch size. Default: %(default)s.", - ) + # Prediction arguments. # TODO: add --beam_width. # Data arguments. data.add_argparse_args(parser) @@ -188,15 +149,9 @@ def main() -> None: args = parser.parse_args() util.log_arguments(args) trainer = get_trainer_from_argparse_args(args) - loader = get_loader( - get_dataset_from_argparse_args(args), - args.arch, - args.batch_size, - args.max_source_length, - args.max_target_length, - ) + datamodule = get_datamodule_from_argparse_args(args) model = get_model_from_argparse_args(args) - predict(trainer, model, loader, args.output) + predict(trainer, model, datamodule, args.output) if __name__ == "__main__": diff --git a/yoyodyne/train.py b/yoyodyne/train.py index ecb0d5d3..d5ab7e9c 100644 --- a/yoyodyne/train.py +++ b/yoyodyne/train.py @@ -1,7 +1,7 @@ """Trains a sequence-to-sequence neural network.""" import argparse -from typing import List, Optional, Set, Tuple +from typing import List, Optional import pytorch_lightning as pl import wandb @@ -95,133 +95,44 @@ def get_trainer_from_argparse_args( ) -def get_data_from_argparse_args( +def get_datamodule_from_argparse_args( args: argparse.Namespace, -) -> Tuple[data.BaseDataset, data.BaseDataset]: - """Creates the data from CLI arguments. +) -> data.DataModule: + """Creates the datamodule from CLI arguments. Args: - args (argparse.Namespace). + args (Argparse.Namespace). Returns: - Tuple[data.BaseDataset, data.BaseDataset]: the training and - development data. + data.DataModule. """ - tsv_parser = data.TsvParser( + separate_features = args.features_col != 0 and args.arch in [ + "pointer_generator_lstm", + "transducer", + ] + datamodule = data.DataModule( + train=args.train, + val=args.val, + batch_size=args.batch_size, source_col=args.source_col, features_col=args.features_col, target_col=args.target_col, - ) - if not tsv_parser.has_target: - raise Error("target_col must be specified for training") - string_parser = data.StringParser( source_sep=args.source_sep, features_sep=args.features_sep, target_sep=args.target_sep, + tied_vocabulary=args.tied_vocabulary, + separate_features=separate_features, + max_source_length=args.max_source_length, + max_target_length=args.max_target_length, ) - # TODO: move this into the data module. - separate_features = tsv_parser.has_features and args.arch in [ - "pointer_generator_lstm", - "transducer", - ] - # Computes index. - source_vocabulary: Set[str] = set() - features_vocabulary: Set[str] = set() - target_vocabulary: Set[str] = set() - for path in [args.train, args.dev]: - if tsv_parser.has_features: - if tsv_parser.has_target: - for source, features, target in tsv_parser.samples(path): - source_vocabulary.update( - string_parser.source_symbols(source) - ) - features_vocabulary.update( - string_parser.features_symbols(features) - ) - target_vocabulary.update( - string_parser.target_symbols(target) - ) - else: - for source, features in tsv_parser.samples(path): - source_vocabulary.update( - string_parser.source_symbols(source) - ) - features_vocabulary.update( - string_parser.features_symbols(features) - ) - elif tsv_parser.has_target: - for source, target in tsv_parser.samples(path): - source_vocabulary.update(string_parser.source_symbols(source)) - target_vocabulary.update(string_parser.target_symbols(target)) - else: - for source in tsv_parser.samples(path): - source_vocabulary.update(string_parser.source_symbols(source)) - if tsv_parser.has_target and args.tied_vocabulary: - source_vocabulary.update(target_vocabulary) - target_vocabulary.update(source_vocabulary) - index = data.Index( - source_vocabulary=sorted(source_vocabulary), - features_vocabulary=sorted(features_vocabulary) - if separate_features - else None, - target_vocabulary=sorted(target_vocabulary), - ) - util.log_info(f"Source vocabulary: {index.source_map.pprint()}") - if tsv_parser.has_features: - util.log_info(f"Feature vocabulary: {index.features_map.pprint()}") - util.log_info(f"Target vocabulary: {index.target_map.pprint()}") - train_set = data.get_dataset(args.train, tsv_parser, string_parser, index) - dev_set = data.get_dataset(args.dev, tsv_parser, string_parser, index) - return train_set, dev_set - - -def get_loaders( - train_set: data.BaseDataset, - dev_set: data.BaseDataset, - arch: str, - batch_size: int, - max_source_length: int, - max_target_length: int, -) -> Tuple[torch_data.DataLoader, torch_data.DataLoader]: - """Creates the loaders. - - Args: - train_set (data.BaseDataset). - dev_set (data.BaseDataset). - arch (str). - batch_size (int). - max_source_length (int). - max_target_length (int). - - Returns: - Tuple[data.DataLoader, data.DataLoader]: the training and development - loaders. - """ - collator = data.Collator( - train_set, - arch, - max_source_length, - max_target_length, - ) - train_loader = torch_data.DataLoader( - train_set, - collate_fn=collator, - batch_size=batch_size, - shuffle=True, - num_workers=1, # Our data loading is simple. - ) - dev_loader = torch_data.DataLoader( - dev_set, - collate_fn=collator, - batch_size=2 * batch_size, # Because we're not collecting gradients. - num_workers=1, - ) - return train_loader, dev_loader + datamodule.write_index(args.model_dir, args.experiment) + datamodule.log_vocabularies() + return datamodule def get_model_from_argparse_args( - train_set: data.BaseDataset, args: argparse.Namespace, + datamodule: data.DataModule, ) -> models.BaseEncoderDecoder: """Creates the model. @@ -238,7 +149,7 @@ def get_model_from_argparse_args( ) expert = ( models.expert.get_expert( - train_set, + datamodule.train_loader().dataset, epochs=args.oracle_em_epochs, oracle_factor=args.oracle_factor, sed_params_path=args.sed_params, @@ -247,7 +158,7 @@ def get_model_from_argparse_args( else None ) scheduler_kwargs = schedulers.get_scheduler_kwargs_from_argparse_args(args) - separate_features = train_set.has_features and args.arch in [ + separate_features = datamodule.has_features and args.arch in [ "pointer_generator_lstm", "transducer", ] @@ -259,12 +170,12 @@ def get_model_from_argparse_args( else None ) features_vocab_size = ( - train_set.index.features_vocab_size if train_set.has_features else 0 + datamodule.index.features_vocab_size if datamodule.has_features else 0 ) source_vocab_size = ( - train_set.index.source_vocab_size + features_vocab_size + datamodule.index.source_vocab_size + features_vocab_size if not separate_features - else train_set.index.source_vocab_size + else datamodule.index.source_vocab_size ) # Please pass all arguments by keyword and keep in lexicographic order. return model_cls( @@ -277,7 +188,7 @@ def get_model_from_argparse_args( dropout=args.dropout, embedding_size=args.embedding_size, encoder_layers=args.encoder_layers, - end_idx=train_set.index.end_idx, + end_idx=datamodule.index.end_idx, expert=expert, features_encoder_cls=features_encoder_cls, features_vocab_size=features_vocab_size, @@ -287,14 +198,14 @@ def get_model_from_argparse_args( max_source_length=args.max_source_length, max_target_length=args.max_target_length, optimizer=args.optimizer, - output_size=train_set.index.target_vocab_size, - pad_idx=train_set.index.pad_idx, + output_size=datamodule.index.target_vocab_size, + pad_idx=datamodule.index.pad_idx, scheduler=args.scheduler, scheduler_kwargs=scheduler_kwargs, source_encoder_cls=source_encoder_cls, source_vocab_size=source_vocab_size, - start_idx=train_set.index.start_idx, - target_vocab_size=train_set.index.target_vocab_size, + start_idx=datamodule.index.start_idx, + target_vocab_size=datamodule.index.target_vocab_size, ) @@ -341,9 +252,9 @@ def add_argparse_args(parser: argparse.ArgumentParser) -> None: help="Path to input training data TSV.", ) parser.add_argument( - "--dev", + "--val", required=True, - help="Path to input development data TSV.", + help="Path to input validation data TSV.", ) parser.add_argument( "--model_dir", @@ -355,12 +266,6 @@ def add_argparse_args(parser: argparse.ArgumentParser) -> None: help="Path to ckpt checkpoint to resume training from.", ) # Other training arguments. - parser.add_argument( - "--batch_size", - type=int, - default=defaults.BATCH_SIZE, - help="Batch size. Default: %(default)s.", - ) parser.add_argument( "--patience", type=int, help="Patience for early stopping." ) @@ -418,32 +323,19 @@ def main() -> None: util.log_arguments(args) pl.seed_everything(args.seed) trainer = get_trainer_from_argparse_args(args) - train_set, dev_set = get_data_from_argparse_args(args) - index = train_set.index.index_path(args.model_dir, args.experiment) - train_set.index.write(index) - util.log_info(f"Index: {index}") - train_loader, dev_loader = get_loaders( - train_set, - dev_set, - args.arch, - args.batch_size, - args.max_source_length, - args.max_target_length, - ) - model = get_model_from_argparse_args(train_set, args) + datamodule = get_datamodule_from_argparse_args(args) + model = get_model_from_argparse_args(args, datamodule) # Tuning options. Batch autoscaling is unsupported; LR tuning logs the # suggested value and then exits. if args.auto_scale_batch_size: raise Error("Batch auto-scaling is not supported") return if args.auto_lr_find: - result = trainer.tuner.lr_find(model, train_loader, dev_loader) + result = trainer.tuner.lr_find(model, datamodule=datamodule) util.log_info(f"Best initial LR: {result.suggestion():.8f}") return # Otherwise, train and log the best checkpoint. - best_checkpoint = train( - trainer, model, train_loader, dev_loader, train_from=args.train_from - ) + best_checkpoint = train(trainer, model, datamodule, args.train_from) util.log_info(f"Best checkpoint: {best_checkpoint}") From a92e9171356f5f4c7927615dde5caf242ca69b7e Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Sun, 16 Jul 2023 20:15:44 -0400 Subject: [PATCH 08/18] Finally gets it working again. --- README.md | 32 ++++++++-- yoyodyne/data/datamodules.py | 84 ++++++++++++-------------- yoyodyne/data/datasets.py | 34 +++++------ yoyodyne/data/tsv.py | 113 +++++++++++++++-------------------- yoyodyne/train.py | 3 +- 5 files changed, 130 insertions(+), 136 deletions(-) diff --git a/README.md b/README.md index a05fa1de..9cb088c8 100644 --- a/README.md +++ b/README.md @@ -54,8 +54,33 @@ import yoyodyne ## Usage -See [`yoyodyne-predict --help`](yoyodyne/predict.py) and -[`yoyodyne-train --help`](yoyodyne/train.py). +### Training + +Training is performed by the [`yoyodyne-train`](yoyodyne.train.py%60) script. +One must specify the following required arguments: + +- `--train`: path to TSV file containing training data +- `--val`: path to TSV file containing validation data +- `--experiment`: name of experiment (pick something unique) +- `--model_dir`: path for model metadata and checkpoints output during + training + +The user can also specify as well as various optional training and architectural +arguments. See below or run [`yoyodyne-train --help`](yoyodyne/train.py) for +more information. + +### Prediction + +Prediction is performed by the [`yoyodyne-predict`](yoyodyne.predict.py%60) +script. One must specify the following required arguments: + +- `--predict`: path to TSV file containing data to be predicted +- `--checkpoint`: path to checkpoint +- `--experiment`: name of experiment +- `--index`: path to index +- `--output`: path for predictions + +Run [`yoyodyne-predict --help`](yoyodyne/predict.py) for more information. ## Data format @@ -153,8 +178,7 @@ source encoder using the `--source_encoder` flag: - `"transformer"`: This is a transformer encoder. When using features, the user can also specify a non-default features encoder -using the `--features_encoder` flag (`"linear"`, `"lstm"`, -`"transformer"`). +using the `--features_encoder` flag (`"linear"`, `"lstm"`, `"transformer"`). For all models, the user may also wish to specify: diff --git a/yoyodyne/data/datamodules.py b/yoyodyne/data/datamodules.py index d8c63dc6..079140a8 100644 --- a/yoyodyne/data/datamodules.py +++ b/yoyodyne/data/datamodules.py @@ -1,6 +1,6 @@ """Data modules.""" -from typing import Optional, Set +from typing import Iterator, Optional, Set import pytorch_lightning as pl from torch.utils import data @@ -12,8 +12,7 @@ class DataModule(pl.LightningDataModule): """Parses, indexes, collates and loads data.""" - tsv_parser: tsv.TsvParser - string_parser: tsv.StringParser + parser: tsv.TsvParser index: indexes.Index batch_size: int collator: collators.Collator @@ -44,9 +43,13 @@ def __init__( max_target_length: int = defaults.MAX_TARGET_LENGTH, ): super().__init__() - self.tsv_parser = tsv.TsvParser(source_col, features_col, target_col) - self.string_parser = tsv.StringParser( - source_sep, features_sep, target_sep + self.parser = tsv.TsvParser( + source_col=source_col, + features_col=features_col, + target_col=target_col, + source_sep=source_sep, + features_sep=features_sep, + target_sep=target_sep, ) self.train = train self.val = val @@ -56,45 +59,25 @@ def __init__( source_vocabulary: Set[str] = set() features_vocabulary: Set[str] = set() target_vocabulary: Set[str] = set() - for path in [self.train, self.val, self.predict, self.test]: - if path is None: - continue - if self.tsv_parser.has_features: - if self.tsv_parser.has_target: - for source, features, target in self.tsv_parser.samples( - path - ): - source_vocabulary.update( - self.string_parser.source_symbols(source) - ) - features_vocabulary.update( - self.string_parser.features_symbols(features) - ) - target_vocabulary.update( - self.string_parser.target_symbols(target) - ) + for path in self.paths: + if self.parser.has_features: + if self.parser.has_target: + for source, features, target in self.parser.samples(path): + source_vocabulary.update(source) + features_vocabulary.update(features) + target_vocabulary.update(target) else: - for source, features in self.tsv_parser.samples(path): - source_vocabulary.update( - self.string_parser.source_symbols(source) - ) - features_vocabulary.update( - self.string_parser.features_symbols(features) - ) - elif self.tsv_parser.has_target: - for source, target in self.tsv_parser.samples(path): - source_vocabulary.update( - self.string_parser.source_symbols(source) - ) - target_vocabulary.update( - self.string_parser.target_symbols(target) - ) + for source, features in self.parser.samples(path): + source_vocabulary.update(source) + features_vocabulary.update(features) + elif self.parser.has_target: + for source, target in self.parser.samples(path): + source_vocabulary.update(source) + target_vocabulary.update(target) else: - for source in self.tsv_parser.samples(path): - source_vocabulary.update( - self.string_parser.source_symbols(source) - ) - if self.tsv_parser.has_target and tied_vocabulary: + for source in self.parser.samples(path): + source_vocabulary.update(source) + if self.parser.has_target and tied_vocabulary: source_vocabulary.update(target_vocabulary) target_vocabulary.update(source_vocabulary) self.separate_features = separate_features @@ -123,6 +106,17 @@ def __init__( # Helpers. + @property + def paths(self) -> Iterator[str]: + if self.train is not None: + yield self.train + if self.val is not None: + yield self.val + if self.predict is not None: + yield self.predict + if self.test is not None: + yield self.test + def log_vocabularies(self) -> None: """Logs this module's vocabularies.""" util.log_info(f"Source vocabulary: {self.index.source_map.pprint()}") @@ -160,9 +154,9 @@ def source_vocab_size(self) -> int: def _dataset(self, path: str) -> datasets.Dataset: return datasets.Dataset( - list(self.tsv_parser.samples(path)), + list(self.parser.samples(path)), self.index, - self.string_parser, + self.parser, ) # Required API. diff --git a/yoyodyne/data/datasets.py b/yoyodyne/data/datasets.py index 63324ffa..0669317f 100644 --- a/yoyodyne/data/datasets.py +++ b/yoyodyne/data/datasets.py @@ -52,9 +52,9 @@ def has_target(self): class Dataset(data.Dataset): """Datatset class.""" - samples: List[str] + samples: List[List[str]] index: indexes.Index # Usually copied from the DataModule. - string_parser: tsv.StringParser # Ditto. + parser: tsv.TsvParser # Ditto. @property def has_features(self) -> bool: @@ -87,44 +87,41 @@ def _encode( dtype=torch.long, ) - def encode_source(self, string: str) -> torch.Tensor: + def encode_source(self, symbols: List[str]) -> torch.Tensor: """Encodes a source string, padding with start and end tags. Args: - string (str). + symbols (List[str]). Returns: torch.Tensor. """ wrapped = [special.START] - wrapped.extend(self.string_parser.source_symbols(string)) + wrapped.extend(symbols) wrapped.append(special.END) return self._encode(wrapped, self.index.source_map) - def encode_features(self, string: str) -> torch.Tensor: + def encode_features(self, symbols: List[str]) -> torch.Tensor: """Encodes a features string. Args: - string (str). + symbols (List[str]). Returns: torch.Tensor. """ - return self._encode( - self.string_parser.features_symbols(string), - self.index.features_map, - ) + return self._encode(symbols, self.index.features_map) - def encode_target(self, string: str) -> torch.Tensor: + def encode_target(self, symbols: List[str]) -> torch.Tensor: """Encodes a features string, padding with end tags. Args: - string (str). + symbols (List[str]). Returns: torch.Tensor. """ - wrapped = self.string_parser.target_symbols(string) + wrapped = symbols wrapped.append(special.END) return self._encode(wrapped, self.index.target_map) @@ -164,7 +161,7 @@ def decode_source( str: Decoded source strings. """ for symbols in self._decode(indices, self.index.source_map): - yield self.string_parser.source_string(symbols) + yield self.parser.source_string(symbols) def decode_features( self, @@ -179,7 +176,7 @@ def decode_features( str: Decoded features strings. """ for symbols in self._decode(indices, self.index.target_map): - yield self.string_parser.feature_string(symbols) + yield self.parser.feature_string(symbols) def decode_target( self, @@ -194,7 +191,7 @@ def decode_target( str: Decoded target strings. """ for symbols in self._decode(indices, self.index.target_map): - yield self.string_parser.target_string(symbols) + yield self.parser.target_string(symbols) # Required API. @@ -231,5 +228,4 @@ def __getitem__(self, idx: int) -> Item: target=self.encode_target(target), ) else: - source = self.samples[idx] - return Item(source=self.encode_source(source)) + return Item(source=self.encode_source(self.samples[idx])) diff --git a/yoyodyne/data/tsv.py b/yoyodyne/data/tsv.py index d87adf1e..60300e66 100644 --- a/yoyodyne/data/tsv.py +++ b/yoyodyne/data/tsv.py @@ -1,15 +1,13 @@ """TSV parsing. -The TsvParser yield string tuples from TSV files using 1-based indexing. - -The CellParser converts between raw strings ("strings") and lists of string -symbols. +The TsvParser yields data from TSV files using 1-based indexing and custom +separators. """ import csv import dataclasses -from typing import Iterator, List, Tuple +from typing import Iterator, List, Tuple, Union from .. import defaults, util @@ -22,7 +20,7 @@ class Error(Exception): @dataclasses.dataclass class TsvParser: - """Streams rows from a TSV file. + """Streams data from a TSV file. Args: source_col (int, optional): 1-indexed column in TSV containing @@ -31,11 +29,23 @@ class TsvParser: features strings. target_col (int, optional): 1-indexed column in TSV containing target strings. + source_sep (str, optional): string used to split source string into + symbols; an empty string indicates that each Unicode codepoint is + its own symbol. + features_sep (str, optional): string used to split features string into + symbols; an empty string indicates that each Unicode codepoint is + its own symbol. + target_sep (str, optional): string used to split target string into + symbols; an empty string indicates that each Unicode codepoint is + its own symbol. """ source_col: int = defaults.SOURCE_COL features_col: int = defaults.FEATURES_COL target_col: int = defaults.TARGET_COL + source_sep: str = defaults.SOURCE_SEP + features_sep: str = defaults.FEATURES_SEP + target_sep: str = defaults.TARGET_SEP def __post_init__(self) -> None: # This is automatically called after initialization. @@ -74,76 +84,47 @@ def has_features(self) -> bool: def has_target(self) -> bool: return self.target_col != 0 - def source_samples(self, path: str) -> Iterator[str]: - """Yields source.""" - for row in self._tsv_reader(path): - yield self._get_string(row, self.source_col) - - def source_target_samples(self, path: str) -> Iterator[Tuple[str, str]]: - """Yields source and target.""" - for row in self._tsv_reader(path): - source = self._get_string(row, self.source_col) - target = self._get_string(row, self.target_col) - yield source, target - - def source_features_target_samples( + def samples( self, path: str - ) -> Iterator[Tuple[str, str, str]]: - """Yields source, features, and target.""" - for row in self._tsv_reader(path): - source = self._get_string(row, self.source_col) - features = self._get_string(row, self.features_col) - target = self._get_string(row, self.target_col) - yield source, features, target - - def source_features_samples(self, path: str) -> Iterator[Tuple[str, str]]: - """Yields source, and features.""" + ) -> Iterator[ + Union[ + List[str], + Tuple[List[str], List[str]], + Tuple[List[str], List[str], List[str]], + ] + ]: + """Yields source, and features and/or target if available.""" for row in self._tsv_reader(path): - source = self._get_string(row, self.source_col) - features = self._get_string(row, self.features_col) - yield source, features - - def samples(self, path: str) -> Iterator[Tuple[str, ...]]: - """Picks the right one.""" - if self.has_features: - if self.has_target: - self.source_features_target_samples(path) + source = self.source_symbols( + self._get_string(row, self.source_col) + ) + if self.has_features: + features = self.features_symbols( + self._get_string(row, self.features_col) + ) + if self.has_target: + target = self.target_symbols( + self._get_string(row, self.target_col) + ) + yield source, features, target + else: + yield source, features + elif self.has_target: + target = self.target_symbols( + self._get_string(row, self.target_col) + ) + yield source, target else: - return self.source_features_samples(path) - elif self.has_target: - return self.source_target_samples(path) - else: - return self.source_samples(path) - - -@dataclasses.dataclass -class StringParser: - """Parses strings from the TSV file into lists of symbols. - - Args: - source_sep (str, optional): string used to split source string into - symbols; an empty string indicates that each Unicode codepoint is - its own symbol. - features_sep (str, optional): string used to split features string into - symbols; an empty string indicates that each Unicode codepoint is - its own symbol. - target_sep (str, optional): string used to split target string into - symbols; an empty string indicates that each Unicode codepoint is - its own symbol. - """ - - source_sep: str = defaults.SOURCE_SEP - features_sep: str = defaults.FEATURES_SEP - target_sep: str = defaults.TARGET_SEP + yield source - # Parsing methods. + # String parsing methods. @staticmethod def _get_symbols(string: str, sep: str) -> List[str]: return list(string) if not sep else sep.split(string) def source_symbols(self, string: str) -> List[str]: - return self._get_symbols(string, self.features_sep) + return self._get_symbols(string, self.source_sep) def features_symbols(self, string: str) -> List[str]: # We deliberately obfuscate these to avoid overlap with source. diff --git a/yoyodyne/train.py b/yoyodyne/train.py index d5ab7e9c..a41de703 100644 --- a/yoyodyne/train.py +++ b/yoyodyne/train.py @@ -35,7 +35,6 @@ def _get_logger(experiment: str, model_dir: str, log_wandb: bool) -> List: wandb.define_metric("val_accuracy", summary="max") # Logs the path to local artifacts made by PTL. wandb.config.update({"local_run_dir": trainer_logger[0].log_dir}) - return trainer_logger @@ -149,7 +148,7 @@ def get_model_from_argparse_args( ) expert = ( models.expert.get_expert( - datamodule.train_loader().dataset, + datamodule.train_dataloader().dataset, epochs=args.oracle_em_epochs, oracle_factor=args.oracle_factor, sed_params_path=args.sed_params, From c62888221b2d96c1ac738b7b725f4b4d3260b9fa Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Sun, 16 Jul 2023 20:24:39 -0400 Subject: [PATCH 09/18] Updates tests to reflect. --- tests/collator_test.py | 47 ------------------------------------------ tests/dataset_test.py | 14 ------------- 2 files changed, 61 deletions(-) delete mode 100644 tests/collator_test.py delete mode 100644 tests/dataset_test.py diff --git a/tests/collator_test.py b/tests/collator_test.py deleted file mode 100644 index 9e002df8..00000000 --- a/tests/collator_test.py +++ /dev/null @@ -1,47 +0,0 @@ -import pytest - -from yoyodyne import collators, dataconfig, datasets - - -@pytest.mark.parametrize( - ["arch", "has_features", "has_target", "expected_separate_features"], - [ - ("feature_invariant_transformer", True, True, False), - ("feature_invariant_transformer", True, False, False), - ("lstm", True, True, False), - ("lstm", False, True, False), - ("lstm", True, False, False), - ("lstm", False, False, False), - ("pointer_generator_lstm", True, True, True), - ("pointer_generator_lstm", False, True, False), - ("pointer_generator_lstm", True, False, True), - ("pointer_generator_lstm", False, False, False), - ("transducer", True, True, True), - ("transducer", False, True, False), - ("transducer", True, False, True), - ("transducer", False, False, False), - ("transformer", True, True, False), - ("transformer", False, True, False), - ("transformer", True, False, False), - ("transformer", False, False, False), - ], -) -def test_get_collator( - make_trivial_tsv_file, - arch, - has_features, - has_target, - expected_separate_features, -): - filename = make_trivial_tsv_file - config = dataconfig.DataConfig( - features_col=3 if has_features else 0, - target_col=2 if has_target else 0, - ) - dataset = datasets.get_dataset(filename, config) - collator = collators.Collator( - dataset, - arch, - ) - assert collator.has_target == has_target - assert collator.separate_features == expected_separate_features diff --git a/tests/dataset_test.py b/tests/dataset_test.py deleted file mode 100644 index 21421e36..00000000 --- a/tests/dataset_test.py +++ /dev/null @@ -1,14 +0,0 @@ -import pytest - -from yoyodyne import dataconfig, datasets - - -@pytest.mark.parametrize( - "features_col, expected_cls", - [(0, datasets.DatasetNoFeatures), (3, datasets.DatasetFeatures)], -) -def test_get_dataset(make_trivial_tsv_file, features_col, expected_cls): - filename = make_trivial_tsv_file - config = dataconfig.DataConfig(features_col=features_col) - dataset = datasets.get_dataset(filename, config) - assert type(dataset) is expected_cls From 82d485f63337f0b81ac027e4023d04bcc0bfb373 Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Sun, 16 Jul 2023 20:30:11 -0400 Subject: [PATCH 10/18] Fixes some REAMDE typos. --- README.md | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 9cb088c8..f0b95e98 100644 --- a/README.md +++ b/README.md @@ -56,8 +56,8 @@ import yoyodyne ### Training -Training is performed by the [`yoyodyne-train`](yoyodyne.train.py%60) script. -One must specify the following required arguments: +Training is performed by the [`yoyodyne-train`](yoyodyne.train.py) script. One +must specify the following required arguments: - `--train`: path to TSV file containing training data - `--val`: path to TSV file containing validation data @@ -98,11 +98,12 @@ the third contains semi-colon delimited feature strings: this format is specified by `--features-col 3`. -Alternatively, for the SIGMORPHON 2016 shared task data format: +Alternatively, for the [SIGMORPHON 2016 shared +task](https://sigmorphon.github.io/sharedtasks/2016/) data: source feat1,feat2,... target -this format is specified by `--features-col 2 --features-sep , --target-col 3`. +this format is specified by `--features_col 2 --features_sep , --target_col 3`. In order to ensure that targets are ignored during prediction, one can specify `--target_col 0`. From c40f1cb25ab14180d5b5003bb3d4522a7f0ba1ff Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Sun, 16 Jul 2023 20:33:46 -0400 Subject: [PATCH 11/18] Update batches.py --- yoyodyne/data/batches.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yoyodyne/data/batches.py b/yoyodyne/data/batches.py index 2e1648b8..dd42a6cf 100644 --- a/yoyodyne/data/batches.py +++ b/yoyodyne/data/batches.py @@ -28,7 +28,7 @@ def __init__( ): """Constructs the padded tensor from a list of tensors. - The optional pad_len argument can be used, e.g., to keep all data + The optional pad_len argument can be used, e.g., to keep all batches the exact same length, which improves performance on certain accelerators. If not specified, it will be computed using the length of the longest input tensor. From ca885c5e9f85747e771e9f85cb588a9a9835554c Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Sun, 16 Jul 2023 20:36:51 -0400 Subject: [PATCH 12/18] More typo cleanups. --- pyproject.toml | 2 +- yoyodyne/models/lstm.py | 8 ++++---- yoyodyne/models/pointer_generator.py | 8 ++++---- yoyodyne/models/transformer.py | 8 ++++---- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7206b1d3..56ad8ebb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ exclude = ["examples*"] [project] name = "yoyodyne" -version = "0.2.3" +version = "0.2.4" description = "Small-vocabulary neural sequence-to-sequence models" readme = "README.md" requires-python = ">= 3.9" diff --git a/yoyodyne/models/lstm.py b/yoyodyne/models/lstm.py index 0619a985..e656e106 100644 --- a/yoyodyne/models/lstm.py +++ b/yoyodyne/models/lstm.py @@ -139,10 +139,10 @@ def decode( finished = torch.logical_or( finished, (decoder_input == self.end_idx) ) - # Breaks when all data predicted an EOS symbol. - # If we have a target (and are thus computing loss), - # we only break when we have decoded at least the the - # same number of steps as the target length. + # Breaks when all sequences have predicted an EOS symbol. If we + # have a target (and are thus computing loss), we only break + # when we have decoded at least the the same number of steps as + # the target length. if finished.all(): if target is None or decoder_input.size(-1) >= target.size( -1 diff --git a/yoyodyne/models/pointer_generator.py b/yoyodyne/models/pointer_generator.py index 18e07124..c6889dff 100644 --- a/yoyodyne/models/pointer_generator.py +++ b/yoyodyne/models/pointer_generator.py @@ -289,10 +289,10 @@ def decode( finished = torch.logical_or( finished, (decoder_input == self.end_idx) ) - # Breaks when all data predicted an EOS symbol. - # If we have a target (and are thus computing loss), - # we only break when we have decoded at least the the - # same number of steps as the target length. + # Breaks when all sequences have predicted an EOS symbol.If we + # have a target (and are thus computing loss), we only break + # when we have decoded at least the the same number of steps as + # the target length. if finished.all(): if target is None or decoder_input.size(-1) >= target.size( -1 diff --git a/yoyodyne/models/transformer.py b/yoyodyne/models/transformer.py index 5b000e91..91250a29 100644 --- a/yoyodyne/models/transformer.py +++ b/yoyodyne/models/transformer.py @@ -105,10 +105,10 @@ def _decode_greedy( finished = torch.logical_or( finished, (predictions[-1] == self.end_idx) ) - # Breaks when all data predicted an EOS symbol. - # If we have a target (and are thus computing loss), - # we only break when we have decoded at least the the - # same number of steps as the target length. + # Breaks when all sequences have predicted an EOS symbol. If we + # have a target (and are thus computing loss), we only break when + # we have decoded at least the the same number of steps as the + # target length. if finished.all(): if targets is None or len(outputs) >= targets.size(-1): break From b83815b3c257148e36c79974669f561f4173aed6 Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Sun, 16 Jul 2023 20:38:39 -0400 Subject: [PATCH 13/18] Late-breaking typos. --- yoyodyne/models/pointer_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yoyodyne/models/pointer_generator.py b/yoyodyne/models/pointer_generator.py index c6889dff..96a4fb74 100644 --- a/yoyodyne/models/pointer_generator.py +++ b/yoyodyne/models/pointer_generator.py @@ -289,7 +289,7 @@ def decode( finished = torch.logical_or( finished, (decoder_input == self.end_idx) ) - # Breaks when all sequences have predicted an EOS symbol.If we + # Breaks when all sequences have predicted an EOS symbol. If we # have a target (and are thus computing loss), we only break # when we have decoded at least the the same number of steps as # the target length. From 8fd10cba34f758a56fd24f95252b30798f11990d Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Mon, 17 Jul 2023 12:51:39 -0400 Subject: [PATCH 14/18] Fixes some late-breaking typos. --- yoyodyne/data/datamodules.py | 7 ++++--- yoyodyne/data/datasets.py | 2 +- yoyodyne/data/tsv.py | 16 +++++++--------- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/yoyodyne/data/datamodules.py b/yoyodyne/data/datamodules.py index 079140a8..9e1322d0 100644 --- a/yoyodyne/data/datamodules.py +++ b/yoyodyne/data/datamodules.py @@ -83,11 +83,12 @@ def __init__( self.separate_features = separate_features self.index = indexes.Index( source_vocabulary=sorted(source_vocabulary), - # These two are stored as nulls if empty. features_vocabulary=sorted(features_vocabulary) - if self.separate_features + if features_vocabulary + else None, + target_vocabulary=sorted(target_vocabulary) + if target_vocabulary else None, - target_vocabulary=sorted(target_vocabulary), ) # Stores batch size. self.batch_size = batch_size diff --git a/yoyodyne/data/datasets.py b/yoyodyne/data/datasets.py index 0669317f..20326305 100644 --- a/yoyodyne/data/datasets.py +++ b/yoyodyne/data/datasets.py @@ -228,4 +228,4 @@ def __getitem__(self, idx: int) -> Item: target=self.encode_target(target), ) else: - return Item(source=self.encode_source(self.samples[idx])) + return Item(source=self.encode_source(source)) diff --git a/yoyodyne/data/tsv.py b/yoyodyne/data/tsv.py index 60300e66..1f063ceb 100644 --- a/yoyodyne/data/tsv.py +++ b/yoyodyne/data/tsv.py @@ -9,7 +9,7 @@ import dataclasses from typing import Iterator, List, Tuple, Union -from .. import defaults, util +from .. import defaults class Error(Exception): @@ -50,15 +50,13 @@ class TsvParser: def __post_init__(self) -> None: # This is automatically called after initialization. if self.source_col < 1: - raise Error(f"Invalid source column: {self.source_col}") + raise Error(f"Out of range source column: {self.source_col}") if self.features_col < 0: - raise Error(f"Invalid features column: {self.features_col}") - if self.features_col != 0: - util.log_info("Including features") + raise Error(f"Out of range features column: {self.features_col}") + if self.features_col < 0: + raise Error(f"Out of range features column: {self.features_col}") if self.target_col < 0: - raise Error(f"Invalid target column: {self.target_col}") - if self.target_col == 0: - util.log_info("Ignoring targets in input") + raise Error(f"Out of range target column: {self.target_col}") @staticmethod def _tsv_reader(path: str) -> Iterator[str]: @@ -121,7 +119,7 @@ def samples( @staticmethod def _get_symbols(string: str, sep: str) -> List[str]: - return list(string) if not sep else sep.split(string) + return list(string) if not sep else string.split(sep) def source_symbols(self, string: str) -> List[str]: return self._get_symbols(string, self.source_sep) From 5cf25dbe8a174c60bdc86cc8526b77c1a8b9941b Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Mon, 17 Jul 2023 14:30:41 -0400 Subject: [PATCH 15/18] Cleanups for comments. --- README.md | 2 +- yoyodyne/data/datamodules.py | 4 +--- yoyodyne/data/indexes.py | 26 ++++++++------------------ yoyodyne/train.py | 2 +- 4 files changed, 11 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index f0b95e98..64fe30fa 100644 --- a/README.md +++ b/README.md @@ -65,7 +65,7 @@ must specify the following required arguments: - `--model_dir`: path for model metadata and checkpoints output during training -The user can also specify as well as various optional training and architectural +The user can also specify various optional training and architectural arguments. See below or run [`yoyodyne-train --help`](yoyodyne/train.py) for more information. diff --git a/yoyodyne/data/datamodules.py b/yoyodyne/data/datamodules.py index 9e1322d0..6246eadf 100644 --- a/yoyodyne/data/datamodules.py +++ b/yoyodyne/data/datamodules.py @@ -132,9 +132,7 @@ def log_vocabularies(self) -> None: def write_index(self, model_dir: str, experiment: str) -> None: """Writes the index.""" - index_path = self.index.index_path(model_dir, experiment) - self.index.write(index_path) - util.log_info(f"Index path: {index_path}") + self.index.write(model_dir, experiment) @property def has_features(self) -> int: diff --git a/yoyodyne/data/indexes.py b/yoyodyne/data/indexes.py index b9e06d9a..e702b554 100644 --- a/yoyodyne/data/indexes.py +++ b/yoyodyne/data/indexes.py @@ -4,7 +4,7 @@ import pickle from typing import Dict, List, Optional, Set -from .. import special +from .. import special, util class SymbolMap: @@ -84,22 +84,9 @@ def __init__( # Serialization support. - @staticmethod - def index_path(model_dir: str, experiment: str) -> str: - """Computes the index path. - - Args: - model_dir (str). - experiment (str). - - Returns: - str. - """ - return f"{model_dir}/{experiment}/index.pkl" - @classmethod def read(cls, path: str): - """Loads symbol mappings. + """Loads index. Args: path (str): input path. @@ -111,12 +98,15 @@ def read(cls, path: str): setattr(index, key, value) return index - def write(self, path: str) -> None: - """Saves index. + def write(self, model_dir: str, experiment: str) -> None: + """Writes index. Args: - path (str): output path. + model_dir (str). + experiment (str). """ + path = f"{model_dir}/{experiment}/index.pkl" + util.log_info(f"Index path: {path}") os.makedirs(os.path.dirname(path), exist_ok=True) with open(path, "wb") as sink: pickle.dump(vars(self), sink) diff --git a/yoyodyne/train.py b/yoyodyne/train.py index a41de703..02ab209c 100644 --- a/yoyodyne/train.py +++ b/yoyodyne/train.py @@ -124,7 +124,7 @@ def get_datamodule_from_argparse_args( max_source_length=args.max_source_length, max_target_length=args.max_target_length, ) - datamodule.write_index(args.model_dir, args.experiment) + datamodule.index.write(args.model_dir, args.experiment) datamodule.log_vocabularies() return datamodule From 6b0ffca822afc4aee2c3784553f577e0934d5d3a Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Mon, 17 Jul 2023 14:40:01 -0400 Subject: [PATCH 16/18] Addressing most of Adam's comments. --- yoyodyne/data/datasets.py | 3 +-- yoyodyne/data/tsv.py | 1 + yoyodyne/predict.py | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/yoyodyne/data/datasets.py b/yoyodyne/data/datasets.py index 20326305..267379dc 100644 --- a/yoyodyne/data/datasets.py +++ b/yoyodyne/data/datasets.py @@ -72,8 +72,7 @@ def _encode( """Encodes a sequence as a tensor of indices with string boundary IDs. Args: - string (str): string to be encoded. - sep (str): separator to use. + symbols (List[str]): symbols to be encoded. symbol_map (indexes.SymbolMap): symbol map to encode with. Returns: diff --git a/yoyodyne/data/tsv.py b/yoyodyne/data/tsv.py index 1f063ceb..5087ffda 100644 --- a/yoyodyne/data/tsv.py +++ b/yoyodyne/data/tsv.py @@ -66,6 +66,7 @@ def _tsv_reader(path: str) -> Iterator[str]: @staticmethod def _get_string(row: List[str], col: int) -> str: """Returns a string from a row by index. + Args: row (List[str]): the split row. col (int): the column index. diff --git a/yoyodyne/predict.py b/yoyodyne/predict.py index 6d9113fb..eaad92c1 100644 --- a/yoyodyne/predict.py +++ b/yoyodyne/predict.py @@ -91,18 +91,18 @@ def predict( Args: trainer (pl.Trainer). model (pl.LightningModule). - dataomdule (data.DataModule). + datamdule (data.DataModule). output (str). """ util.log_info(f"Writing to {output}") _mkdir(output) - decode_target = datamodule.predict_dataloader().dataset.decode_target + loader = datamodule.predict_dataloader() with open(output, "w") as sink: - for batch in trainer.predict(model, datamodule=datamodule): + for batch in trainer.predict(model, loader): batch = model.evaluator.finalize_predictions( batch, datamodule.index.end_idx, datamodule.index.pad_idx ) - for prediction in decode_target(batch): + for prediction in loader.dataset.decode_target(batch): print(prediction, file=sink) From bda40da5956131bedea005aca31dfa5148ee4640 Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Mon, 17 Jul 2023 16:36:28 -0400 Subject: [PATCH 17/18] Last-minute fixes with indices. --- README.md | 10 +++++----- yoyodyne/data/__init__.py | 1 + yoyodyne/data/datamodules.py | 37 ++++++++++++++++++++---------------- yoyodyne/data/indexes.py | 27 +++++++++++++++++++++----- yoyodyne/predict.py | 17 +++++++++++------ yoyodyne/train.py | 12 ++++++------ 6 files changed, 66 insertions(+), 38 deletions(-) diff --git a/README.md b/README.md index 64fe30fa..70ef9d0c 100644 --- a/README.md +++ b/README.md @@ -59,11 +59,11 @@ import yoyodyne Training is performed by the [`yoyodyne-train`](yoyodyne.train.py) script. One must specify the following required arguments: -- `--train`: path to TSV file containing training data -- `--val`: path to TSV file containing validation data - `--experiment`: name of experiment (pick something unique) - `--model_dir`: path for model metadata and checkpoints output during training +- `--train`: path to TSV file containing training data +- `--val`: path to TSV file containing validation data The user can also specify various optional training and architectural arguments. See below or run [`yoyodyne-train --help`](yoyodyne/train.py) for @@ -74,10 +74,10 @@ more information. Prediction is performed by the [`yoyodyne-predict`](yoyodyne.predict.py%60) script. One must specify the following required arguments: -- `--predict`: path to TSV file containing data to be predicted -- `--checkpoint`: path to checkpoint +- `--model_dir`: path for model metadata - `--experiment`: name of experiment -- `--index`: path to index +- `--checkpoint`: path to checkpoint +- `--predict`: path to TSV file containing data to be predicted - `--output`: path for predictions Run [`yoyodyne-predict --help`](yoyodyne/predict.py) for more information. diff --git a/yoyodyne/data/__init__.py b/yoyodyne/data/__init__.py index e1b4368d..7553e247 100644 --- a/yoyodyne/data/__init__.py +++ b/yoyodyne/data/__init__.py @@ -5,6 +5,7 @@ from .. import defaults from .datamodules import DataModule # noqa: F401 from .batches import PaddedBatch, PaddedTensor # noqa: F401 +from .indexes import Index # noqa: F401 def add_argparse_args(parser: argparse.ArgumentParser) -> None: diff --git a/yoyodyne/data/datamodules.py b/yoyodyne/data/datamodules.py index 6246eadf..22ef289a 100644 --- a/yoyodyne/data/datamodules.py +++ b/yoyodyne/data/datamodules.py @@ -41,6 +41,8 @@ def __init__( separate_features: bool = False, max_source_length: int = defaults.MAX_SOURCE_LENGTH, max_target_length: int = defaults.MAX_TARGET_LENGTH, + # Indexing. + index: Optional[indexes.Index] = None, ): super().__init__() self.parser = tsv.TsvParser( @@ -55,6 +57,24 @@ def __init__( self.val = val self.predict = predict self.test = test + self.batch_size = batch_size + self.separate_features = separate_features + self.index = ( + index if index is not None else self._make_index(tied_vocabulary) + ) + self.collator = collators.Collator( + pad_idx=self.index.pad_idx, + has_features=self.index.has_features, + has_target=self.index.has_target, + separate_features=separate_features, + features_offset=self.index.source_vocab_size + if self.index.has_features + else 0, + max_source_length=max_source_length, + max_target_length=max_target_length, + ) + + def _make_index(self, tied_vocabulary: bool) -> indexes.Index: # Computes index. source_vocabulary: Set[str] = set() features_vocabulary: Set[str] = set() @@ -80,8 +100,7 @@ def __init__( if self.parser.has_target and tied_vocabulary: source_vocabulary.update(target_vocabulary) target_vocabulary.update(source_vocabulary) - self.separate_features = separate_features - self.index = indexes.Index( + return indexes.Index( source_vocabulary=sorted(source_vocabulary), features_vocabulary=sorted(features_vocabulary) if features_vocabulary @@ -90,20 +109,6 @@ def __init__( if target_vocabulary else None, ) - # Stores batch size. - self.batch_size = batch_size - # Makes collator. - self.collator = collators.Collator( - pad_idx=self.index.pad_idx, - has_features=self.index.has_features, - has_target=self.index.has_target, - separate_features=separate_features, - features_offset=self.index.source_vocab_size - if self.index.has_features - else 0, - max_source_length=max_source_length, - max_target_length=max_target_length, - ) # Helpers. diff --git a/yoyodyne/data/indexes.py b/yoyodyne/data/indexes.py index e702b554..d7c2d6d5 100644 --- a/yoyodyne/data/indexes.py +++ b/yoyodyne/data/indexes.py @@ -4,7 +4,7 @@ import pickle from typing import Dict, List, Optional, Set -from .. import special, util +from .. import special class SymbolMap: @@ -85,19 +85,37 @@ def __init__( # Serialization support. @classmethod - def read(cls, path: str): + def read(cls, model_dir: str, experiment: str) -> "Index": """Loads index. Args: - path (str): input path. + model_dir (str). + experiment (str). + + Returns: + Index. """ index = cls.__new__(cls) + path = index.index_path(model_dir, experiment) with open(path, "rb") as source: dictionary = pickle.load(source) for key, value in dictionary.items(): setattr(index, key, value) return index + @staticmethod + def index_path(model_dir: str, experiment: str) -> str: + """Computes path for the index file. + + Args: + model_dir (str). + experiment (str). + + Returns: + str. + """ + return f"{model_dir}/{experiment}/index.pkl" + def write(self, model_dir: str, experiment: str) -> None: """Writes index. @@ -105,8 +123,7 @@ def write(self, model_dir: str, experiment: str) -> None: model_dir (str). experiment (str). """ - path = f"{model_dir}/{experiment}/index.pkl" - util.log_info(f"Index path: {path}") + path = self.index_path(model_dir, experiment) os.makedirs(os.path.dirname(path), exist_ok=True) with open(path, "wb") as sink: pickle.dump(vars(self), sink) diff --git a/yoyodyne/predict.py b/yoyodyne/predict.py index eaad92c1..fa572b62 100644 --- a/yoyodyne/predict.py +++ b/yoyodyne/predict.py @@ -37,7 +37,7 @@ def get_datamodule_from_argparse_args( "pointer_generator_lstm", "transducer", ] - # TODO(kbg): reuse index? + index = data.Index.read(args.model_dir, args.experiment) return data.DataModule( predict=args.predict, batch_size=args.batch_size, @@ -51,6 +51,7 @@ def get_datamodule_from_argparse_args( separate_features=separate_features, max_source_length=args.max_source_length, max_target_length=args.max_target_length, + index=index, ) @@ -112,10 +113,18 @@ def add_argparse_args(parser: argparse.ArgumentParser) -> None: Args: parser (argparse.ArgumentParser). """ + # Path arguments. + parser.add_argument( + "--checkpoint", required=True, help="Path to checkpoint (.ckpt)." + ) + parser.add_argument( + "--model_dir", + required=True, + help="Path to output model directory.", + ) parser.add_argument( "--experiment", required=True, help="Name of experiment." ) - # Path arguments. parser.add_argument( "--predict", required=True, @@ -126,10 +135,6 @@ def add_argparse_args(parser: argparse.ArgumentParser) -> None: required=True, help="Path to prediction output data TSV.", ) - parser.add_argument("--index", required=True, help="Path to index (.pkl).") - parser.add_argument( - "--checkpoint", required=True, help="Path to checkpoint (.ckpt)." - ) # Prediction arguments. # TODO: add --beam_width. # Data arguments. diff --git a/yoyodyne/train.py b/yoyodyne/train.py index 02ab209c..2e5a85e5 100644 --- a/yoyodyne/train.py +++ b/yoyodyne/train.py @@ -241,10 +241,15 @@ def add_argparse_args(parser: argparse.ArgumentParser) -> None: Args: argparse.ArgumentParser. """ + # Path arguments. + parser.add_argument( + "--model_dir", + required=True, + help="Path to output model directory.", + ) parser.add_argument( "--experiment", required=True, help="Name of experiment." ) - # Path arguments. parser.add_argument( "--train", required=True, @@ -255,11 +260,6 @@ def add_argparse_args(parser: argparse.ArgumentParser) -> None: required=True, help="Path to input validation data TSV.", ) - parser.add_argument( - "--model_dir", - required=True, - help="Path to output model directory.", - ) parser.add_argument( "--train_from", help="Path to ckpt checkpoint to resume training from.", From f4eb78ff1b7d82a5074dfdeec01326a747db6809 Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Mon, 17 Jul 2023 16:37:56 -0400 Subject: [PATCH 18/18] Edits README. --- README.md | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 70ef9d0c..b4679886 100644 --- a/README.md +++ b/README.md @@ -56,12 +56,11 @@ import yoyodyne ### Training -Training is performed by the [`yoyodyne-train`](yoyodyne.train.py) script. One +Training is performed by the [`yoyodyne-train`](yoyodyne/train.py) script. One must specify the following required arguments: +- `--model_dir`: path for model metadata and checkpoints - `--experiment`: name of experiment (pick something unique) -- `--model_dir`: path for model metadata and checkpoints output during - training - `--train`: path to TSV file containing training data - `--val`: path to TSV file containing validation data @@ -71,7 +70,7 @@ more information. ### Prediction -Prediction is performed by the [`yoyodyne-predict`](yoyodyne.predict.py%60) +Prediction is performed by the [`yoyodyne-predict`](yoyodyne/predict.py) script. One must specify the following required arguments: - `--model_dir`: path for model metadata