Skip to content

Commit

Permalink
Merge pull request #266 from CUNY-CL/mapper
Browse files Browse the repository at this point in the history
Moves to mapper interface
  • Loading branch information
kylebgorman authored Dec 2, 2024
2 parents 0a91f56 + 8bc0127 commit f3550a1
Show file tree
Hide file tree
Showing 12 changed files with 276 additions and 207 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ exclude = ["examples*"]

[project]
name = "yoyodyne"
version = "0.2.16"
version = "0.2.17"
description = "Small-vocabulary neural sequence-to-sequence models"
readme = "README.md"
requires-python = ">= 3.9"
Expand Down
2 changes: 2 additions & 0 deletions yoyodyne/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from .datamodules import DataModule # noqa: F401
from .datasets import Dataset # noqa: F401
from .indexes import Index # noqa: F401
from .mappers import Mapper # noqa: F401
from .tsv import TsvParser # noqa: F401


def add_argparse_args(parser: argparse.ArgumentParser) -> None:
Expand Down
1 change: 1 addition & 0 deletions yoyodyne/data/collators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import argparse
import dataclasses

from typing import List

import torch
Expand Down
85 changes: 61 additions & 24 deletions yoyodyne/data/datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,42 @@
from torch.utils import data

from .. import defaults, util
from . import collators, datasets, indexes, tsv
from . import collators, datasets, indexes, mappers, tsv


class DataModule(lightning.LightningDataModule):
"""Parses, indexes, collates and loads data.
The batch size tuner is permitted to mutate the `batch_size` argument.
"""Data module.
This is responsible for indexing the data, collating/padding, and
generating datasets.
Args:
model_dir: Path for checkpoints, indexes, and logs.
train: Path for training data TSV.
val: Path for validation data TSV.
predict: Path for prediction data TSV.
test: Path for test data TSV.
source_col: 1-indexed column in TSV containing source strings.
features_col: 1-indexed column in TSV containing features strings.
target_col: 1-indexed column in TSV containing target strings.
source_sep: String used to split source string into symbols; an empty
string indicates that each Unicode codepoint is its own symbol.
features_sep: String used to split features string into symbols; an
empty string indicates that each Unicode codepoint is its own
symbol.
target_sep: String used to split target string into symbols; an empty
string indicates that each Unicode codepoint is its own symbol.
separate_features: Whether or not a separate encoder should be used
for features.
tie_embeddings: Whether or not source and target embeddings are tied.
If not, then source symbols are wrapped in {...}.
batch_size: Desired batch size.
max_source_length: The maximum length of a source string; this includes
concatenated feature strings if not using separate features. An
error will be raised if any source exceeds this limit.
max_target_length: The maximum length of a target string. A warning
will be raised and the target strings will be truncated if any
target exceeds this limit.
"""

train: Optional[str]
Expand All @@ -37,18 +66,16 @@ def __init__(
source_col: int = defaults.SOURCE_COL,
features_col: int = defaults.FEATURES_COL,
target_col: int = defaults.TARGET_COL,
# String parsing arguments.
source_sep: str = defaults.SOURCE_SEP,
features_sep: str = defaults.FEATURES_SEP,
target_sep: str = defaults.TARGET_SEP,
# Collator options.
batch_size: int = defaults.BATCH_SIZE,
# Modeling options.
separate_features: bool = False,
tie_embeddings: bool = defaults.TIE_EMBEDDINGS,
# Other.
batch_size: int = defaults.BATCH_SIZE,
max_source_length: int = defaults.MAX_SOURCE_LENGTH,
max_target_length: int = defaults.MAX_TARGET_LENGTH,
tie_embeddings: bool = defaults.TIE_EMBEDDINGS,
# Indexing.
index: Optional[indexes.Index] = None,
):
super().__init__()
self.train = train
Expand Down Expand Up @@ -83,7 +110,7 @@ def __init__(
def _make_index(
self, model_dir: str, tie_embeddings: bool
) -> indexes.Index:
# Computes index.
"""Creates the index from a training set."""
source_vocabulary: Set[str] = set()
features_vocabulary: Set[str] = set()
target_vocabulary: Set[str] = set()
Expand All @@ -107,21 +134,22 @@ def _make_index(
for source in self.parser.samples(self.train):
source_vocabulary.update(source)
index = indexes.Index(
source_vocabulary=sorted(source_vocabulary),
source_vocabulary=source_vocabulary,
features_vocabulary=(
sorted(features_vocabulary) if features_vocabulary else None
),
target_vocabulary=(
sorted(target_vocabulary) if target_vocabulary else None
features_vocabulary if features_vocabulary else None
),
target_vocabulary=target_vocabulary if target_vocabulary else None,
tie_embeddings=tie_embeddings,
)
# Writes it to the model directory.
index.write(model_dir)
return index

# Logging.

@staticmethod
def pprint(vocabulary: Iterable) -> str:
"""Prints the vocabulary for debugging adn logging purposes."""
"""Prints the vocabulary for debugging dnd logging purposes."""
return ", ".join(f"{symbol!r}" for symbol in vocabulary)

def log_vocabularies(self) -> None:
Expand All @@ -140,6 +168,8 @@ def log_vocabularies(self) -> None:
f"{self.pprint(self.index.target_vocabulary)}"
)

# Properties.

@property
def has_features(self) -> bool:
return self.parser.has_features
Expand All @@ -148,13 +178,6 @@ def has_features(self) -> bool:
def has_target(self) -> bool:
return self.parser.has_target

def _dataset(self, path: str) -> datasets.Dataset:
return datasets.Dataset(
list(self.parser.samples(path)),
self.index,
self.parser,
)

# Required API.

def train_dataloader(self) -> data.DataLoader:
Expand All @@ -165,6 +188,7 @@ def train_dataloader(self) -> data.DataLoader:
batch_size=self.batch_size,
shuffle=True,
num_workers=1,
persistent_workers=True,
)

def val_dataloader(self) -> data.DataLoader:
Expand All @@ -173,7 +197,9 @@ def val_dataloader(self) -> data.DataLoader:
self._dataset(self.val),
collate_fn=self.collator,
batch_size=self.batch_size,
shuffle=False,
num_workers=1,
persistent_workers=True,
)

def predict_dataloader(self) -> data.DataLoader:
Expand All @@ -182,7 +208,9 @@ def predict_dataloader(self) -> data.DataLoader:
self._dataset(self.predict),
collate_fn=self.collator,
batch_size=self.batch_size,
shuffle=False,
num_workers=1,
persistent_workers=True,
)

def test_dataloader(self) -> data.DataLoader:
Expand All @@ -191,5 +219,14 @@ def test_dataloader(self) -> data.DataLoader:
self._dataset(self.test),
collate_fn=self.collator,
batch_size=self.batch_size,
shuffle=False,
num_workers=1,
persistent_workers=True,
)

def _dataset(self, path: str) -> datasets.Dataset:
return datasets.Dataset(
list(self.parser.samples(path)),
mappers.Mapper(self.index),
self.parser,
)
Loading

0 comments on commit f3550a1

Please sign in to comment.