Skip to content

Commit

Permalink
Migrate to Keras 3.0 with TF backend (#373)
Browse files Browse the repository at this point in the history
* remove pytest-lazy-fixture as dev dependency and skip test (with WG temp fix)

* change tensorflow dependency for cellfinder

* replace keras imports from tensorflow to just keras imports

* add keras import and reorder

* add keras and TF 2.16 to pyproject.toml

* comment out TF version check for now

* change checkpoint filename for compliance with keras 3. remove use_multiprocessing=False from fit() as it is no longer an input. test_train() passing

* add multiprocessing parameters to cube generator constructor and remove from fit() signature (keras3 change)

* apply temp garbage collector fix

* skip troublesome test

* skip running tests on CI on windows

* remove commented out TF check

* clean commented out code. Explicitly pass use_multiprocessing=False (as before)

* remove str conversion before model.save

* raise test_detection error for sonarcloud happy

* skip running tests on windows on CI

* remove filename comment and small edits
  • Loading branch information
sfmig authored Feb 8, 2024
1 parent 99cbda0 commit 29f8555
Show file tree
Hide file tree
Showing 9 changed files with 72 additions and 65 deletions.
4 changes: 1 addition & 3 deletions .github/workflows/test_and_deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,10 @@ jobs:
# Run all supported Python versions on linux
os: [ubuntu-latest]
python-version: ["3.9", "3.10"]
# Include one windows, one macos run
# Include one macos run
include:
- os: macos-latest
python-version: "3.10"
- os: windows-latest
python-version: "3.10"

steps:
# Cache the tensorflow model so we don't have to remake it every time
Expand Down
7 changes: 4 additions & 3 deletions cellfinder/core/classify/classify.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os
from typing import Any, Callable, Dict, List, Optional, Tuple

import keras
import numpy as np
from brainglobe_utils.cells.cells import Cell
from brainglobe_utils.general.system import get_num_processes
from tensorflow import keras

from cellfinder.core import logger, types
from cellfinder.core.classify.cube_generator import CubeGeneratorFromFile
Expand Down Expand Up @@ -63,6 +63,8 @@ def main(
cube_width=cube_width,
cube_height=cube_height,
cube_depth=cube_depth,
use_multiprocessing=True,
workers=workers,
)

model = get_model(
Expand All @@ -73,10 +75,9 @@ def main(
)

logger.info("Running inference")
# in Keras 3.0 multiprocessing params are specified in the generator
predictions = model.predict(
inference_generator,
use_multiprocessing=True,
workers=workers,
verbose=True,
callbacks=callbacks,
)
Expand Down
22 changes: 18 additions & 4 deletions cellfinder/core/classify/cube_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
from random import shuffle
from typing import Dict, List, Optional, Tuple, Union

import keras
import numpy as np
import tensorflow as tf
from brainglobe_utils.cells.cells import Cell, group_cells_by_z
from brainglobe_utils.general.numerical import is_even
from keras.utils import Sequence
from scipy.ndimage import zoom
from skimage.io import imread
from tensorflow.keras.utils import Sequence

from cellfinder.core import types
from cellfinder.core.classify.augment import AugmentationParameters, augment
Expand Down Expand Up @@ -56,7 +56,14 @@ def __init__(
translate: Tuple[float, float, float] = (0.05, 0.05, 0.05),
shuffle: bool = False,
interpolation_order: int = 2,
*args,
**kwargs,
):
# pass any additional arguments not specified in signature to the
# constructor of the superclass (e.g.: `use_multiprocessing` or
# `workers`)
super().__init__(*args, **kwargs)

self.points = points
self.signal_array = signal_array
self.background_array = background_array
Expand Down Expand Up @@ -220,7 +227,7 @@ def __getitem__(

if self.train:
batch_labels = [cell.type - 1 for cell in cell_batch]
batch_labels = tf.keras.utils.to_categorical(
batch_labels = keras.utils.to_categorical(
batch_labels, num_classes=self.classes
)
return images, batch_labels
Expand Down Expand Up @@ -352,7 +359,14 @@ def __init__(
translate: Tuple[float, float, float] = (0.2, 0.2, 0.2),
train: bool = False, # also return labels
interpolation_order: int = 2,
*args,
**kwargs,
):
# pass any additional arguments not specified in signature to the
# constructor of the superclass (e.g.: `use_multiprocessing` or
# `workers`)
super().__init__(*args, **kwargs)

self.im_shape = shape
self.batch_size = batch_size
self.labels = labels
Expand Down Expand Up @@ -414,7 +428,7 @@ def __getitem__(

if self.train and self.labels is not None:
batch_labels = [self.labels[k] for k in indexes]
batch_labels = tf.keras.utils.to_categorical(
batch_labels = keras.utils.to_categorical(
batch_labels, num_classes=self.classes
)
return images, batch_labels
Expand Down
10 changes: 5 additions & 5 deletions cellfinder/core/classify/resnet.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union

from tensorflow import Tensor
from tensorflow.keras import Model
from tensorflow.keras.initializers import Initializer
from tensorflow.keras.layers import (
from keras import Model
from keras.initializers import Initializer
from keras.layers import (
Activation,
Add,
BatchNormalization,
Expand All @@ -14,7 +13,8 @@
MaxPooling3D,
ZeroPadding3D,
)
from tensorflow.keras.optimizers import Adam, Optimizer
from keras.optimizers import Adam, Optimizer
from tensorflow import Tensor

#####################################################################
# Define the types of ResNet
Expand Down
24 changes: 13 additions & 11 deletions cellfinder/core/classify/tools.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
from typing import List, Optional, Sequence, Tuple, Union
from collections.abc import Sequence
from typing import List, Optional, Tuple, Union

import keras
import numpy as np
import tensorflow as tf
from tensorflow.keras import Model
from keras import Model

from cellfinder.core import logger
from cellfinder.core.classify.resnet import build_model, layer_type
Expand All @@ -17,8 +18,7 @@ def get_model(
inference: bool = False,
continue_training: bool = False,
) -> Model:
"""
Returns the correct model based on the arguments passed
"""Returns the correct model based on the arguments passed
:param existing_model: An existing, trained model. This is returned if it
exists
:param model_weights: This file is used to set the model weights if it
Expand All @@ -30,29 +30,31 @@ def get_model(
by using the default one
:param continue_training: If True, will ensure that a trained model
exists. E.g. by using the default one
:return: A tf.keras model
:return: A keras model
"""
if existing_model is not None or network_depth is None:
logger.debug(f"Loading model: {existing_model}")
return tf.keras.models.load_model(existing_model)
return keras.models.load_model(existing_model)
else:
logger.debug(f"Creating a new instance of model: {network_depth}")
model = build_model(
network_depth=network_depth, learning_rate=learning_rate
network_depth=network_depth,
learning_rate=learning_rate,
)
if inference or continue_training:
logger.debug(
f"Setting model weights according to: {model_weights}"
f"Setting model weights according to: {model_weights}",
)
if model_weights is None:
raise IOError("`model_weights` must be provided")
raise OSError("`model_weights` must be provided")
model.load_weights(model_weights)
return model


def make_lists(
tiff_files: Sequence, train: bool = True
tiff_files: Sequence,
train: bool = True,
) -> Union[Tuple[List, List], Tuple[List, List, np.ndarray]]:
signal_list = []
background_list = []
Expand Down
28 changes: 18 additions & 10 deletions cellfinder/core/train/train_yml.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def run(

suppress_tf_logging(tf_suppress_log_messages)

from tensorflow.keras.callbacks import (
from keras.callbacks import (
CSVLogger,
ModelCheckpoint,
TensorBoard,
Expand Down Expand Up @@ -386,15 +386,16 @@ def run(
labels=labels_test,
batch_size=batch_size,
train=True,
use_multiprocessing=False,
)

# for saving checkpoints
base_checkpoint_file_name = "-epoch.{epoch:02d}-loss-{val_loss:.3f}.h5"
base_checkpoint_file_name = "-epoch.{epoch:02d}-loss-{val_loss:.3f}"

else:
logger.info("No validation data selected.")
validation_generator = None
base_checkpoint_file_name = "-epoch.{epoch:02d}.h5"
base_checkpoint_file_name = "-epoch.{epoch:02d}"

training_generator = CubeGeneratorFromDisk(
signal_train,
Expand All @@ -404,6 +405,7 @@ def run(
shuffle=True,
train=True,
augment=not no_augment,
use_multiprocessing=False,
)
callbacks = []

Expand All @@ -420,9 +422,14 @@ def run(

if not no_save_checkpoints:
if save_weights:
filepath = str(output_dir / ("weight" + base_checkpoint_file_name))
filepath = str(
output_dir
/ ("weight" + base_checkpoint_file_name + ".weights.h5")
)
else:
filepath = str(output_dir / ("model" + base_checkpoint_file_name))
filepath = str(
output_dir / ("model" + base_checkpoint_file_name + ".keras")
)

checkpoints = ModelCheckpoint(
filepath,
Expand All @@ -431,25 +438,26 @@ def run(
callbacks.append(checkpoints)

if save_progress:
filepath = str(output_dir / "training.csv")
csv_logger = CSVLogger(filepath)
csv_filepath = str(output_dir / "training.csv")
csv_logger = CSVLogger(csv_filepath)
callbacks.append(csv_logger)

logger.info("Beginning training.")
# Keras 3.0: `use_multiprocessing` input is set in the
# `training_generator` (False by default)
model.fit(
training_generator,
validation_data=validation_generator,
use_multiprocessing=False,
epochs=epochs,
callbacks=callbacks,
)

if save_weights:
logger.info("Saving model weights")
model.save_weights(str(output_dir / "model_weights.h5"))
model.save_weights(output_dir / "model.weights.h5")
else:
logger.info("Saving model")
model.save(output_dir / "model.h5")
model.save(output_dir / "model.keras")

logger.info(
"Finished training, " "Total time taken: %s",
Expand Down
21 changes: 0 additions & 21 deletions tests/core/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from typing import Tuple

import numpy as np
Expand All @@ -9,26 +8,6 @@
from cellfinder.core.tools.prep import DEFAULT_INSTALL_PATH


@pytest.fixture(scope="session")
def no_free_cpus() -> int:
"""
Set number of free CPUs so all available CPUs are used by the tests.
"""
return 0


@pytest.fixture(scope="session")
def run_on_one_cpu_only() -> int:
"""
Set number of free CPUs so tests can use exactly one CPU.
"""
cpus = os.cpu_count()
if cpus is not None:
return cpus - 1
else:
raise ValueError("No CPUs available.")


@pytest.fixture(scope="session")
def download_default_model():
"""
Expand Down
19 changes: 12 additions & 7 deletions tests/core/test_integration/test_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,13 @@ def test_detection_full(signal_array, background_array, free_cpus, request):


def test_detection_small_planes(
signal_array, background_array, no_free_cpus, mocker
signal_array,
background_array,
mocker,
cpus_to_leave_free: int = 0,
):
# Check that processing works when number of planes < number of processes
nproc = get_num_processes(no_free_cpus)
nproc = get_num_processes(cpus_to_leave_free)
n_planes = 2

# Don't want to bother classifying in this test, so mock classifcation
Expand All @@ -101,11 +104,13 @@ def test_detection_small_planes(
background_array[0:n_planes],
voxel_sizes,
ball_z_size=5,
n_free_cpus=no_free_cpus,
n_free_cpus=cpus_to_leave_free,
)


def test_callbacks(signal_array, background_array, no_free_cpus):
def test_callbacks(
signal_array, background_array, cpus_to_leave_free: int = 0
):
# 20 is minimum number of planes needed to find > 0 cells
signal_array = signal_array[0:20]
background_array = background_array[0:20]
Expand All @@ -130,7 +135,7 @@ def detect_finished_callback(points):
detect_callback=detect_callback,
classify_callback=classify_callback,
detect_finished_callback=detect_finished_callback,
n_free_cpus=no_free_cpus,
n_free_cpus=cpus_to_leave_free,
)

np.testing.assert_equal(planes_done, np.arange(len(signal_array)))
Expand All @@ -148,13 +153,13 @@ def test_floating_point_error(signal_array, background_array):
main(signal_array, background_array, voxel_sizes)


def test_synthetic_data(synthetic_bright_spots, no_free_cpus):
def test_synthetic_data(synthetic_bright_spots, cpus_to_leave_free: int = 0):
signal_array, background_array = synthetic_bright_spots
detected = main(
signal_array,
background_array,
voxel_sizes,
n_free_cpus=no_free_cpus,
n_free_cpus=cpus_to_leave_free,
)
assert len(detected) == 8

Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_integration/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,5 @@ def test_train(tmpdir):
sys.argv = train_args
train_run()

model_file = os.path.join(tmpdir, "model.h5")
model_file = os.path.join(tmpdir, "model.keras")
assert os.path.exists(model_file)

0 comments on commit 29f8555

Please sign in to comment.