diff --git a/doc/analysis.rst b/doc/analysis.rst
index bf238c22..888320c0 100644
--- a/doc/analysis.rst
+++ b/doc/analysis.rst
@@ -9,3 +9,20 @@ to get a quick overview over all models, as a pandas dataframe.
Additionally, see the Python API docs for the :mod:`petab_select.analyze` module, which contains some methods to subset and group models,
or compute "weights" (e.g. Akaike weights).
+
+Model hashes
+^^^^^^^^^^^^
+
+Model hashes are special objects in the library, that are generated from model-specific information that is unique within a single PEtab Select problem.
+
+This means you can reconstruct the model given some model hash. For example, with this model hash `M1-000`, you can reconstruct the :class:`petab_select.ModelHash` from a string, then reconstruct the :class:`petab_select.Model`.
+
+.. code-block:: language
+
+ ModelHash.from_hash("M1-000").get_model(petab_select_problem)
+
+You can use this to get the uncalibrated version of a calibrated model.
+
+.. code-block:: language
+
+ model.hash.get_model(petab_select_problem)
diff --git a/doc/examples/calibrated_models/calibrated_models.yaml b/doc/examples/calibrated_models/calibrated_models.yaml
index 53abe836..73086f00 100644
--- a/doc/examples/calibrated_models/calibrated_models.yaml
+++ b/doc/examples/calibrated_models/calibrated_models.yaml
@@ -4,16 +4,16 @@
estimated_parameters:
sigma_x2: 4.462298422134608
iteration: 1
- model_hash: M_0-000
- model_id: M_0-000
- model_subspace_id: M_0
+ model_hash: M-000
+ model_id: M-000
+ model_subspace_id: M
model_subspace_indices:
- 0
- 0
- 0
parameters:
- k1: 0
- k2: 0
+ k1: 0.2
+ k2: 0.1
k3: 0
model_subspace_petab_yaml: petab_problem.yaml
predecessor_model_hash: virtual_initial_model-
@@ -24,19 +24,19 @@
k3: 0.0
sigma_x2: 0.12242920113658338
iteration: 2
- model_hash: M_1-000
- model_id: M_1-000
- model_subspace_id: M_1
+ model_hash: M-001
+ model_id: M-001
+ model_subspace_id: M
model_subspace_indices:
- 0
- 0
- - 0
+ - 1
parameters:
k1: 0.2
k2: 0.1
k3: estimate
model_subspace_petab_yaml: petab_problem.yaml
- predecessor_model_hash: M_0-000
+ predecessor_model_hash: M-000
- criteria:
AICc: -0.27451438069575573
NLLH: -4.137257190347878
@@ -44,19 +44,19 @@
k2: 0.10147824307890803
sigma_x2: 0.12142219599557078
iteration: 2
- model_hash: M_2-000
- model_id: M_2-000
- model_subspace_id: M_2
+ model_hash: M-010
+ model_id: M-010
+ model_subspace_id: M
model_subspace_indices:
- 0
- - 0
+ - 1
- 0
parameters:
k1: 0.2
k2: estimate
k3: 0
model_subspace_petab_yaml: petab_problem.yaml
- predecessor_model_hash: M_0-000
+ predecessor_model_hash: M-000
- criteria:
AICc: -0.7053270766271886
NLLH: -4.352663538313594
@@ -64,11 +64,11 @@
k1: 0.20160925279667963
sigma_x2: 0.11714017664827497
iteration: 2
- model_hash: M_3-000
- model_id: M_3-000
- model_subspace_id: M_3
+ model_hash: M-100
+ model_id: M-100
+ model_subspace_id: M
model_subspace_indices:
- - 0
+ - 1
- 0
- 0
parameters:
@@ -76,7 +76,7 @@
k2: 0.1
k3: 0
model_subspace_petab_yaml: petab_problem.yaml
- predecessor_model_hash: M_0-000
+ predecessor_model_hash: M-000
- criteria:
AICc: 9.294672923372811
NLLH: -4.352663538313594
@@ -85,19 +85,19 @@
k3: 0.0
sigma_x2: 0.11714017664827497
iteration: 3
- model_hash: M_5-000
- model_id: M_5-000
- model_subspace_id: M_5
+ model_hash: M-101
+ model_id: M-101
+ model_subspace_id: M
model_subspace_indices:
+ - 1
- 0
- - 0
- - 0
+ - 1
parameters:
k1: estimate
k2: 0.1
k3: estimate
model_subspace_petab_yaml: petab_problem.yaml
- predecessor_model_hash: M_3-000
+ predecessor_model_hash: M-100
- criteria:
AICc: 7.8521704398854
NLLH: -5.0739147800573
@@ -106,19 +106,19 @@
k2: 0.0859052351446815
sigma_x2: 0.10386846319370771
iteration: 3
- model_hash: M_6-000
- model_id: M_6-000
- model_subspace_id: M_6
+ model_hash: M-110
+ model_id: M-110
+ model_subspace_id: M
model_subspace_indices:
- - 0
- - 0
+ - 1
+ - 1
- 0
parameters:
k1: estimate
k2: estimate
k3: 0
model_subspace_petab_yaml: petab_problem.yaml
- predecessor_model_hash: M_3-000
+ predecessor_model_hash: M-100
- criteria:
AICc: 35.94352968170024
NLLH: -6.028235159149878
@@ -128,16 +128,16 @@
k3: 0.0010850434974038557
sigma_x2: 0.08859278245811462
iteration: 4
- model_hash: M_7-000
- model_id: M_7-000
- model_subspace_id: M_7
+ model_hash: M-111
+ model_id: M-111
+ model_subspace_id: M
model_subspace_indices:
- - 0
- - 0
- - 0
+ - 1
+ - 1
+ - 1
parameters:
k1: estimate
k2: estimate
k3: estimate
model_subspace_petab_yaml: petab_problem.yaml
- predecessor_model_hash: M_3-000
+ predecessor_model_hash: M-100
diff --git a/doc/examples/calibrated_models/model_space.tsv b/doc/examples/calibrated_models/model_space.tsv
index 87058e5d..c75f6c4a 100644
--- a/doc/examples/calibrated_models/model_space.tsv
+++ b/doc/examples/calibrated_models/model_space.tsv
@@ -1,9 +1,2 @@
model_subspace_id petab_yaml k1 k2 k3
-M_0 petab_problem.yaml 0 0 0
-M_1 petab_problem.yaml 0.2 0.1 estimate
-M_2 petab_problem.yaml 0.2 estimate 0
-M_3 petab_problem.yaml estimate 0.1 0
-M_4 petab_problem.yaml 0.2 estimate estimate
-M_5 petab_problem.yaml estimate 0.1 estimate
-M_6 petab_problem.yaml estimate estimate 0
-M_7 petab_problem.yaml estimate estimate estimate
+M petab_problem.yaml 0.2;estimate 0.1;estimate 0
diff --git a/doc/examples/visualization.ipynb b/doc/examples/visualization.ipynb
index 16bd485e..5ce2b7fd 100644
--- a/doc/examples/visualization.ipynb
+++ b/doc/examples/visualization.ipynb
@@ -31,7 +31,6 @@
"\n",
"import petab_select\n",
"import petab_select.plot\n",
- "from petab_select import VIRTUAL_INITIAL_MODEL\n",
"\n",
"models = petab_select.Models.from_yaml(\n",
" \"calibrated_models/calibrated_models.yaml\"\n",
@@ -45,7 +44,9 @@
"metadata": {},
"outputs": [],
"source": [
- "models.df.style.background_gradient(\n",
+ "models.df.drop(\n",
+ " columns=[petab_select.Criterion.AIC, petab_select.Criterion.BIC]\n",
+ ").style.background_gradient(\n",
" cmap=matplotlib.colormaps.get_cmap(\"summer\"),\n",
" subset=[petab_select.Criterion.AICC],\n",
")"
@@ -56,7 +57,7 @@
"id": "aaeb0606",
"metadata": {},
"source": [
- "We customize the labels here to just be the second part of a model hash: a binary string that describes their estimated parameters. We additionally set a custom color for a couple of models, and the default color for the other models. We also set a nicer label for the \"virtual initial model\", which is a hypothetical model that PEtab Select uses by default to initialize a forward search (with no parameters)."
+ "To use the plotting methods, we need to first setup an object that contains information common to multiple plotting methods. This can include the models, custom colors and labels, and the criterion."
]
},
{
@@ -66,25 +67,21 @@
"metadata": {},
"outputs": [],
"source": [
- "# Custom labels\n",
- "labels = {}\n",
- "for model in models:\n",
- " labels[model.hash] = \"M_\" + \"\".join(\n",
- " \"1\" if value == petab_select.ESTIMATE else \"0\"\n",
- " for value in model.parameters.values()\n",
- " )\n",
- "labels[VIRTUAL_INITIAL_MODEL.hash] = \"\\n\".join(\n",
- " petab_select.VIRTUAL_INITIAL_MODEL.model_subspace_id.split(\"_\")\n",
- ").title()\n",
- "\n",
"# Custom colors for some models\n",
"colors = {\n",
- " \"M_000\": \"lightgreen\",\n",
- " \"M_001\": \"lightgreen\",\n",
+ " \"M-000\": \"lightgreen\",\n",
+ " \"M-001\": \"lightgreen\",\n",
"}\n",
"\n",
+ "plot_data = petab_select.plot.PlotData(\n",
+ " models=models,\n",
+ " criterion=petab_select.Criterion.AICC,\n",
+ " colors=colors,\n",
+ " relative_criterion=True,\n",
+ ")\n",
+ "\n",
"# Change default color\n",
- "petab_select.plot.NORMAL_NODE_COLOR = \"darkgray\""
+ "petab_select.plot.DEFAULT_NODE_COLOR = \"darkgray\""
]
},
{
@@ -106,7 +103,7 @@
"metadata": {},
"outputs": [],
"source": [
- "petab_select.plot.upset(models=models, criterion=petab_select.Criterion.AICC);"
+ "petab_select.plot.upset(plot_data=plot_data);"
]
},
{
@@ -128,11 +125,7 @@
"metadata": {},
"outputs": [],
"source": [
- "petab_select.plot.line_best_by_iteration(\n",
- " models=models,\n",
- " criterion=petab_select.Criterion.AICC,\n",
- " labels=labels,\n",
- ");"
+ "petab_select.plot.line_best_by_iteration(plot_data=plot_data);"
]
},
{
@@ -152,12 +145,11 @@
"metadata": {},
"outputs": [],
"source": [
- "petab_select.plot.graph_history(\n",
- " models=models,\n",
- " criterion=petab_select.Criterion.AICC,\n",
- " labels=labels,\n",
- " colors=colors,\n",
- ");"
+ "# Add the relative criterion value to each label\n",
+ "plot_data.augment_labels(criterion=True)\n",
+ "petab_select.plot.graph_history(plot_data=plot_data)\n",
+ "# Reset the labels (remove the relative criterion)\n",
+ "plot_data.augment_labels()"
]
},
{
@@ -177,12 +169,7 @@
"metadata": {},
"outputs": [],
"source": [
- "petab_select.plot.bar_criterion_vs_models(\n",
- " models=models,\n",
- " criterion=petab_select.Criterion.AICC,\n",
- " labels=labels,\n",
- " colors=colors,\n",
- ");"
+ "petab_select.plot.bar_criterion_vs_models(plot_data=plot_data);"
]
},
{
@@ -207,10 +194,7 @@
"outputs": [],
"source": [
"petab_select.plot.scatter_criterion_vs_n_estimated(\n",
- " models=models,\n",
- " criterion=petab_select.Criterion.AICC,\n",
- " labels=labels,\n",
- " colors=colors,\n",
+ " plot_data=plot_data,\n",
" # Uncomment to turn off jitter.\n",
" # max_jitter=0,\n",
");"
@@ -231,7 +215,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "5ce191fc",
+ "id": "21157e4d-b2ba-4cb1-95f6-e14052c86959",
"metadata": {},
"outputs": [],
"source": [
@@ -244,10 +228,11 @@
"# cmap = matplotlib.colors.LinearSegmentedColormap.from_list(\"\", [\"green\",\"lightgreen\"])\n",
"# colorbar_mappable = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap)\n",
"\n",
+ "# Augment labels with the changes in parameters of each model, compared to their predecessor model\n",
+ "plot_data.augment_labels(added_parameters=True, removed_parameters=True)\n",
+ "\n",
"petab_select.plot.graph_iteration_layers(\n",
- " models=models,\n",
- " criterion=petab_select.Criterion.AICC,\n",
- " labels=labels,\n",
+ " plot_data=plot_data,\n",
" draw_networkx_kwargs={\n",
" \"arrowstyle\": \"-|>\",\n",
" \"node_shape\": \"s\",\n",
diff --git a/doc/logo/editable/README.md b/doc/logo/editable/README.md
new file mode 100644
index 00000000..0e38ed69
--- /dev/null
+++ b/doc/logo/editable/README.md
@@ -0,0 +1 @@
+The text in these copies is still text. In the copies in the parent directory, the text is converted to paths (e.g. `Path -> Object to Path` in Inkscape), for consistent rendering on systems that don't have the fonts installed.
diff --git a/doc/logo/editable/logo-tall.svg b/doc/logo/editable/logo-tall.svg
new file mode 100644
index 00000000..e786c0a1
--- /dev/null
+++ b/doc/logo/editable/logo-tall.svg
@@ -0,0 +1,191 @@
+
+
+
+
diff --git a/doc/logo/editable/logo-wide.svg b/doc/logo/editable/logo-wide.svg
new file mode 100644
index 00000000..0c51f8a2
--- /dev/null
+++ b/doc/logo/editable/logo-wide.svg
@@ -0,0 +1,194 @@
+
+
+
+
diff --git a/doc/logo/logo-tall.svg b/doc/logo/logo-tall.svg
index e786c0a1..dc5d7389 100644
--- a/doc/logo/logo-tall.svg
+++ b/doc/logo/logo-tall.svg
@@ -8,8 +8,8 @@
version="1.1"
id="svg5"
xml:space="preserve"
- inkscape:version="1.3.2 (1:1.3.2+202311252150+091e20ef0f)"
- sodipodi:docname="008.svg"
+ inkscape:version="1.4 (1:1.4+202410161351+e7c3feb100)"
+ sodipodi:docname="logo-tall.svg"
inkscape:export-filename="001.pdf"
inkscape:export-xdpi="299"
inkscape:export-ydpi="299"
@@ -28,11 +28,11 @@
inkscape:document-units="mm"
showgrid="false"
showborder="true"
- inkscape:zoom="0.44697994"
- inkscape:cx="-102.9129"
- inkscape:cy="-275.18014"
- inkscape:window-width="2490"
- inkscape:window-height="1376"
+ inkscape:zoom="1.2642502"
+ inkscape:cx="-156.21908"
+ inkscape:cy="182.32151"
+ inkscape:window-width="2560"
+ inkscape:window-height="1403"
inkscape:window-x="0"
inkscape:window-y="0"
inkscape:window-maximized="1"
@@ -49,7 +49,7 @@
style="color-interpolation-filters:sRGB"
inkscape:label="Drop Shadow"
id="filter6"
- x="-0.0017246602"
+ x="-0.0017246601"
y="-0.0040821556"
width="1.0045991"
height="1.0108857">petabSELECTpetabSELECT nx.DiGraph:
+ """Get a graph representation of the models in terms of their ancestry.
+
+ Edges connect models with their predecessor models.
+
+ Args:
+ models:
+ The models.
+ labels:
+ Alternative labels for the models. Keys are model hashes, values
+ are the labels.
+
+ Returns:
+ The graph.
+ """
+ if labels is None:
+ labels = {}
+
+ G = nx.DiGraph()
+ edges = []
+ for model in models:
+ tail = labels.get(
+ model.predecessor_model_hash, model.predecessor_model_hash
+ )
+ head = labels.get(model.hash, model.hash)
+ edges.append((tail, head))
+ G.add_edges_from(edges)
+ return G
+
+
+def get_parameter_changes(
+ models: Models,
+ as_dict: bool = False,
+) -> (
+ dict[ModelHash, list[tuple[set[str], set[str]]]]
+ | list[tuple[set[str], set[str]]]
+):
+ """Get the differences in parameters betweem models and their predecessors.
+
+ Args:
+ models:
+ The models.
+ as_dict:
+ Whether to return a dictionary, with model hashes for keys.
+
+ Returns:
+ The parameter changes. Each model has a 2-tuple of sets of parameters.
+ The first and second sets are the added and removed parameters,
+ respectively. If the predecessor model is undefined (e.g. the
+ ``VIRTUAL_INITIAL_MODEL``), then both sets will be empty.
+ """
+ estimated_parameters = {
+ model.hash: set(model.estimated_parameters) for model in models
+ }
+ parameter_changes = [
+ (set(), set())
+ if model.predecessor_model_hash not in estimated_parameters
+ else (
+ estimated_parameters[model.hash].difference(
+ estimated_parameters[model.predecessor_model_hash]
+ ),
+ estimated_parameters[model.predecessor_model_hash].difference(
+ estimated_parameters[model.hash]
+ ),
+ )
+ for model in models
+ ]
+
+ if as_dict:
+ return dict(zip(models.hashes, parameter_changes, strict=True))
+ return parameter_changes
diff --git a/petab_select/model.py b/petab_select/model.py
index cd1bdb94..f0383fd5 100644
--- a/petab_select/model.py
+++ b/petab_select/model.py
@@ -303,6 +303,8 @@ class ModelBase(VirtualModelBase):
"""The iteration of model selection that calibrated this model."""
model_id: str = Field(default=None)
"""The model ID."""
+ model_label: str | None = Field(default=None)
+ """The model label (e.g. for plotting)."""
parameters: dict[str, float | int | Literal[ESTIMATE]]
"""PEtab problem parameters overrides for this model.
diff --git a/petab_select/plot.py b/petab_select/plot.py
index 1f2f539e..04b6e695 100644
--- a/petab_select/plot.py
+++ b/petab_select/plot.py
@@ -1,5 +1,10 @@
-"""Visualization routines for model selection."""
+"""Visualization routines for model selection.
+Plotting methods generally take a :class:`PlotData` object as input, which
+can be used to specify information used by multiple plotting methods.
+"""
+
+import copy
import warnings
from typing import Any
@@ -11,18 +16,18 @@
import networkx as nx
import numpy as np
import upsetplot
-from more_itertools import one
-from .analyze import (
- get_best_by_iteration,
- group_by_iteration,
-)
+from . import analyze
from .constants import Criterion
-from .model import VIRTUAL_INITIAL_MODEL, Model
+from .model import VIRTUAL_INITIAL_MODEL, ModelHash
from .models import Models
-RELATIVE_LABEL_FONTSIZE = -2
-NORMAL_NODE_COLOR = "darkgrey"
+LABEL_FONTSIZE = 16
+"""The font size of axis labels."""
+TICK_LABEL_FONTSIZE = LABEL_FONTSIZE - 4
+"""The font size of axis tick labels."""
+DEFAULT_NODE_COLOR = "darkgrey"
+"""The default color of nodes in graph plots."""
__all__ = [
@@ -32,32 +37,201 @@
"line_best_by_iteration",
"scatter_criterion_vs_n_estimated",
"upset",
+ "PlotData",
+ "LABEL_FONTSIZE",
+ "TICK_LABEL_FONTSIZE",
+ "DEFAULT_NODE_COLOR",
]
-def upset(
- models: Models, criterion: Criterion
-) -> dict[str, matplotlib.axes.Axes | None]:
- """Plot an UpSet plot of estimated parameters and criterion.
+class _IdentiDict(dict):
+ """Dummy dictionary that returns keys as values."""
- Args:
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def __getitem__(self, key):
+ return key
+
+
+class PlotData:
+ """Manage data used in plots.
+
+ Attributes:
models:
The models.
criterion:
The criterion.
+ labels:
+ Model labels. Defaults to :attr:`petab_select.Model.model_label`
+ or :attr:`petab_select.Model.model_id` or
+ :attr:`petab_select.Model.hash`.
+ Keys are model hashes, values are the labels.
+ relative_criterion:
+ Whether to display relative criterion values.
+ parameter_labels:
+ Labels for parameters. Keys are parameter IDs, values are labels.
+ To use the ``parameterName`` column of your PEtab parameters table,
+ provide ``petab_problem.parameter_df["parameterName"].to_dict()``.
+ colors:
+ Colors for each model. Keys are model hashes, values are
+ matplotlib color specifiers (
+ https://matplotlib.org/stable/users/explain/colors/colors.html ).
+ """
+
+ def __init__(
+ self,
+ models: Models,
+ criterion: Criterion | None = None,
+ relative_criterion: bool = False,
+ labels: dict[ModelHash, str] | None = None,
+ parameter_labels: dict[str, str] | None = None,
+ colors: dict[str, str] = None,
+ ):
+ self.models = models
+ self.criterion = criterion
+ self.relative_criterion = relative_criterion
+ self.parameter_labels = parameter_labels or _IdentiDict()
+ self.colors = colors or {}
+
+ labels = labels or {}
+ self.labels = self.get_default_labels() | labels
+ self.labels0 = copy.deepcopy(self.labels)
+
+ def get_default_labels(self) -> dict[ModelHash, str]:
+ """Get default model labels."""
+ return (
+ {
+ model.hash: model.model_label or model.model_id or model.hash
+ for model in self.models
+ }
+ | {
+ model.predecessor_model_hash: model.predecessor_model_hash
+ for model in self.models
+ }
+ | {VIRTUAL_INITIAL_MODEL.hash: "Virtual\nInitial\nModel"}
+ )
+
+ def augment_labels(
+ self,
+ criterion: bool = False,
+ added_parameters: bool = False,
+ removed_parameters: bool = False,
+ ) -> None:
+ """Add information to the plotted model labels.
+
+ N.B.: this resets any previous augmentation.
+
+ Args:
+ criterion:
+ Whether to include the criterion values.
+ added_parameters:
+ Whether to include the added parameters, compared to the predecessor model.
+ removed_parameters:
+ Whether to include the removed parameters, compared to the predecessor model.
+ """
+ if criterion:
+ if not self.criterion:
+ raise ValueError(
+ "Please construct the `PlotData` object with a specified "
+ "`criterion` first."
+ )
+ criterion_values = self.models.get_criterion(
+ criterion=self.criterion,
+ relative=self.relative_criterion,
+ as_dict=True,
+ )
+ if added_parameters or removed_parameters:
+ parameter_changes = analyze.get_parameter_changes(
+ models=self.models,
+ as_dict=True,
+ )
+
+ self.labels = copy.deepcopy(self.labels0)
+ for model_hash, model_label in self.labels.items():
+ if model_hash == VIRTUAL_INITIAL_MODEL.hash:
+ continue
+
+ criterion_label = None
+ added_parameters_label = None
+ removed_parameters_label = None
+
+ if criterion and model_hash in self.models:
+ criterion_label = f"{criterion_values[model_hash]:.2f}"
+
+ if added_parameters:
+ added_parameters_label = (
+ "+ {"
+ + ",".join(
+ self.parameter_labels[parameter_id]
+ for parameter_id in sorted(
+ parameter_changes[model_hash][0]
+ )
+ )
+ + "}"
+ )
+
+ if removed_parameters:
+ removed_parameters_label = (
+ "- {"
+ + ",".join(
+ self.parameter_labels[parameter_id]
+ for parameter_id in sorted(
+ parameter_changes[model_hash][1]
+ )
+ )
+ + "}"
+ )
+
+ self.labels[model_hash] = "\n".join(
+ str(label_part)
+ for label_part in [
+ self.labels[model_hash],
+ criterion_label,
+ added_parameters_label,
+ removed_parameters_label,
+ ]
+ if label_part is not None
+ )
+
+
+def _set_axis_fontsizes(ax: plt.Axes) -> None:
+ """Set the axes (tick) label fontsizes to the global variables.
+
+ Customize via :const:`petab_select.plot.LABEL_FONTSIZE` and
+ :const:`petab_select.plot.TICK_LABEL_FONTSIZE`.
+ """
+ ax.xaxis.label.set_size(LABEL_FONTSIZE)
+ ax.yaxis.label.set_size(LABEL_FONTSIZE)
+ ax.xaxis.set_tick_params(which="major", labelsize=TICK_LABEL_FONTSIZE)
+ ax.yaxis.set_tick_params(which="major", labelsize=TICK_LABEL_FONTSIZE)
+
+
+def upset(plot_data: PlotData) -> dict[str, matplotlib.axes.Axes | None]:
+ """Plot an UpSet plot of estimated parameters and criterion.
+
+ Args:
+ plot_data:
+ The plot data.
Returns:
- The plot axes (see documentation from the `upsetplot `__ package).
+ The plot axes (see documentation from the
+ `upsetplot `__ package).
"""
# Get delta criterion values
- values = np.array(models.get_criterion(criterion=criterion, relative=True))
+ values = np.array(
+ plot_data.models.get_criterion(
+ criterion=plot_data.criterion,
+ relative=plot_data.relative_criterion,
+ )
+ )
# Sort by criterion value
index = np.argsort(values)
values = values[index]
labels = [
model.get_estimated_parameter_ids()
- for model in np.array(models)[index]
+ for model in np.array(plot_data.models)[index]
]
with warnings.catch_warnings():
@@ -70,48 +244,32 @@ def upset(
axes = upsetplot.plot(
series, totals_plot_elements=0, with_lines=False, sort_by="input"
)
- axes["intersections"].set_ylabel(r"$\Delta$" + criterion)
+ axes["intersections"].set_ylabel(r"$\Delta$" + plot_data.criterion)
+ _set_axis_fontsizes(ax=axes["intersections"])
+ _set_axis_fontsizes(ax=axes["matrix"])
return axes
def line_best_by_iteration(
- models: Models,
- criterion: Criterion,
- relative: bool = True,
- fz: int = 14,
- labels: dict[str, str] = None,
+ plot_data: PlotData,
ax: plt.Axes = None,
) -> plt.Axes:
"""Plot the improvement in criterion across iterations.
Args:
- models:
- A list of all models. The selected models will be inferred from the
- best model and its chain of predecessor model hashes.
- criterion:
- The criterion by which models are selected.
- relative:
- If ``True``, criterion values are plotted relative to the lowest
- criterion value.
- fz:
- fontsize
- labels:
- A dictionary of model labels, where keys are model hashes, and
- values are model labels, for plotting. If a model label is not
- provided, it will be generated from its model ID.
+ plot_data:
+ The plot data.
ax:
The axis to use for plotting.
Returns:
The plot axes.
"""
- best_by_iteration = get_best_by_iteration(
- models=models, criterion=criterion
+ best_by_iteration = analyze.get_best_by_iteration(
+ models=plot_data.models,
+ criterion=plot_data.criterion,
)
- if labels is None:
- labels = {}
-
# FIGURE
if ax is None:
_, ax = plt.subplots(figsize=(5, 4))
@@ -122,19 +280,20 @@ def line_best_by_iteration(
[best_by_iteration[iteration] for iteration in iterations]
)
iteration_labels = [
- str(iteration) + f"\n({labels.get(model.hash, model.model_id)})"
+ str(iteration)
+ + f"\n({plot_data.labels.get(model.hash, model.model_id)})"
for iteration, model in zip(iterations, best_models, strict=True)
]
criterion_values = best_models.get_criterion(
- criterion=criterion, relative=relative
+ criterion=plot_data.criterion, relative=plot_data.relative_criterion
)
ax.plot(
iteration_labels,
criterion_values,
linewidth=linewidth,
- color=NORMAL_NODE_COLOR,
+ color=DEFAULT_NODE_COLOR,
marker="x",
markersize=10,
markeredgewidth=2,
@@ -144,16 +303,13 @@ def line_best_by_iteration(
ax.get_xticks()
ax.set_xticks(list(range(len(criterion_values))))
- ax.set_xlabel("Iteration and model", fontsize=fz)
- ax.set_ylabel((r"$\Delta$" if relative else "") + criterion, fontsize=fz)
- # could change to compared_model_ids, if all models are plotted
- ax.set_xticklabels(
- ax.get_xticklabels(),
- fontsize=fz + RELATIVE_LABEL_FONTSIZE,
- )
- ax.yaxis.set_tick_params(
- which="major", labelsize=fz + RELATIVE_LABEL_FONTSIZE
+ ax.set_xlabel("Iteration and model")
+ ax.set_ylabel(
+ (r"$\Delta$" if plot_data.relative_criterion else "")
+ + plot_data.criterion
)
+ # could change to compared_model_ids, if all models are plotted
+ _set_axis_fontsizes(ax=ax)
ytl = ax.get_yticks()
ax.set_ylim([min(ytl), max(ytl)])
# removing top and right borders
@@ -163,33 +319,18 @@ def line_best_by_iteration(
def graph_history(
- models: Models,
- criterion: Criterion = None,
- labels: dict[str, str] = None,
- colors: dict[str, str] = None,
+ plot_data: PlotData,
draw_networkx_kwargs: dict[str, Any] = None,
- relative: bool = True,
spring_layout_kwargs: dict[str, Any] = None,
ax: plt.Axes = None,
) -> plt.Axes:
"""Plot all calibrated models in the model space, as a directed graph.
Args:
- models:
- A list of models.
- criterion:
- The criterion.
- labels:
- A dictionary of model labels, where keys are model hashes, and
- values are model labels, for plotting. If a model label is not
- provided, it will be generated from its model ID.
- colors:
- Colors for each model, using their labels.
+ plot_data:
+ The plot data.
draw_networkx_kwargs:
Forwarded to ``networkx.draw_networkx``.
- relative:
- If ``True``, criterion values are offset by the minimum criterion
- value.
spring_layout_kwargs:
Forwarded to ``networkx.spring_layout``.
ax:
@@ -203,49 +344,14 @@ def graph_history(
if spring_layout_kwargs is None:
spring_layout_kwargs = default_spring_layout_kwargs
- criterion_values = models.get_criterion(
- criterion=criterion, relative=relative, as_dict=True
+ criterion_values = plot_data.models.get_criterion(
+ criterion=plot_data.criterion,
+ relative=plot_data.relative_criterion,
+ as_dict=True,
)
- if labels is None:
- labels = {
- model.hash: model.model_id
- + (
- f"\n{criterion_values[model.hash]:.2f}"
- if criterion is not None
- else ""
- )
- for model in models
- }
- labels = labels.copy()
- labels[VIRTUAL_INITIAL_MODEL.hash] = "Virtual\nInitial\nModel"
-
- G = nx.DiGraph()
- edges = []
- for model in models:
- predecessor_model_hash = model.predecessor_model_hash
- if predecessor_model_hash is not None:
- from_ = labels.get(predecessor_model_hash, predecessor_model_hash)
- # may only not be the case for
- # COMPARED_MODEL_ID == INITIAL_VIRTUAL_MODEL
- if predecessor_model_hash in models:
- predecessor_model = models[predecessor_model_hash]
- from_ = labels.get(
- predecessor_model.hash,
- predecessor_model.model_id,
- )
- else:
- raise NotImplementedError(
- "Plots for models with `None` as their predecessor model are "
- "not yet implemented."
- )
- from_ = "None"
- to = labels.get(model.hash, model.model_id)
- edges.append((from_, to))
-
- G.add_edges_from(edges)
default_draw_networkx_kwargs = {
- "node_color": NORMAL_NODE_COLOR,
+ "node_color": DEFAULT_NODE_COLOR,
"arrowstyle": "-|>",
"node_shape": "s",
"node_size": 2500,
@@ -253,18 +359,22 @@ def graph_history(
}
if draw_networkx_kwargs is None:
draw_networkx_kwargs = default_draw_networkx_kwargs
- if colors is not None:
- if label_diff := set(colors).difference(list(G)):
- raise ValueError(
- "Colors were provided for the following model labels, but "
- f"these are not in the graph: {label_diff}"
- )
+ G = analyze.get_graph(models=plot_data.models)
+ # Set colors
+ if label_diff := set(plot_data.colors).difference(G.nodes):
+ raise ValueError(
+ "Colors were provided for the following model labels, but "
+ f"these are not in the graph: `{label_diff}`."
+ )
- node_colors = [
- colors.get(model_label, default_draw_networkx_kwargs["node_color"])
- for model_label in list(G)
- ]
- draw_networkx_kwargs.update({"node_color": node_colors})
+ node_colors = [
+ plot_data.colors.get(
+ model_hash, default_draw_networkx_kwargs["node_color"]
+ )
+ for model_hash in G.nodes
+ ]
+ draw_networkx_kwargs.update({"node_color": node_colors})
+ nx.relabel_nodes(G, mapping=plot_data.labels, copy=False)
if ax is None:
_, ax = plt.subplots(figsize=(12, 12))
@@ -276,35 +386,19 @@ def graph_history(
def bar_criterion_vs_models(
- models: list[Model],
- criterion: Criterion = None,
- labels: dict[str, str] = None,
- relative: bool = True,
+ plot_data: PlotData,
ax: plt.Axes = None,
bar_kwargs: dict[str, Any] = None,
- colors: dict[str, str] = None,
) -> plt.Axes:
"""Plot all calibrated models and their criterion value.
Args:
- models:
- A list of models.
- criterion:
- The criterion.
- labels:
- A dictionary of model labels, where keys are model hashes, and
- values are model labels, for plotting. If a model label is not
- provided, it will be generated from its model ID.
- relative:
- If `True`, criterion values are offset by the minimum criterion
- value.
+ plot_data:
+ The plot data.
ax:
The axis to use for plotting.
bar_kwargs:
Passed to the matplotlib `ax.bar` call.
- colors:
- Custom colors. Keys are model hashes or ``labels`` (if provided),
- values are matplotlib colors.
Returns:
The plot axes.
@@ -312,69 +406,52 @@ def bar_criterion_vs_models(
if bar_kwargs is None:
bar_kwargs = {}
- if labels is None:
- labels = {model.hash: model.model_id for model in models}
-
if ax is None:
_, ax = plt.subplots()
- bar_model_labels = [
- labels.get(model.hash, model.model_id) for model in models
- ]
- criterion_values = models.get_criterion(
- criterion=criterion, relative=relative
+ criterion_values = plot_data.models.get_criterion(
+ criterion=plot_data.criterion,
+ relative=plot_data.relative_criterion,
)
- if colors is not None:
- if label_diff := set(colors).difference(bar_model_labels):
- raise ValueError(
- "Colors were provided for the following model labels, but "
- f"these are not in the graph: {label_diff}"
- )
+ if label_diff := set(plot_data.colors).difference(plot_data.labels):
+ raise ValueError(
+ "Colors were provided for the following model labels, but "
+ f"these are not in the graph: {label_diff}"
+ )
- bar_kwargs["color"] = [
- colors.get(model_label, NORMAL_NODE_COLOR)
- for model_label in criterion_values
- ]
+ bar_kwargs["color"] = [
+ plot_data.colors.get(model_label, DEFAULT_NODE_COLOR)
+ for model_label in criterion_values
+ ]
- ax.bar(bar_model_labels, criterion_values, **bar_kwargs)
+ labels = [
+ plot_data.labels[model_hash] for model_hash in plot_data.models.hashes
+ ]
+ ax.bar(labels, criterion_values, **bar_kwargs)
ax.set_xlabel("Model")
ax.set_ylabel(
- (r"$\Delta$" if relative else "") + criterion,
+ (r"$\Delta$" if plot_data.relative_criterion else "")
+ + plot_data.criterion
)
+ _set_axis_fontsizes(ax=ax)
return ax
def scatter_criterion_vs_n_estimated(
- models: list[Model],
- criterion: Criterion = None,
- relative: bool = True,
+ plot_data: PlotData,
ax: plt.Axes = None,
- colors: dict[str, str] = None,
- labels: dict[str, str] = None,
scatter_kwargs: dict[str, str] = None,
max_jitter: float = 0.2,
) -> plt.Axes:
"""Plot criterion values against number of estimated parameters.
Args:
- models:
- A list of models.
- criterion:
- The criterion.
- relative:
- If `True`, criterion values are offset by the minimum criterion
- value.
+ plot_data:
+ The plot data.
ax:
The axis to use for plotting.
- colors:
- Custom colors. Keys are model hashes or ``labels`` (if provided),
- values are matplotlib colors.
- labels:
- A dictionary of model labels, where keys are model hashes, and
- values are model labels, for plotting. If a model label is not
- provided, it will be generated from its model ID.
scatter_kwargs:
Forwarded to ``matplotlib.axes.Axes.scatter``.
max_jitter:
@@ -385,31 +462,25 @@ def scatter_criterion_vs_n_estimated(
Returns:
The plot axes.
"""
- labels = {
- model.hash: labels.get(model.model_id, model.model_id)
- for model in models
- }
-
if scatter_kwargs is None:
scatter_kwargs = {}
- if colors is not None:
- if label_diff := set(colors).difference(labels.values()):
- raise ValueError(
- "Colors were provided for the following model labels, but "
- f"these are not in the graph: {label_diff}"
- )
- scatter_kwargs["c"] = [
- colors.get(model_label, NORMAL_NODE_COLOR)
- for model_label in labels.values()
- ]
+ if hash_diff := set(plot_data.colors).difference(plot_data.models.hashes):
+ raise ValueError(
+ "Colors were provided for the following model hashes, but "
+ f"these are not in the graph: {hash_diff}"
+ )
+ scatter_kwargs["c"] = [
+ plot_data.colors.get(model_hash, DEFAULT_NODE_COLOR)
+ for model_hash in plot_data.models.hashes
+ ]
n_estimated = []
- for model in models:
+ for model in plot_data.models:
n_estimated.append(len(model.get_estimated_parameter_ids()))
- criterion_values = models.get_criterion(
- criterion=criterion, relative=relative
+ criterion_values = plot_data.models.get_criterion(
+ criterion=plot_data.criterion, relative=plot_data.relative_criterion
)
if max_jitter:
@@ -432,49 +503,31 @@ def scatter_criterion_vs_n_estimated(
ax.set_xlabel("Number of estimated parameters")
ax.set_ylabel(
- (r"$\Delta$" if relative else "") + criterion,
+ (r"$\Delta$" if plot_data.relative_criterion else "")
+ + plot_data.criterion,
)
+ _set_axis_fontsizes(ax=ax)
return ax
def graph_iteration_layers(
- models: list[Model],
- criterion: Criterion,
- labels: dict[str, str] = None,
- relative: bool = True,
+ plot_data: PlotData,
ax: plt.Axes = None,
draw_networkx_kwargs: dict[str, Any] | None = None,
# colors: Dict[str, str] = None,
- parameter_labels: dict[str, str] = None,
- augment_labels: bool = True,
colorbar_mappable: matplotlib.cm.ScalarMappable = None,
# use_tex: bool = True,
) -> plt.Axes:
"""Graph the models of each iteration of model selection.
Args:
- models:
- A list of models.
- criterion:
- The criterion.
- labels:
- A dictionary of model labels, where keys are model hashes, and
- values are model labels, for plotting. If a model label is not
- provided, it will be generated from its model ID.
- relative:
- If ``True``, criterion values are offset by the minimum criterion
- value.
+ plot_data:
+ The plot data.
ax:
The axis to use for plotting.
draw_networkx_kwargs:
Passed to the ``networkx.draw_networkx`` call.
- parameter_labels:
- A dictionary of parameter labels, where keys are parameter IDs, and
- values are parameter labels, for plotting. Defaults to parameter IDs.
- augment_labels:
- If ``True``, provided labels will be augmented with the relative
- change in parameters.
colorbar_mappable:
Customize the colors.
See documentation for the `mappable` argument of
@@ -487,7 +540,6 @@ def graph_iteration_layers(
_, ax = plt.subplots(figsize=(20, 10))
default_draw_networkx_kwargs = {
- #'node_color': NORMAL_NODE_COLOR,
"arrowstyle": "-|>",
"node_shape": "s",
"node_size": 250,
@@ -496,110 +548,24 @@ def graph_iteration_layers(
if draw_networkx_kwargs is None:
draw_networkx_kwargs = default_draw_networkx_kwargs
- ancestry = {model.hash: model.predecessor_model_hash for model in models}
- ancestry_as_set = {k: {v} for k, v in ancestry.items()}
+ model_criterion_values = plot_data.models.get_criterion(
+ criterion=plot_data.criterion,
+ relative=plot_data.relative_criterion,
+ as_dict=True,
+ )
+ G = analyze.get_graph(models=plot_data.models)
+
+ # The ordering of models into iterations
ordering = [
[model.hash for model in iteration_models]
- for iteration_models in group_by_iteration(models).values()
+ for iteration_models in analyze.group_by_iteration(
+ plot_data.models
+ ).values()
]
- if VIRTUAL_INITIAL_MODEL.hash in ancestry.values():
+ if VIRTUAL_INITIAL_MODEL.hash in G.nodes:
ordering.insert(0, [VIRTUAL_INITIAL_MODEL.hash])
- model_estimated_parameters = {
- model.hash: set(model.estimated_parameters) for model in models
- }
- model_criterion_values = models.get_criterion(
- criterion=criterion, relative=relative, as_dict=True
- )
-
- model_parameter_diffs = {
- model.hash: (
- (set(), set())
- if model.predecessor_model_hash not in model_estimated_parameters
- else (
- model_estimated_parameters[model.hash].difference(
- model_estimated_parameters[model.predecessor_model_hash]
- ),
- model_estimated_parameters[
- model.predecessor_model_hash
- ].difference(model_estimated_parameters[model.hash]),
- )
- )
- for model in models
- }
-
- labels = labels or {}
- labels[VIRTUAL_INITIAL_MODEL.hash] = labels.get(
- VIRTUAL_INITIAL_MODEL.hash, "Virtual\nInitial\nModel"
- )
- labels = (
- labels
- | {
- model.hash: model.model_id
- for model in models
- if model.hash not in labels
- }
- | {
- model.predecessor_model_hash: model.predecessor_model_hash
- for model in models
- if model.predecessor_model_hash not in labels
- }
- )
- if augment_labels:
-
- class Identidict(dict):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
-
- def __getitem__(self, key):
- return key
-
- if parameter_labels is None:
- parameter_labels = Identidict()
-
- model_added_parameters = {
- model_hash: ",".join(
- [
- parameter_labels[parameter_id]
- for parameter_id in sorted(
- model_parameter_diffs[model_hash][0]
- )
- ]
- )
- for model_hash in model_estimated_parameters
- }
- model_removed_parameters = {
- model_hash: ",".join(
- [
- parameter_labels[parameter_id]
- for parameter_id in sorted(
- model_parameter_diffs[model_hash][1]
- )
- ]
- )
- for model_hash in model_estimated_parameters
- }
-
- labels = {
- model_hash: (
- label0
- if model_hash == VIRTUAL_INITIAL_MODEL.hash
- else "\n".join(
- [
- label0,
- "+ {"
- + model_added_parameters.get(model_hash, "")
- + "}",
- "- {"
- + model_removed_parameters.get(model_hash, "")
- + "}",
- ]
- )
- )
- for model_hash, label0 in labels.items()
- }
-
if colorbar_mappable is None:
norm = matplotlib.colors.Normalize(
vmin=min(model_criterion_values.values()),
@@ -610,14 +576,9 @@ def __getitem__(self, key):
)
colorbar_mappable = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap)
cbar = ax.get_figure().colorbar(colorbar_mappable, ax=ax)
- cbar.ax.set_title(r"$\Delta$" + criterion)
+ cbar.ax.set_title(r"$\Delta$" + plot_data.criterion)
- G = nx.DiGraph(
- [
- (predecessor, successor)
- for successor, predecessor in ancestry.items()
- ]
- )
+ # Model node positions
# FIXME change positions so that top edge of topmost node, and bottom edge
# of bottommost nodes, are at 1 and 0, respectively
X = [
@@ -633,24 +594,23 @@ def __getitem__(self, key):
]
pos = {
- labels[model_hash]: (X[i], Y[i][j])
+ plot_data.labels[model_hash]: (X[i], Y[i][j])
for i, layer in enumerate(ordering)
for j, model_hash in enumerate(layer)
}
- nx.relabel_nodes(G, mapping=labels, copy=False)
- G_hashes = [
- one([k for k, v in labels.items() if v == label]) for label in G.nodes
- ]
node_colors = [
(
colorbar_mappable.to_rgba(model_criterion_values[model_hash])
if model_hash in model_criterion_values
- else NORMAL_NODE_COLOR
+ else DEFAULT_NODE_COLOR
)
- for model_hash in G_hashes
+ for model_hash in G.nodes
]
+ # Apply custom labels
+ nx.relabel_nodes(G, mapping=plot_data.labels, copy=False)
+
nx.draw_networkx(
G, pos, ax=ax, node_color=node_colors, **draw_networkx_kwargs
)
@@ -665,76 +625,6 @@ def __getitem__(self, key):
ha="center",
)
- ## Get selected parameter IDs
- ## TODO move this logic elsewhere
- # selected_hashes = set(ancestry.values())
- # selected_models = {}
- # for model in models:
- # if model.hash in selected_hashes:
- # selected_models[model.hash] = model
-
- # selected_parameters = {
- # model_hash: sorted(model.estimated_parameters)
- # for model_hash, model in selected_models.items()
- # }
-
- # selected_order = [
- # [model_hash for model_hash in layer if model_hash in selected_models]
- # for layer in ordering
- # ]
- # selected_order = [
- # None if not model_hash else one(model_hash)
- # for model_hash in selected_order
- # ]
-
- # selected_parameter_ids = []
- # estimated0 = None
- # model_hash = None
- # for model_hash in selected_order:
- # if model_hash is None:
- # selected_parameter_ids.append('')
- # continue
- # if estimated0 is not None:
- # new_parameter_ids = list(
- # set(selected_parameters[model_hash]).symmetric_difference(
- # estimated0
- # )
- # )
- # new_parameter_names = []
- # for new_parameter_id in new_parameter_ids:
- # # Default to parameter ID, use parameter name if available
- # new_parameter_name = new_parameter_id
- # if (
- # PARAMETER_NAME
- # in selected_models[
- # model_hash
- # ].petab_problem.parameter_df.columns
- # ):
- # petab_parameter_name = selected_models[
- # model_hash
- # ].petab_problem.parameter_df.loc[
- # new_parameter_id, PARAMETER_NAME
- # ]
- # if not pd.isna(petab_parameter_name):
- # new_parameter_name = petab_parameter_name
- # new_parameter_names.append(new_parameter_name)
- # new_parameter_names = [
- # new_parameter_name.replace('\\\\rightarrow ', '->')
- # for new_parameter_name in new_parameter_names
- # ]
- # selected_parameter_ids.append(sorted(new_parameter_names))
- # else:
- # selected_parameter_ids.append([''])
- # estimated0 = selected_parameters[model_hash]
-
- ## Add labels for selected parameters
- # for x, label in zip(X, selected_parameter_ids):
- # ax.annotate(
- # "\n".join(label),
- # xy=(x, 1.15),
- # fontsize=draw_networkx_kwargs.get('font_size', 20),
- # )
-
# Set margins for the axes so that nodes aren't clipped
ax.margins(0.15)
ax.axis("off")
diff --git a/petab_select/problem.py b/petab_select/problem.py
index e30511c7..a64e62a1 100644
--- a/petab_select/problem.py
+++ b/petab_select/problem.py
@@ -265,7 +265,7 @@ def get_best(
The best model.
"""
warnings.warn(
- "Use ``petab_select.analyze.get_best`` instead.",
+ "Use ``petab_select.ui.get_best`` instead.",
DeprecationWarning,
stacklevel=2,
)
diff --git a/pyproject.toml b/pyproject.toml
index 934171de..4cac5fa5 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -24,13 +24,12 @@ dependencies = [
"click>=8.1.7",
"dill>=0.3.9",
"mkstd>=0.0.8",
+ "networkx>=3.2",
]
[project.optional-dependencies]
plot = [
"matplotlib>=2.2.3",
"upsetplot",
- "toposort",
- "networkx",
]
test = [
"pytest >= 5.4.3",