Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix measure performance frequency + Add loss log #56

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Trainers: Add check name collision for losses
emanuele committed May 7, 2020
commit c672bd4e4864dc95314e97e577b51bec16c9cc28
2 changes: 2 additions & 0 deletions src/ashpy/trainers/classifier.py
Original file line number Diff line number Diff line change
@@ -134,6 +134,8 @@ def toy_dataset():
self._loss = loss
self._loss.reduction = tf.keras.losses.Reduction.NONE

super()._check_loss_name_collision([self._loss])

self._avg_loss = ClassifierLoss(name="ashpy/avg_loss")
if metrics:
metrics = (*metrics, self._avg_loss)
8 changes: 8 additions & 0 deletions src/ashpy/trainers/gan.py
Original file line number Diff line number Diff line change
@@ -192,6 +192,10 @@ def __init__(
self._discriminator_loss = discriminator_loss
self._discriminator_loss.reduction = tf.losses.Reduction.NONE

super()._check_loss_name_collision(
[self._generator_loss, self._discriminator_loss]
)

losses_metrics = (
DiscriminatorLoss(name="ashpy/d_loss", logdir=logdir),
GeneratorLoss(name="ashpy/g_loss", logdir=logdir),
@@ -586,6 +590,10 @@ def __init__(
self._encoder_loss = encoder_loss
self._encoder_loss.reduction = tf.losses.Reduction.NONE

super()._check_loss_name_collision(
[self._generator_loss, self._discriminator_loss, self._encoder_loss]
)

ckpt_dict = {
self.ckpt_id_encoder: self._encoder,
self.ckpt_id_optimizer_encoder: self._encoder_optimizer,
20 changes: 20 additions & 0 deletions src/ashpy/trainers/trainer.py
Original file line number Diff line number Diff line change
@@ -173,6 +173,26 @@ def _check_name_collision(objects: List, obj_type: str):
raise ValueError(f"{obj_type} should have unique names.")
buffer.append(obj.name)

@staticmethod
def _check_loss_name_collision(losses: List[ashpy.losses.Executor]):
"""Check that all losses have unique names."""
names = []

for loss in losses:
if loss.name in names:
raise ValueError(f"Losses should have unique names.")
else:
names.append(loss.name)

if isinstance(loss, ashpy.losses.SumExecutor):
loss: ashpy.losses.SumExecutor
sublosses_names = [subloss.name for subloss in loss.sublosses]
for subloss_name in sublosses_names:
if loss.name in names:
raise ValueError(f"Losses should have unique names.")
else:
names.append(subloss_name)

def _validate_metrics(self):
"""Check if every metric is an :py:class:`ashpy.metrics.Metric`."""
validate_objects(self._metrics, Metric)