Skip to content

Commit

Permalink
Merge pull request #106 from CUNY-CL/name
Browse files Browse the repository at this point in the history
Improves model logging
  • Loading branch information
kylebgorman authored Jul 7, 2023
2 parents 263f2ce + c2aedf2 commit ed277fc
Show file tree
Hide file tree
Showing 10 changed files with 60 additions and 12 deletions.
5 changes: 1 addition & 4 deletions yoyodyne/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import argparse

from .. import util
from .base import BaseEncoderDecoder
from .lstm import AttentiveLSTMEncoderDecoder, LSTMEncoderDecoder
from .pointer_generator import PointerGeneratorLSTMEncoderDecoder
Expand Down Expand Up @@ -31,9 +30,7 @@ def get_model_cls(arch: str) -> BaseEncoderDecoder:
"transformer": TransformerEncoderDecoder,
}
try:
model_cls = model_fac[arch]
util.log_info(f"Model: {model_cls.__name__}")
return model_cls
return model_fac[arch]
except KeyError:
raise NotImplementedError(f"Architecture {arch} not found")

Expand Down
10 changes: 9 additions & 1 deletion yoyodyne/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
from torch import nn, optim

from .. import batches, defaults, evaluators, schedulers
from .. import batches, defaults, evaluators, schedulers, util
from . import modules


Expand Down Expand Up @@ -150,6 +150,14 @@ def __init__(
self.save_hyperparameters(
ignore=["source_encoder", "decoder", "features_encoder"]
)
# Logs the module names.
util.log_info(f"Model: {self.name}")
if self.features_encoder is not None:
util.log_info(f"Source encoder: {self.source_encoder.name}")
util.log_info(f"Features encoder: {self.features_encoder.name}")
else:
util.log_info(f"Encoder: {self.source_encoder.name}")
util.log_info(f"Decoder: {self.decoder.name}")

@staticmethod
def _xavier_embedding_initialization(
Expand Down
8 changes: 8 additions & 0 deletions yoyodyne/models/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,10 @@ def forward(
predictions = predictions.transpose(0, 1)
return predictions

@property
def name(self) -> str:
return "LSTM"

@staticmethod
def add_argparse_args(parser: argparse.ArgumentParser) -> None:
"""Adds LSTM configuration options to the argument parser.
Expand Down Expand Up @@ -346,3 +350,7 @@ def get_decoder(self):
hidden_size=self.hidden_size,
attention_input_size=self.source_encoder.output_size,
)

@property
def name(self) -> str:
return "attentive LSTM"
9 changes: 2 additions & 7 deletions yoyodyne/models/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import argparse

from ... import util
from .base import BaseModule
from .linear import LinearEncoder
from .lstm import LSTMAttentiveDecoder, LSTMDecoder, LSTMEncoder # noqa: F401
Expand Down Expand Up @@ -49,18 +48,14 @@ def get_encoder_cls(
}
if encoder_arch is None:
try:
model_cls = model_to_encoder_fac[model_arch]
util.log_info(f"Model: {model_cls.__name__}")
return model_cls
return model_to_encoder_fac[model_arch]
except KeyError:
raise NotImplementedError(
f"Encoder compatible with {model_arch} not found"
)
else:
try:
model_cls = encoder_fac[encoder_arch]
util.log_info(f"Model: {model_cls.__name__}")
return model_cls
return encoder_fac[encoder_arch]
except KeyError:
raise NotImplementedError(
f"Encoder architecture {encoder_arch} not found"
Expand Down
4 changes: 4 additions & 0 deletions yoyodyne/models/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ def forward(
"""
return base.ModuleOutput(self.embed(source.padded))

@property
def name(self) -> str:
return "linear"

@property
def output_size(self) -> int:
return self.embedding_size
12 changes: 12 additions & 0 deletions yoyodyne/models/modules/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ def get_module(self) -> nn.LSTM:
def output_size(self) -> int:
return self.hidden_size * self.num_directions

@property
def name(self) -> str:
return "LSTM"


class LSTMDecoder(LSTMModule):
def __init__(self, *args, decoder_input_size, **kwargs):
Expand Down Expand Up @@ -165,6 +169,10 @@ def get_module(self) -> nn.LSTM:
def output_size(self) -> int:
return self.num_embeddings

@property
def name(self) -> str:
return "LSTM"


class LSTMAttentiveDecoder(LSTMDecoder):
attention_input_size: int
Expand Down Expand Up @@ -211,3 +219,7 @@ def forward(
)
output = self.dropout_layer(output)
return base.ModuleOutput(output, hiddens=hiddens)

@property
def name(self) -> str:
return "attentive LSTM"
12 changes: 12 additions & 0 deletions yoyodyne/models/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,10 @@ def get_module(self) -> nn.TransformerEncoder:
def output_size(self) -> int:
return self.embedding_size

@property
def name(self) -> str:
return "transformer"


class FeatureInvariantTransformerEncoder(TransformerEncoder):
"""Encoder for Transformer with feature invariance.
Expand Down Expand Up @@ -230,6 +234,10 @@ def embed(self, symbols: torch.Tensor) -> torch.Tensor:
)
return out

@property
def name(self) -> str:
return "feature-invariant transformer"


class TransformerDecoder(TransformerModule):
"""Decoder for Transformer."""
Expand Down Expand Up @@ -311,3 +319,7 @@ def generate_square_subsequent_mask(length: int) -> torch.Tensor:
@property
def output_size(self) -> int:
return self.num_embeddings

@property
def name(self) -> str:
return "transformer"
4 changes: 4 additions & 0 deletions yoyodyne/models/pointer_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,3 +376,7 @@ def _reshape_hiddens(
H = H.view(layers, num_directions, H.size(1), H.size(2)).sum(axis=1)
C = C.view(layers, num_directions, C.size(1), C.size(2)).sum(axis=1)
return (H, C)

@property
def name(self) -> str:
return "pointer-generator"
4 changes: 4 additions & 0 deletions yoyodyne/models/transducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,5 +575,9 @@ def sample(log_probs: torch.Tensor) -> torch.Tensor:
break
return action

@property
def name(self) -> str:
return "transducer"


# TODO: Implement beam decoding.
4 changes: 4 additions & 0 deletions yoyodyne/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ def forward(
)
return output

@property
def name(self) -> str:
return "transformer"

@staticmethod
def add_argparse_args(parser: argparse.ArgumentParser) -> None:
"""Adds transformer configuration options to the argument parser.
Expand Down

0 comments on commit ed277fc

Please sign in to comment.