Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate to Keras 3.0 with TF backend #373

Merged
merged 18 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions .github/workflows/test_and_deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,10 @@ jobs:
# Run all supported Python versions on linux
os: [ubuntu-latest]
python-version: ["3.9", "3.10"]
# Include one windows, one macos run
# Include one macos run
include:
- os: macos-latest
python-version: "3.10"
- os: windows-latest
python-version: "3.10"

steps:
# Cache the tensorflow model so we don't have to remake it every time
Expand Down
7 changes: 4 additions & 3 deletions cellfinder/core/classify/classify.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os
from typing import Any, Callable, Dict, List, Optional, Tuple

import keras
import numpy as np
from brainglobe_utils.cells.cells import Cell
from brainglobe_utils.general.system import get_num_processes
from tensorflow import keras

from cellfinder.core import logger, types
from cellfinder.core.classify.cube_generator import CubeGeneratorFromFile
Expand Down Expand Up @@ -63,6 +63,8 @@ def main(
cube_width=cube_width,
cube_height=cube_height,
cube_depth=cube_depth,
use_multiprocessing=True,
workers=workers,
)

model = get_model(
Expand All @@ -73,10 +75,9 @@ def main(
)

logger.info("Running inference")
# in Keras 3.0 multiprocessing params are specified in the generator
predictions = model.predict(
inference_generator,
use_multiprocessing=True,
workers=workers,
verbose=True,
callbacks=callbacks,
)
Expand Down
22 changes: 18 additions & 4 deletions cellfinder/core/classify/cube_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
from random import shuffle
from typing import Dict, List, Optional, Tuple, Union

import keras
import numpy as np
import tensorflow as tf
from brainglobe_utils.cells.cells import Cell, group_cells_by_z
from brainglobe_utils.general.numerical import is_even
from keras.utils import Sequence
from scipy.ndimage import zoom
from skimage.io import imread
from tensorflow.keras.utils import Sequence

from cellfinder.core import types
from cellfinder.core.classify.augment import AugmentationParameters, augment
Expand Down Expand Up @@ -56,7 +56,14 @@ def __init__(
translate: Tuple[float, float, float] = (0.05, 0.05, 0.05),
shuffle: bool = False,
interpolation_order: int = 2,
*args,
**kwargs,
):
# pass any additional arguments not specified in signature to the
# constructor of the superclass (e.g.: `use_multiprocessing` or
# `workers`)
super().__init__(*args, **kwargs)

self.points = points
self.signal_array = signal_array
self.background_array = background_array
Expand Down Expand Up @@ -220,7 +227,7 @@ def __getitem__(

if self.train:
batch_labels = [cell.type - 1 for cell in cell_batch]
batch_labels = tf.keras.utils.to_categorical(
batch_labels = keras.utils.to_categorical(
batch_labels, num_classes=self.classes
)
return images, batch_labels
Expand Down Expand Up @@ -352,7 +359,14 @@ def __init__(
translate: Tuple[float, float, float] = (0.2, 0.2, 0.2),
train: bool = False, # also return labels
interpolation_order: int = 2,
*args,
**kwargs,
):
# pass any additional arguments not specified in signature to the
# constructor of the superclass (e.g.: `use_multiprocessing` or
# `workers`)
super().__init__(*args, **kwargs)

self.im_shape = shape
self.batch_size = batch_size
self.labels = labels
Expand Down Expand Up @@ -414,7 +428,7 @@ def __getitem__(

if self.train and self.labels is not None:
batch_labels = [self.labels[k] for k in indexes]
batch_labels = tf.keras.utils.to_categorical(
batch_labels = keras.utils.to_categorical(
batch_labels, num_classes=self.classes
)
return images, batch_labels
Expand Down
10 changes: 5 additions & 5 deletions cellfinder/core/classify/resnet.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union

from tensorflow import Tensor
from tensorflow.keras import Model
from tensorflow.keras.initializers import Initializer
from tensorflow.keras.layers import (
from keras import Model
from keras.initializers import Initializer
from keras.layers import (
Activation,
Add,
BatchNormalization,
Expand All @@ -14,7 +13,8 @@
MaxPooling3D,
ZeroPadding3D,
)
from tensorflow.keras.optimizers import Adam, Optimizer
from keras.optimizers import Adam, Optimizer
from tensorflow import Tensor
sfmig marked this conversation as resolved.
Show resolved Hide resolved

#####################################################################
# Define the types of ResNet
Expand Down
24 changes: 13 additions & 11 deletions cellfinder/core/classify/tools.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
from typing import List, Optional, Sequence, Tuple, Union
from collections.abc import Sequence
from typing import List, Optional, Tuple, Union

import keras
import numpy as np
import tensorflow as tf
from tensorflow.keras import Model
from keras import Model

from cellfinder.core import logger
from cellfinder.core.classify.resnet import build_model, layer_type
Expand All @@ -17,8 +18,7 @@ def get_model(
inference: bool = False,
continue_training: bool = False,
) -> Model:
"""
Returns the correct model based on the arguments passed
"""Returns the correct model based on the arguments passed
:param existing_model: An existing, trained model. This is returned if it
exists
:param model_weights: This file is used to set the model weights if it
Expand All @@ -30,29 +30,31 @@ def get_model(
by using the default one
:param continue_training: If True, will ensure that a trained model
exists. E.g. by using the default one
:return: A tf.keras model
:return: A keras model

"""
if existing_model is not None or network_depth is None:
logger.debug(f"Loading model: {existing_model}")
return tf.keras.models.load_model(existing_model)
return keras.models.load_model(existing_model)
else:
logger.debug(f"Creating a new instance of model: {network_depth}")
model = build_model(
network_depth=network_depth, learning_rate=learning_rate
network_depth=network_depth,
learning_rate=learning_rate,
)
if inference or continue_training:
logger.debug(
f"Setting model weights according to: {model_weights}"
f"Setting model weights according to: {model_weights}",
)
if model_weights is None:
raise IOError("`model_weights` must be provided")
raise OSError("`model_weights` must be provided")
model.load_weights(model_weights)
return model


def make_lists(
tiff_files: Sequence, train: bool = True
tiff_files: Sequence,
train: bool = True,
) -> Union[Tuple[List, List], Tuple[List, List, np.ndarray]]:
signal_list = []
background_list = []
Expand Down
28 changes: 18 additions & 10 deletions cellfinder/core/train/train_yml.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def run(

suppress_tf_logging(tf_suppress_log_messages)

from tensorflow.keras.callbacks import (
from keras.callbacks import (
CSVLogger,
ModelCheckpoint,
TensorBoard,
Expand Down Expand Up @@ -386,15 +386,16 @@ def run(
labels=labels_test,
batch_size=batch_size,
train=True,
use_multiprocessing=False,
)

# for saving checkpoints
base_checkpoint_file_name = "-epoch.{epoch:02d}-loss-{val_loss:.3f}.h5"
base_checkpoint_file_name = "-epoch.{epoch:02d}-loss-{val_loss:.3f}"

else:
logger.info("No validation data selected.")
validation_generator = None
base_checkpoint_file_name = "-epoch.{epoch:02d}.h5"
base_checkpoint_file_name = "-epoch.{epoch:02d}"

training_generator = CubeGeneratorFromDisk(
signal_train,
Expand All @@ -404,6 +405,7 @@ def run(
shuffle=True,
train=True,
augment=not no_augment,
use_multiprocessing=False,
)
callbacks = []

Expand All @@ -420,9 +422,14 @@ def run(

if not no_save_checkpoints:
if save_weights:
filepath = str(output_dir / ("weight" + base_checkpoint_file_name))
filepath = str(
output_dir
/ ("weight" + base_checkpoint_file_name + ".weights.h5")
)
else:
filepath = str(output_dir / ("model" + base_checkpoint_file_name))
filepath = str(
output_dir / ("model" + base_checkpoint_file_name + ".keras")
)

checkpoints = ModelCheckpoint(
filepath,
Expand All @@ -431,25 +438,26 @@ def run(
callbacks.append(checkpoints)

if save_progress:
filepath = str(output_dir / "training.csv")
csv_logger = CSVLogger(filepath)
csv_filepath = str(output_dir / "training.csv")
csv_logger = CSVLogger(csv_filepath)
callbacks.append(csv_logger)

logger.info("Beginning training.")
# Keras 3.0: `use_multiprocessing` input is set in the
# `training_generator` (False by default)
model.fit(
training_generator,
validation_data=validation_generator,
use_multiprocessing=False,
epochs=epochs,
callbacks=callbacks,
)

if save_weights:
logger.info("Saving model weights")
model.save_weights(str(output_dir / "model_weights.h5"))
model.save_weights(output_dir / "model.weights.h5")
else:
logger.info("Saving model")
model.save(output_dir / "model.h5")
model.save(output_dir / "model.keras")

logger.info(
"Finished training, " "Total time taken: %s",
Expand Down
21 changes: 0 additions & 21 deletions tests/core/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from typing import Tuple

import numpy as np
Expand All @@ -9,26 +8,6 @@
from cellfinder.core.tools.prep import DEFAULT_INSTALL_PATH


@pytest.fixture(scope="session")
def no_free_cpus() -> int:
"""
Set number of free CPUs so all available CPUs are used by the tests.
"""
return 0


@pytest.fixture(scope="session")
def run_on_one_cpu_only() -> int:
"""
Set number of free CPUs so tests can use exactly one CPU.
"""
cpus = os.cpu_count()
if cpus is not None:
return cpus - 1
else:
raise ValueError("No CPUs available.")


@pytest.fixture(scope="session")
def download_default_model():
"""
Expand Down
19 changes: 12 additions & 7 deletions tests/core/test_integration/test_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,13 @@ def test_detection_full(signal_array, background_array, free_cpus, request):


def test_detection_small_planes(
signal_array, background_array, no_free_cpus, mocker
signal_array,
background_array,
mocker,
cpus_to_leave_free: int = 0,
):
# Check that processing works when number of planes < number of processes
nproc = get_num_processes(no_free_cpus)
nproc = get_num_processes(cpus_to_leave_free)
n_planes = 2

# Don't want to bother classifying in this test, so mock classifcation
Expand All @@ -101,11 +104,13 @@ def test_detection_small_planes(
background_array[0:n_planes],
voxel_sizes,
ball_z_size=5,
n_free_cpus=no_free_cpus,
n_free_cpus=cpus_to_leave_free,
)


def test_callbacks(signal_array, background_array, no_free_cpus):
def test_callbacks(
signal_array, background_array, cpus_to_leave_free: int = 0
):
# 20 is minimum number of planes needed to find > 0 cells
signal_array = signal_array[0:20]
background_array = background_array[0:20]
Expand All @@ -130,7 +135,7 @@ def detect_finished_callback(points):
detect_callback=detect_callback,
classify_callback=classify_callback,
detect_finished_callback=detect_finished_callback,
n_free_cpus=no_free_cpus,
n_free_cpus=cpus_to_leave_free,
)

np.testing.assert_equal(planes_done, np.arange(len(signal_array)))
Expand All @@ -148,13 +153,13 @@ def test_floating_point_error(signal_array, background_array):
main(signal_array, background_array, voxel_sizes)


def test_synthetic_data(synthetic_bright_spots, no_free_cpus):
def test_synthetic_data(synthetic_bright_spots, cpus_to_leave_free: int = 0):
signal_array, background_array = synthetic_bright_spots
detected = main(
signal_array,
background_array,
voxel_sizes,
n_free_cpus=no_free_cpus,
n_free_cpus=cpus_to_leave_free,
)
assert len(detected) == 8

Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_integration/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,5 @@ def test_train(tmpdir):
sys.argv = train_args
train_run()

model_file = os.path.join(tmpdir, "model.h5")
model_file = os.path.join(tmpdir, "model.keras")
assert os.path.exists(model_file)
Loading