From 83e35ca56ce4fd84859dbb577dc9ad12507e2d1f Mon Sep 17 00:00:00 2001 From: Stefan Hahmann Date: Tue, 23 Jul 2024 09:18:42 +0200 Subject: [PATCH] Change methods runClassification and classifyExternalProjects in ClassifyLineagesController --- .../ClassifyLineagesController.java | 73 +++++++++---------- .../multiproject/ExternalProjects.java | 9 ++- 2 files changed, 41 insertions(+), 41 deletions(-) diff --git a/src/main/java/org/mastodon/mamut/classification/ClassifyLineagesController.java b/src/main/java/org/mastodon/mamut/classification/ClassifyLineagesController.java index b8cd99106..58ec4fb80 100644 --- a/src/main/java/org/mastodon/mamut/classification/ClassifyLineagesController.java +++ b/src/main/java/org/mastodon/mamut/classification/ClassifyLineagesController.java @@ -64,7 +64,6 @@ import java.util.Comparator; import java.util.HashSet; import java.util.List; -import java.util.Map; import java.util.NoSuchElementException; import java.util.Set; import java.util.StringJoiner; @@ -160,15 +159,16 @@ private String runClassification() String createdTagSetName; try { - Pair< List< List< BranchSpotTree > >, double[][] > rootsAndDistances = getRootsAndDistanceMatrix(); - List< List< BranchSpotTree > > roots = rootsAndDistances.getLeft(); + Pair< List< Pair< ProjectSession, List< BranchSpotTree > > >, double[][] > rootsAndDistances = getRootsAndDistanceMatrix(); + List< Pair< ProjectSession, List< BranchSpotTree > > > rootsMatrix = rootsAndDistances.getLeft(); double[][] distances = rootsAndDistances.getRight(); - Classification< BranchSpotTree > classification = classifyLineageTrees( roots.get( 0 ), distances ); + Pair< ProjectSession, List< BranchSpotTree > > referenceRoots = rootsMatrix.get( 0 ); + Classification< BranchSpotTree > classification = classifyLineageTrees( referenceRoots.getRight(), distances ); List< Pair< String, Integer > > tagsAndColors = createTagsAndColors( classification ); Function< BranchSpotTree, BranchSpot > branchSpotProvider = BranchSpotTree::getBranchSpot; createdTagSetName = applyClassification( classification, tagsAndColors, referenceModel, branchSpotProvider ); - if ( addTagSetToExternalProjects && roots.size() > 1 ) - classification = classifyUsingExternalProjects( roots, classification, distances, tagsAndColors ); + if ( addTagSetToExternalProjects && rootsMatrix.size() > 1 ) + classifyExternalProjects( rootsMatrix, distances, tagsAndColors ); if ( showDendrogram ) showDendrogram( classification ); } @@ -179,58 +179,57 @@ private String runClassification() return createdTagSetName; } - private Classification< BranchSpotTree > classifyUsingExternalProjects( final List< List< BranchSpotTree > > roots, - final Classification< BranchSpotTree > classification, final double[][] distances, - final List< Pair< String, Integer > > tagsAndColors ) + private void classifyExternalProjects( final List< Pair< ProjectSession, List< BranchSpotTree > > > rootsMatrix, + final double[][] distances, final List< Pair< String, Integer > > tagsAndColors ) { Function< BranchSpotTree, BranchSpot > branchSpotProvider; - Classification< BranchSpotTree > averageClassification = classification; - for ( int i = 1; i < roots.size(); i++ ) + for ( int i = 1; i < rootsMatrix.size(); i++ ) // NB: start at 1 to skip reference project { - averageClassification = classifyLineageTrees( roots.get( i ), distances ); - for ( Map.Entry< File, ProjectSession > externalProject : externalProjects.getProjects().entrySet() ) + Pair< ProjectSession, List< BranchSpotTree > > sessionAndRoots = rootsMatrix.get( i ); + Classification< BranchSpotTree > classification = classifyLineageTrees( sessionAndRoots.getRight(), distances ); + ProjectSession projectSession = sessionAndRoots.getLeft(); + ProjectModel projectModel = projectSession.getProjectModel(); + Model model = projectModel.getModel(); + File file = projectSession.getFile(); + branchSpotProvider = branchSpotTree -> model.getBranchGraph().vertices().stream() + .filter( ( branchSpot -> branchSpot.getFirstLabel().equals( branchSpotTree.getName() ) ) ) + .findFirst().orElse( null ); + applyClassification( classification, tagsAndColors, model, branchSpotProvider ); + try { - ProjectModel projectModel = externalProject.getValue().getProjectModel(); - File file = externalProject.getKey(); - branchSpotProvider = branchSpotTree -> projectModel.getModel().getBranchGraph().vertices().stream() - .filter( ( branchSpot -> branchSpot.getFirstLabel().equals( branchSpotTree.getName() ) ) ) - .findFirst().orElse( null ); - applyClassification( classification, tagsAndColors, projectModel.getModel(), branchSpotProvider ); - try - { - ProjectSaver.saveProject( file, projectModel ); - } - catch ( IOException e ) - { - logger.warn( "Could not save tag set of project {} to file {}. Message: {}", projectModel.getProjectName(), - file.getAbsolutePath(), e.getMessage() ); - } + ProjectSaver.saveProject( file, projectModel ); + } + catch ( IOException e ) + { + logger.warn( "Could not save tag set of project {} to file {}. Message: {}", projectModel.getProjectName(), + file.getAbsolutePath(), e.getMessage() ); } } - return averageClassification; } - private Pair< List< List< BranchSpotTree > >, double[][] > getRootsAndDistanceMatrix() + private Pair< List< Pair< ProjectSession, List< BranchSpotTree > > >, double[][] > getRootsAndDistanceMatrix() { List< BranchSpotTree > roots = getRoots(); if ( externalProjects.isEmpty() ) { double[][] distances = ClassificationUtils.getDistanceMatrix( roots, similarityMeasure ); - return Pair.of( Collections.singletonList( roots ), distances ); + return Pair.of( Collections.singletonList( Pair.of( null, roots ) ), distances ); } List< String > commonRootNames = findCommonRootNames(); - List< List< BranchSpotTree > > treeMatrix = new ArrayList<>(); + List< Pair< ProjectSession, List< BranchSpotTree > > > projectSessionsAndRootLists = new ArrayList<>(); keepCommonRootsAndSort( roots, commonRootNames ); - treeMatrix.add( roots ); - for ( ProjectModel projectModel : externalProjects.getProjectModels() ) + projectSessionsAndRootLists.add( Pair.of( null, roots ) ); + for ( ProjectSession projectSession : externalProjects.getProjectSessions() ) { - List< BranchSpotTree > externalRoots = getRoots( projectModel ); + List< BranchSpotTree > externalRoots = getRoots( projectSession.getProjectModel() ); keepCommonRootsAndSort( externalRoots, commonRootNames ); - treeMatrix.add( externalRoots ); + projectSessionsAndRootLists.add( Pair.of( projectSession, externalRoots ) ); } - return Pair.of( treeMatrix, ClassificationUtils.getAverageDistanceMatrix( treeMatrix, similarityMeasure ) ); + List< List< BranchSpotTree > > treeMatrix = + projectSessionsAndRootLists.stream().map( Pair::getRight ).collect( Collectors.toList() ); + return Pair.of( projectSessionsAndRootLists, ClassificationUtils.getAverageDistanceMatrix( treeMatrix, similarityMeasure ) ); } private List< String > findCommonRootNames() diff --git a/src/main/java/org/mastodon/mamut/classification/multiproject/ExternalProjects.java b/src/main/java/org/mastodon/mamut/classification/multiproject/ExternalProjects.java index d9098075e..7fc48be1d 100644 --- a/src/main/java/org/mastodon/mamut/classification/multiproject/ExternalProjects.java +++ b/src/main/java/org/mastodon/mamut/classification/multiproject/ExternalProjects.java @@ -10,6 +10,7 @@ import java.io.File; import java.io.IOException; import java.lang.invoke.MethodHandles; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; @@ -51,12 +52,12 @@ public Collection< ProjectModel > getProjectModels() } /** - * Gets a mapping of external project files to their project sessions - * @return the mapping + * Gets a list of {@link ProjectSession} + * @return the list */ - public Map< File, ProjectSession > getProjects() + public List< ProjectSession > getProjectSessions() { - return projectSessions; + return new ArrayList<>( projectSessions.values() ); } /**