Skip to content

Commit

Permalink
Prep merge (#349)
Browse files Browse the repository at this point in the history
* Remove unnecessary code duplication

* Remove redundant parameter

* Make brainmapper check optional in this repo for releases and status checks on PRs

---------

Co-authored-by: willGraham01 <[email protected]>
  • Loading branch information
adamltyson and willGraham01 authored Jan 5, 2024
1 parent d94a73b commit c38399a
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 45 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_and_deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ jobs:
build_sdist_wheel:
name: Build source distribution and wheel
needs: [test, test_brainmapper_cli, test_numba_disabled]
needs: [test, test_numba_disabled]
if: github.event_name == 'push' && github.ref_type == 'tag'
runs-on: ubuntu-latest
steps:
Expand Down
16 changes: 2 additions & 14 deletions benchmarks/benchmarks/tools/prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
from brainglobe_utils.general.system import get_num_processes

from cellfinder.core.tools.prep import (
prep_classification,
prep_model_weights,
prep_models,
prep_tensorflow,
prep_training,
)


Expand Down Expand Up @@ -43,30 +42,19 @@ def teardown(self, model_name):

def time_prep_models(self, model_name):
prep_models(
self.trained_model,
self.model_weights,
self.install_path,
model_name,
)

def time_prep_classification(self, model_name):
prep_classification(
self.trained_model,
prep_model_weights(
self.model_weights,
self.install_path,
model_name,
self.n_free_cpus,
)

def time_prep_training(self, model_name):
prep_training(
self.n_free_cpus,
self.trained_model,
self.model_weights,
self.install_path,
model_name,
)


class PrepTF:
def setup(self):
Expand Down
4 changes: 2 additions & 2 deletions cellfinder/core/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ def main(
detect_finished_callback(points)

install_path = None
model_weights = prep.prep_classification(
trained_model, model_weights, install_path, model, n_free_cpus
model_weights = prep.prep_model_weights(
model_weights, install_path, model, n_free_cpus
)
if len(points) > 0:
logger.info("Running classification")
Expand Down
23 changes: 2 additions & 21 deletions cellfinder/core/tools/prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,44 +20,25 @@
DEFAULT_INSTALL_PATH = home / ".cellfinder"


def prep_classification(
trained_model: Optional[os.PathLike],
def prep_model_weights(
model_weights: Optional[os.PathLike],
install_path: Optional[os.PathLike],
model_name: model_download.model_type,
n_free_cpus: int,
) -> Path:
n_processes = get_num_processes(min_free_cpu_cores=n_free_cpus)
prep_tensorflow(n_processes)
model_weights = prep_models(
trained_model, model_weights, install_path, model_name
)
model_weights = prep_models(model_weights, install_path, model_name)

return model_weights


def prep_training(
n_free_cpus: int,
trained_model: Optional[os.PathLike],
model_weights: Optional[os.PathLike],
install_path: Optional[os.PathLike],
model_name: model_download.model_type,
) -> Path:
n_processes = get_num_processes(min_free_cpu_cores=n_free_cpus)
prep_tensorflow(n_processes)
model_weights = prep_models(
trained_model, model_weights, install_path, model_name
)
return model_weights


def prep_tensorflow(max_threads: int) -> None:
tf_tools.set_tf_threads(max_threads)
tf_tools.allow_gpu_memory_growth()


def prep_models(
trained_model_path: Optional[os.PathLike],
model_weights_path: Optional[os.PathLike],
install_path: Optional[os.PathLike],
model_name: model_download.model_type,
Expand Down
10 changes: 3 additions & 7 deletions cellfinder/core/train/train_yml.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,17 +332,13 @@ def run(

from cellfinder.core.classify.cube_generator import CubeGeneratorFromDisk
from cellfinder.core.classify.tools import get_model, make_lists
from cellfinder.core.tools.prep import prep_training
from cellfinder.core.tools.prep import prep_model_weights

start_time = datetime.now()

ensure_directory_exists(output_dir)
model_weights = prep_training(
n_free_cpus,
trained_model,
install_path,
model_weights,
model,
model_weights = prep_model_weights(
install_path, model_weights, model, n_free_cpus
)

yaml_contents = parse_yaml(yaml_file)
Expand Down

0 comments on commit c38399a

Please sign in to comment.