Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix measure performance frequency + Add loss log #56

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion examples/gans/pix2pix_facades.py
Original file line number Diff line number Diff line change
@@ -22,6 +22,7 @@
import tensorflow as tf

from ashpy import LogEvalMode
from ashpy.callbacks import LogImageGANCallback
from ashpy.losses.gan import (
AdversarialLossType,
Pix2PixLoss,
@@ -177,6 +178,8 @@ def main(
if not logdir.exists():
logdir.mkdir(parents=True)

callbacks = [LogImageGANCallback()]

trainer = AdversarialTrainer(
generator=generator,
discriminator=discriminator,
@@ -188,9 +191,10 @@ def main(
metrics=metrics,
logdir=logdir,
log_eval_mode=LogEvalMode.TEST,
callbacks=callbacks,
)

train_dataset = tf.data.Dataset.list_files(PATH + "train/*.jpg")
train_dataset = tf.data.Dataset.list_files(str(PATH / "train/*.jpg"))
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.map(load_image_train)
train_dataset = train_dataset.batch(BATCH_SIZE)
7 changes: 4 additions & 3 deletions examples/gans/pix2pix_facades_multi_gpu.py
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@
Input Pipeline taken from: https://www.tensorflow.org/beta/tutorials/generative/pix2pix
"""
import os
from pathlib import Path

import tensorflow as tf

@@ -31,7 +32,7 @@
from ashpy.models.convolutional.unet import FUNet
from ashpy.trainers.gan import AdversarialTrainer

from .pix2pix_facades import BATCH_SIZE, BUFFER_SIZE, IMG_WIDTH, PATH, load_image_train
from pix2pix_facades import BATCH_SIZE, BUFFER_SIZE, IMG_WIDTH, PATH, load_image_train


def main(
@@ -94,7 +95,7 @@ def main(
)

metrics = []
logdir = f'{"log"}/{dataset_name}/run_multi'
logdir = Path(f'{"log"}/{dataset_name}/run_multi')

if not logdir.exists():
logdir.mkdir(parents=True)
@@ -116,7 +117,7 @@ def main(
log_eval_mode=LogEvalMode.TEST,
)

train_dataset = tf.data.Dataset.list_files(PATH + "train/*.jpg")
train_dataset = tf.data.Dataset.list_files(str(PATH / "train/*.jpg"))
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.map(load_image_train)
train_dataset = train_dataset.batch(BATCH_SIZE)
35 changes: 24 additions & 11 deletions src/ashpy/callbacks/classifier.py
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@
# limitations under the License.

"""Classifier callbacks."""

import tensorflow as tf
from ashpy.callbacks.counter_callback import CounterCallback
from ashpy.callbacks.events import Event
from ashpy.contexts import ClassifierContext
@@ -30,26 +30,29 @@ class LogClassifierCallback(CounterCallback):
def __init__(
self,
event: Event = Event.ON_EPOCH_END,
name="log_classifier_callback",
name: str = "log_classifier_callback",
event_freq: int = 1,
input_is_zero_centered: bool = True,
):
"""
Initialize the LogClassifierCallback.

Args:
event: event to consider
event_freq: frequency of logging

name (str): name of the callback
event (ashpy.events.Event): event to consider
event_freq (int): frequency of logging
input_is_zero_centered (bool): if True, the callback assumes the input is in [-1,
1] if it is an image with type tf.float. If False, the callback assumes the input
is [0, 1] if type float, and [0, 255] if type is uint. If the input type is float
and the image is in [0, 1] use False. If the input type is uint this parameter is
ignored.
"""
super(LogClassifierCallback, self).__init__(
event=event,
fn=LogClassifierCallback._log_fn,
name=name,
event_freq=event_freq,
event=event, fn=self._log_fn, name=name, event_freq=event_freq,
)
self._input_is_zero_centered = input_is_zero_centered

@staticmethod
def _log_fn(context: ClassifierContext) -> None:
def _log_fn(self, context: ClassifierContext) -> None:
"""
Log output of the image and label to Tensorboard.

@@ -60,5 +63,15 @@ def _log_fn(context: ClassifierContext) -> None:
input_tensor = context.current_batch[0]
out_label = context.current_batch[1]

rank = tf.rank(input_tensor)

# if it is an image, check if we need to scale and shift
if (
tf.equal(rank, 4)
and (input_tensor.dtype == tf.float32 or input_tensor.dtype == tf.float64)
and self._input_is_zero_centered
):
input_tensor = (input_tensor + 1) / 2

log("input_x", input_tensor, context.global_step)
log("input_y", out_label, context.global_step)
8 changes: 5 additions & 3 deletions src/ashpy/losses/classifier.py
Original file line number Diff line number Diff line change
@@ -24,19 +24,20 @@
class ClassifierLoss(Executor):
r"""Classifier Loss Executor using the classifier model, instantiated with a fn."""

def __init__(self, fn: tf.keras.losses.Loss) -> None:
def __init__(self, fn: tf.keras.losses.Loss, name: str = "ClassifierLoss") -> None:
r"""
Initialize :py:class:`ClassifierLoss`.

Args:
fn (:py:class:`tf.keras.losses.Loss`): Classification Loss function, should
take as input labels and prediction.
name (str): Name of the loss. It will be used for logging in Tensorboard.

Returns:
:py:obj:`None`

"""
super().__init__(fn)
super().__init__(fn, name=name)

@Executor.reduce_loss
def call(
@@ -69,4 +70,5 @@ def call(
lambda: loss,
lambda: tf.expand_dims(tf.expand_dims(loss, axis=-1), axis=-1),
)
return tf.reduce_mean(loss, axis=[1, 2])
loss = tf.reduce_mean(loss, axis=[1, 2])
return loss
47 changes: 38 additions & 9 deletions src/ashpy/losses/executor.py
Original file line number Diff line number Diff line change
@@ -29,12 +29,13 @@
class Executor:
"""Carry a function and the way of executing it. Given a context."""

def __init__(self, fn: tf.keras.losses.Loss = None) -> None:
def __init__(self, fn: tf.keras.losses.Loss = None, name: str = "loss") -> None:
"""
Initialize the Executor.

Args:
fn (:py:class:`tf.keras.losses.Loss`): A Keras Loss to execute.
name (str): Name of the loss. It will be be used for logging in TensorBoard.

Returns:
:py:obj:`None`
@@ -48,6 +49,8 @@ def __init__(self, fn: tf.keras.losses.Loss = None) -> None:
self._distribute_strategy = tf.distribute.get_strategy()
self._global_batch_size = -1
self._weight = lambda _: 1.0
self._name = name
self._loss_value = 0

@property
def weight(self) -> Callable[..., float]:
@@ -153,17 +156,29 @@ def __call__(self, context, **kwargs) -> tf.Tensor:
:py:obj:`tf.Tensor`: Output Tensor.

"""
return self._weight(context.global_step) * self.call(context, **kwargs)
self._loss_value = self._weight(context.global_step) * self.call(
context, **kwargs
)
return self._loss_value

def __add__(self, other) -> SumExecutor:
def log(self, step: tf.Variable):
"""
Log the loss on Tensorboard.

Args:
step (tf.Variable): current training step.
"""
tf.summary.scalar(f"ashpy/losses/{self._name}", self._loss_value, step=step)

def __add__(self, other: Union[SumExecutor, Executor]) -> SumExecutor:
"""Concatenate Executors together into a SumExecutor."""
if isinstance(other, SumExecutor):
other_executors = other.executors
else:
other_executors = [other]

all_executors = [self] + other_executors
return SumExecutor(all_executors)
return SumExecutor(all_executors, name=f"{self._name}+{other._name}")

def __mul__(self, other: Union[Callable[..., float], float, int, tf.Tensor]):
"""
@@ -185,7 +200,7 @@ def __mul__(self, other: Union[Callable[..., float], float, int, tf.Tensor]):
self._weight = lambda step: weight(step) * __other(step)
return self

def __rmul__(self, other):
def __rmul__(self, other: Union[SumExecutor, Executor]):
"""See `__mul__` method."""
return self * other

@@ -198,19 +213,20 @@ class SumExecutor(Executor):
then summed together.
"""

def __init__(self, executors) -> None:
def __init__(self, executors: List[Executor], name: str = "LossSum") -> None:
"""
Initialize the SumExecutor.

Args:
executors (:py:obj:`list` of [:py:class:`ashpy.executors.Executor`]): Array of
:py:obj:`ashpy.executors.Executor` to sum evaluate and sum together.
name (str): Name of the loss. It will be used to log in TensorBoard.

Returns:
:py:obj:`None`

"""
super().__init__()
super().__init__(name=name)
self._executors = executors
self._global_batch_size = 1

@@ -235,8 +251,21 @@ def call(self, *args, **kwargs) -> tf.Tensor:
:py:classes:`tf.Tensor`: Output Tensor.

"""
result = tf.add_n([executor(*args, **kwargs) for executor in self._executors])
return result
self._loss_value = tf.add_n(
[executor(*args, **kwargs) for executor in self._executors]
)
return self._loss_value

def log(self, step: tf.Variable):
"""
Log the loss + all the sub-losses on Tensorboard.

Args:
step: current step
"""
super().log(step)
for executor in self._executors:
executor.log(step)

def __add__(self, other: Union[SumExecutor, Executor]):
"""Concatenate Executors together into a SumExecutor."""
56 changes: 33 additions & 23 deletions src/ashpy/losses/gan.py
Original file line number Diff line number Diff line change
@@ -114,16 +114,21 @@ def get_discriminator_inputs(
class GeneratorAdversarialLoss(GANExecutor):
r"""Base class for the adversarial loss of the generator."""

def __init__(self, loss_fn: tf.keras.losses.Loss = None) -> None:
def __init__(
self,
loss_fn: tf.keras.losses.Loss = None,
name: str = "GeneratorAdversarialLoss",
) -> None:
"""
Initialize the Executor.
Args:
loss_fn (:py:class:`tf.keras.losses.Loss`): Keras Loss function to call
passing (tf.ones_like(d_fake_i), d_fake_i).
name (str): Name of the loss. It will be used for logging in TensorBoard.
"""
super().__init__(loss_fn)
super().__init__(loss_fn, name=name)

@Executor.reduce_loss
def call(
@@ -187,8 +192,9 @@ class GeneratorBCE(GeneratorAdversarialLoss):

def __init__(self, from_logits: bool = True) -> None:
"""Initialize the BCE Loss for the Generator."""
self.name = "GeneratorBCE"
super().__init__(tf.keras.losses.BinaryCrossentropy(from_logits=from_logits))
super().__init__(
tf.keras.losses.BinaryCrossentropy(from_logits=from_logits), "GeneratorBCE"
)


class GeneratorLSGAN(GeneratorAdversarialLoss):
@@ -208,8 +214,7 @@ class GeneratorLSGAN(GeneratorAdversarialLoss):

def __init__(self) -> None:
"""Initialize the Least Square Loss for the Generator."""
super().__init__(tf.keras.losses.MeanSquaredError())
self.name = "GeneratorLSGAN"
super().__init__(tf.keras.losses.MeanSquaredError(), name="GeneratorLSGAN")


class GeneratorL1(GANExecutor):
@@ -225,7 +230,7 @@ class GeneratorL1(GANExecutor):

def __init__(self) -> None:
"""Initialize the Executor."""
super().__init__(L1())
super().__init__(L1(), name="GeneratorL1")

@Executor.reduce_loss
def call(self, context: GANContext, *, fake: tf.Tensor, real: tf.Tensor, **kwargs):
@@ -256,8 +261,7 @@ class GeneratorHingeLoss(GeneratorAdversarialLoss):

def __init__(self) -> None:
"""Initialize the Least Square Loss for the Generator."""
super().__init__(GHingeLoss())
self.name = "GeneratorHingeLoss"
super().__init__(GHingeLoss(), "GeneratorHingeLoss")


class FeatureMatchingLoss(GANExecutor):
@@ -290,7 +294,7 @@ class FeatureMatchingLoss(GANExecutor):

def __init__(self) -> None:
"""Initialize the Executor."""
super().__init__(L1())
super().__init__(L1(), "FeatureMatchingLoss")

@Executor.reduce_loss
def call(
@@ -355,8 +359,7 @@ class CategoricalCrossEntropy(Executor):

def __init__(self) -> None:
"""Initialize the Categorical Cross Entropy Executor."""
self.name = "CrossEntropy"
super().__init__(tf.keras.losses.CategoricalCrossentropy())
super().__init__(tf.keras.losses.CategoricalCrossentropy(), "CrossEntropy")

@Executor.reduce_loss
def call(self, context: GANContext, *, fake: tf.Tensor, real: tf.Tensor, **kwargs):
@@ -431,7 +434,7 @@ def __init__(
if use_feature_matching_loss:
executors.append(FeatureMatchingLoss() * feature_matching_weight)

super().__init__(executors)
super().__init__(executors, name="Pix2PixLoss")


class Pix2PixLossSemantic(SumExecutor):
@@ -479,7 +482,7 @@ def __init__(

if use_feature_matching_loss:
executors.append(FeatureMatchingLoss() * feature_matching_weight)
super().__init__(executors)
super().__init__(executors, name="Pix2PixLossSemantic")


# TODO: Check if this supports condition
@@ -488,7 +491,10 @@ class EncoderBCE(Executor):

def __init__(self, from_logits: bool = True) -> None:
"""Initialize the Executor."""
super().__init__(tf.keras.losses.BinaryCrossentropy(from_logits=from_logits))
super().__init__(
tf.keras.losses.BinaryCrossentropy(from_logits=from_logits),
name="EncoderBCE",
)

@Executor.reduce_loss
def call(
@@ -515,16 +521,21 @@ def call(
class DiscriminatorAdversarialLoss(GANExecutor):
r"""Base class for the adversarial loss of the discriminator."""

def __init__(self, loss_fn: tf.keras.losses.Loss = None) -> None:
def __init__(
self,
loss_fn: tf.keras.losses.Loss = None,
name: str = "DiscriminatorAdversarialLoss",
) -> None:
r"""
Initialize the Executor.
Args:
loss_fn (:py:class:`tf.keras.losses.Loss`): Loss function call passing
(d_real, d_fake).
(d_real, d_fake).
name (str) : Name of the loss. It will be used for logging in TensorBoard.
"""
super().__init__(loss_fn)
super().__init__(loss_fn, name=name)

@Executor.reduce_loss
def call(
@@ -591,7 +602,8 @@ class DiscriminatorMinMax(DiscriminatorAdversarialLoss):
def __init__(self, from_logits=True, label_smoothing=0.0):
"""Initialize Loss."""
super().__init__(
DMinMax(from_logits=from_logits, label_smoothing=label_smoothing)
DMinMax(from_logits=from_logits, label_smoothing=label_smoothing),
name="DiscriminatorMinMax",
)


@@ -625,8 +637,7 @@ class DiscriminatorLSGAN(DiscriminatorAdversarialLoss):

def __init__(self) -> None:
"""Initialize loss."""
super().__init__(DLeastSquare())
self.name = "DiscriminatorLSGAN"
super().__init__(DLeastSquare(), name="DiscriminatorLSGAN")


class DiscriminatorHingeLoss(DiscriminatorAdversarialLoss):
@@ -640,8 +651,7 @@ class DiscriminatorHingeLoss(DiscriminatorAdversarialLoss):

def __init__(self) -> None:
"""Initialize the Least Square Loss for the Generator."""
super().__init__(DHingeLoss())
self.name = "DiscriminatorHingeLoss"
super().__init__(DHingeLoss(), name="DiscriminatorHingeLoss")


###
8 changes: 6 additions & 2 deletions src/ashpy/metrics/classifier.py
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@
from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union

import tensorflow as tf # pylint: disable=import-error
from ashpy.metrics.metric import Metric
@@ -90,6 +90,7 @@ class ClassifierMetric(Metric):
def __init__(
self,
metric: tf.keras.metrics.Metric,
name: Optional[str] = None,
model_selection_operator: Callable = None,
logdir: Union[Path, str] = Path().cwd() / "log",
processing_predictions=None,
@@ -100,6 +101,7 @@ def __init__(
Args:
metric (:py:class:`tf.keras.metrics.Metric`): The Keras Metric to use with
the classifier (e.g.: Accuracy()).
name (str): The name of the metric, if None uses the metric.name property.
model_selection_operator (:py:obj:`typing.Callable`): The operation that will
be used when `model_selection` is triggered to compare the metrics,
used by the `update_state`.
@@ -116,8 +118,10 @@ def __init__(
keyword-arguments. Defaults to {"fn": tf.argmax, "kwargs": {"axis": -1}}.
"""
if name is None:
name = metric.name
super().__init__(
name=metric.name,
name=name,
metric=metric,
model_selection_operator=model_selection_operator,
logdir=logdir,
10 changes: 8 additions & 2 deletions src/ashpy/trainers/classifier.py
Original file line number Diff line number Diff line change
@@ -187,6 +187,8 @@ def train_step(self, features, labels):

gradients = tape.gradient(loss, self._model.trainable_variables)
self._optimizer.apply_gradients(zip(gradients, self._model.trainable_variables))

self._loss.log(self._global_step)
return loss

@tf.function
@@ -220,8 +222,6 @@ def call(
performance.
"""
if self._deferred_restoration:
self._build_and_restore_models(dataset=training_set)

# set the context properties
self._context.training_set = training_set
@@ -242,6 +242,9 @@ def call(
)
)

if self._deferred_restoration:
self._build_and_restore_models(dataset=training_set)

with self._train_summary_writer.as_default():

# notify on train start
@@ -286,6 +289,9 @@ def call(
# notify on epoch end
self._on_epoch_end()

self.context.dataset = training_set
self._measure_performance()

with self._eval_summary_writer.as_default():
self._context.dataset = validation_set
self._measure_performance()
42 changes: 32 additions & 10 deletions src/ashpy/trainers/gan.py
Original file line number Diff line number Diff line change
@@ -232,10 +232,24 @@ def __init__(

def _build_and_restore_models(self, dataset: tf.data.Dataset):
restorer = ashpy.restorers.AdversarialRestorer(self._logdir)
(x, _), z = next(iter(dataset.take(1)))
# Invoke model on sample input
self._generator(z)
self._discriminator(x)
(real_x, real_y), g_input = next(iter(dataset.take(1)))
# prepare g inputs
if len(self._generator.inputs) == 2:
g_inputs = [g_input, real_x]
else:
g_inputs = g_input
# call G on its inputs
self._generator(g_inputs)

# prepare d inputs
if len(self._discriminator.inputs) == 2:
d_inputs = [real_x, real_y]
else:
d_inputs = real_x
# call D on its inputs
self._discriminator(d_inputs)

# restore models
restorer.restore_generator(self._generator)
restorer.restore_discriminator(self._discriminator)
self._deferred_restoration = False
@@ -289,6 +303,10 @@ def train_step(self, real_xy, g_inputs):
zip(g_gradients, self._generator.trainable_variables)
)

# log losses
self._discriminator_loss.log(self._global_step)
self._generator_loss.log(self._global_step)

return d_loss, g_loss, fake

@tf.function
@@ -328,9 +346,6 @@ def call(
performance.
"""
if self._deferred_restoration:
self._build_and_restore_models(dataset=dataset)

current_epoch = self._current_epoch()

self._update_global_batch_size(
@@ -344,6 +359,9 @@ def call(

self._context.generator_inputs = samples[1]

if self._deferred_restoration:
self._build_and_restore_models(dataset=dataset)

with self._train_summary_writer.as_default():

# notify on train start
@@ -652,6 +670,10 @@ def train_step(self, real_xy, g_inputs):
zip(e_gradients, self._encoder.trainable_variables)
)

self._discriminator_loss.log(self._global_step)
self._generator_loss.log(self._global_step)
self._encoder_loss.log(self._global_step)

return d_loss, g_loss, e_loss, fake, generator_of_encoder

@tf.function
@@ -693,9 +715,6 @@ def call(
performance.
"""
if self._deferred_restoration:
self._build_and_restore_models(dataset=dataset)

current_epoch = self._current_epoch()

self._update_global_batch_size(
@@ -712,6 +731,9 @@ def call(
self._context.generator_inputs = samples[1]
self._context.encoder_inputs = samples[0][0]

if self._deferred_restoration:
self._build_and_restore_models(dataset=dataset)

with self._train_summary_writer.as_default():

# notify on train start event
1 change: 0 additions & 1 deletion tests/test_restorers.py
Original file line number Diff line number Diff line change
@@ -23,7 +23,6 @@
have the same weights.
"""
from pathlib import Path
from typing import Union

import pytest
import tensorflow as tf
4 changes: 1 addition & 3 deletions tests/test_trainers.py
Original file line number Diff line number Diff line change
@@ -18,11 +18,9 @@
from typing import List

import ashpy
import pytest
import tensorflow as tf
from ashpy.trainers import AdversarialTrainer, ClassifierTrainer

from tests.test_restorers import ModelNotConstructedError, _check_models_weights
from tests.test_restorers import _check_models_weights
from tests.utils.fake_training_loop import (
FakeAdversarialTraining,
FakeClassifierTraining,
1 change: 0 additions & 1 deletion tests/utils/fake_training_loop.py
Original file line number Diff line number Diff line change
@@ -20,7 +20,6 @@
import ashpy
import tensorflow as tf
from ashpy.losses import DiscriminatorMinMax, GeneratorBCE
from ashpy.models.gans import ConvDiscriminator, ConvGenerator
from ashpy.trainers import AdversarialTrainer, ClassifierTrainer, Trainer

from tests.utils.fake_datasets import (