Skip to content

Commit

Permalink
Merge pull request #110 from CUNY-CL/data0
Browse files Browse the repository at this point in the history
Migrates to data modules
  • Loading branch information
kylebgorman authored Jul 17, 2023
2 parents 28ff0e5 + f4eb78f commit 8a4b925
Show file tree
Hide file tree
Showing 24 changed files with 881 additions and 969 deletions.
38 changes: 31 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,32 @@ import yoyodyne

## Usage

See [`yoyodyne-predict --help`](yoyodyne/predict.py) and
[`yoyodyne-train --help`](yoyodyne/train.py).
### Training

Training is performed by the [`yoyodyne-train`](yoyodyne/train.py) script. One
must specify the following required arguments:

- `--model_dir`: path for model metadata and checkpoints
- `--experiment`: name of experiment (pick something unique)
- `--train`: path to TSV file containing training data
- `--val`: path to TSV file containing validation data

The user can also specify various optional training and architectural
arguments. See below or run [`yoyodyne-train --help`](yoyodyne/train.py) for
more information.

### Prediction

Prediction is performed by the [`yoyodyne-predict`](yoyodyne/predict.py)
script. One must specify the following required arguments:

- `--model_dir`: path for model metadata
- `--experiment`: name of experiment
- `--checkpoint`: path to checkpoint
- `--predict`: path to TSV file containing data to be predicted
- `--output`: path for predictions

Run [`yoyodyne-predict --help`](yoyodyne/predict.py) for more information.

## Data format

Expand All @@ -73,11 +97,12 @@ the third contains semi-colon delimited feature strings:

this format is specified by `--features-col 3`.

Alternatively, for the SIGMORPHON 2016 shared task data format:
Alternatively, for the [SIGMORPHON 2016 shared
task](https://sigmorphon.github.io/sharedtasks/2016/) data:

source feat1,feat2,... target

this format is specified by `--features-col 2 --features-sep , --target-col 3`.
this format is specified by `--features_col 2 --features_sep , --target_col 3`.

In order to ensure that targets are ignored during prediction, one can specify
`--target_col 0`.
Expand All @@ -101,7 +126,7 @@ checkpoint, the **full path to the checkpoint** should be specified with for the
the provided model checkpoint.

During training, we save the best `--save_top_k` checkpoints (by default, 1)
ranked according to accuracy on the `--dev` set. For example, `--save_top_k 5`
ranked according to accuracy on the `--val` set. For example, `--save_top_k 5`
will save the top 5 most accurate models.

## Reserved symbols
Expand Down Expand Up @@ -153,8 +178,7 @@ source encoder using the `--source_encoder` flag:
- `"transformer"`: This is a transformer encoder.

When using features, the user can also specify a non-default features encoder
using the `--features_encoder` flag (`"linear"`, `"lstm"`,
`"transformer"`).
using the `--features_encoder` flag (`"linear"`, `"lstm"`, `"transformer"`).

For all models, the user may also wish to specify:

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ exclude = ["examples*"]

[project]
name = "yoyodyne"
version = "0.2.3"
version = "0.2.4"
description = "Small-vocabulary neural sequence-to-sequence models"
readme = "README.md"
requires-python = ">= 3.9"
Expand Down
47 changes: 0 additions & 47 deletions tests/collator_test.py

This file was deleted.

14 changes: 0 additions & 14 deletions tests/dataset_test.py

This file was deleted.

89 changes: 89 additions & 0 deletions yoyodyne/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""Data classes."""

import argparse

from .. import defaults
from .datamodules import DataModule # noqa: F401
from .batches import PaddedBatch, PaddedTensor # noqa: F401
from .indexes import Index # noqa: F401


def add_argparse_args(parser: argparse.ArgumentParser) -> None:
"""Adds data options to the argument parser.
Args:
parser (argparse.ArgumentParser).
"""
parser.add_argument(
"--source_col",
type=int,
default=defaults.SOURCE_COL,
help="1-based index for source column. Default: %(default)s.",
)
parser.add_argument(
"--target_col",
type=int,
default=defaults.TARGET_COL,
help="1-based index for target column. Default: %(default)s.",
)
parser.add_argument(
"--features_col",
type=int,
default=defaults.FEATURES_COL,
help="1-based index for features column; "
"0 indicates the model will not use features. "
"Default: %(default)s.",
)
parser.add_argument(
"--source_sep",
type=str,
default=defaults.SOURCE_SEP,
help="String used to split source string into symbols; "
"an empty string indicates that each Unicode codepoint "
"is its own symbol. Default: %(default)r.",
)
parser.add_argument(
"--target_sep",
type=str,
default=defaults.TARGET_SEP,
help="String used to split target string into symbols; "
"an empty string indicates that each Unicode codepoint "
"is its own symbol. Default: %(default)r.",
)
parser.add_argument(
"--features_sep",
type=str,
default=defaults.FEATURES_SEP,
help="String used to split features string into symbols; "
"an empty string indicates that each Unicode codepoint "
"is its own symbol. Default: %(default)r.",
)
parser.add_argument(
"--tied_vocabulary",
action="store_true",
default=defaults.TIED_VOCABULARY,
help="Share source and target embeddings. Default: %(default)s.",
)
parser.add_argument(
"--no_tied_vocabulary",
action="store_false",
dest="tied_vocabulary",
default=True,
)
parser.add_argument(
"--batch_size",
type=int,
default=defaults.BATCH_SIZE,
help="Batch size. Default: %(default)s.",
)
parser.add_argument(
"--max_source_length",
type=int,
default=defaults.MAX_SOURCE_LENGTH,
help="Maximum source string length. Default: %(default)s.",
)
parser.add_argument(
"--max_target_length",
type=int,
default=defaults.MAX_TARGET_LENGTH,
help="Maximum target string length. Default: %(default)s.",
)
File renamed without changes.
41 changes: 7 additions & 34 deletions yoyodyne/collators.py → yoyodyne/data/collators.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,30 @@
"""Collators and related utilities."""

import argparse
import dataclasses
from typing import List

import torch

from . import batches, datasets, defaults, util
from .. import defaults, util
from . import batches, datasets


class LengthError(Exception):
pass


@dataclasses.dataclass
class Collator:
"""Pads data."""

pad_idx: int
features_offset: int
has_features: bool
has_target: bool
max_source_length: int
max_target_length: int
separate_features: bool

def __init__(
self,
dataset: datasets.BaseDataset,
arch: str,
max_source_length: int = defaults.MAX_SOURCE_LENGTH,
max_target_length: int = defaults.MAX_TARGET_LENGTH,
):
"""Initializes the collator.
Args:
dataset (dataset.BaseDataset).
arch (str).
max_source_length (int).
max_target_length (int).
"""
self.index = dataset.index
self.pad_idx = self.index.pad_idx
self.config = dataset.config
self.has_features = self.config.has_features
self.has_target = self.config.has_target
self.max_source_length = max_source_length
self.max_target_length = max_target_length
self.features_offset = (
dataset.index.source_vocab_size if self.has_features else 0
)
self.separate_features = dataset.config.has_features and arch in [
"pointer_generator_lstm",
"transducer",
]
features_offset: int
max_source_length: int = defaults.MAX_SOURCE_LENGTH
max_target_length: int = defaults.MAX_TARGET_LENGTH

def _source_length_error(self, padded_length: int) -> None:
"""Callback function to raise the error when the padded length of the
Expand Down
Loading

0 comments on commit 8a4b925

Please sign in to comment.