Skip to content

Commit

Permalink
add params to CurveFits.combineCurveFits (#41)
Browse files Browse the repository at this point in the history
* draft of `combineCurveFits` w new params

* test new params to `CurveFits.combineCurveFits`
  • Loading branch information
jbloom authored Dec 14, 2023
1 parent 51d9677 commit 6f544b1
Show file tree
Hide file tree
Showing 4 changed files with 319 additions and 85 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@ All notable changes to this project will be documented in this file.

The format is based on `Keep a Changelog <https://keepachangelog.com>`_.

0.10.0
------

Added
+++++
- Added the following parameters to ``CurveFits.combineCurveFits``: ``sera``, ``viruses``, ``serum_virus_replicates_to_drop``.

0.9.0
-----

Expand Down
2 changes: 1 addition & 1 deletion neutcurve/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

__author__ = "Jesse Bloom"
__email__ = "[email protected]"
__version__ = "0.9.0"
__version__ = "0.10.0"
__url__ = "https://github.com/jbloomlab/neutcurve"

from neutcurve.curvefits import CurveFits # noqa: F401
Expand Down
123 changes: 79 additions & 44 deletions neutcurve/curvefits.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,27 @@ class CurveFits:
_WILDTYPE_NAMES = ("WT", "wt", "wildtype", "Wildtype", "wild type", "Wild type")

@staticmethod
def combineCurveFits(curvefits_list):
def combineCurveFits(
curvefits_list,
*,
sera=None,
viruses=None,
serum_virus_replicates_to_drop=None,
):
"""
Args:
`curvesfit_list` (list)
List of :class:`CurveFits` objects that are identical other than the
data they contain and have unique virus/serum/replicate combinations.
They can differ in `fixtop` and `fixbottom`, but then those
will be set to `None` in the returned object.
`sera` (None or list)
Only keep fits for sera in this list, or keep all sera if `None`.
`viruses` (None or list)
Only keep fits for viruses in this list, or keep all sera if `None`.
`serum_virus_replicates_to_drop` (None or list)
If a list, should specify `(serum, virus, replicates)` tuples, and those
particular fits are dropped.
Returns:
combined_fits (:class:`CurveFits`)
Expand Down Expand Up @@ -147,6 +160,31 @@ def combineCurveFits(curvefits_list):
combined_fits.df = combined_fits.df[
combined_fits.df[combined_fits.replicate_col] != "average"
]
for col, keeplist in [
(combined_fits.serum_col, sera),
(combined_fits.virus_col, viruses),
]:
if keeplist is not None:
combined_fits.df = combined_fits.df[
combined_fits.df[col].isin(keeplist)
]
if serum_virus_replicates_to_drop:
assert "tup" not in set(combined_fits.df.columns)
combined_fits.df = (
combined_fits.df.assign(
tup=lambda x: list(
x[
[
combined_fits.serum_col,
combined_fits.virus_col,
combined_fits.replicate_col,
]
].itertuples(index=False, name=None)
),
)
.query("tup not in @serum_virus_replicates_to_drop")
.drop(columns="tup")
)
combined_fits.df = combined_fits._get_avg_and_stderr_df(combined_fits.df)
if len(combined_fits.df) != len(
combined_fits.df.groupby(
Expand All @@ -161,56 +199,47 @@ def combineCurveFits(curvefits_list):
raise ValueError("duplicated sera/virus/replicate in `curvefits_list`")

# combine sera
combined_fits.sera = list(
dict.fromkeys([serum for f in curvefits_list for serum in f.sera])
)
assert set(combined_fits.sera) == set(combined_fits.df[combined_fits.serum_col])
combined_fits.sera = combined_fits.df[combined_fits.serum_col].unique().tolist()

# combine allviruses
combined_fits.allviruses = list(
dict.fromkeys([virus for f in curvefits_list for virus in f.allviruses])
)
assert set(combined_fits.allviruses) == set(
combined_fits.df[combined_fits.virus_col]
combined_fits.allviruses = (
combined_fits.df[combined_fits.virus_col].unique().tolist()
)

# combine viruses and replicates
combined_fits.viruses = {}
combined_fits.replicates = {}
for serum in combined_fits.sera:
combined_fits.viruses[serum] = []
for virus in combined_fits.allviruses:
for f in curvefits_list:
if (
(serum in f.viruses)
and (virus in f.viruses[serum])
and (virus not in combined_fits.viruses[serum])
):
combined_fits.viruses[serum].append(virus)
if (serum, virus) in f.replicates:
if (serum, virus) not in combined_fits.replicates:
combined_fits.replicates[(serum, virus)] = f.replicates[
(serum, virus)
]
else:
combined_fits.replicates[(serum, virus)] += f.replicates[
(serum, virus)
]
for key, val in combined_fits.replicates.items():
val = dict.fromkeys(val)
if "average" in val:
del val["average"]
if len(val) != len(set(val)):
raise ValueError(f"duplicate replicate for {key}")
combined_fits.replicates[key] = list(val)
combined_fits.replicates[key].append("average")
assert combined_fits.serum_col != "viruses"
combined_fits.viruses = (
combined_fits.df.groupby(combined_fits.serum_col, sort=False)
.aggregate(
viruses=pd.NamedAgg(
combined_fits.virus_col,
lambda s: s.unique().tolist(),
),
)["viruses"]
.to_dict()
)
assert combined_fits.serum_col != "replicate"
assert combined_fits.virus_col != "replicate"
combined_fits.replicates = (
combined_fits.df[combined_fits.df[combined_fits.replicate_col] != "average"]
.groupby([combined_fits.serum_col, combined_fits.virus_col], sort=False)
.aggregate(
replicates=pd.NamedAgg(
combined_fits.replicate_col,
lambda s: s.unique().tolist(),
),
)["replicates"]
.to_dict()
)
for serum, virus in combined_fits.replicates:
combined_fits.replicates[(serum, virus)].append("average")
serum_virus_rep_tups = [
(serum, virus, rep)
for (serum, virus), reps in combined_fits.replicates.items()
for rep in reps
]
assert len(serum_virus_rep_tups) == len(set(serum_virus_rep_tups))
assert set(serum_virus_rep_tups) == set(
combined_fits_tups = set(
combined_fits.df[
[
combined_fits.serum_col,
Expand All @@ -219,7 +248,10 @@ def combineCurveFits(curvefits_list):
]
].itertuples(index=False, name=None)
)
assert set(combined_fits.allviruses) == set(
assert (
set(serum_virus_rep_tups) == combined_fits_tups
), f"{combined_fits_tups=}\n\n{serum_virus_rep_tups=}"
assert set(combined_fits.allviruses).issubset(
v for f in curvefits_list for s in f.viruses.values() for v in s
)

Expand All @@ -228,9 +260,12 @@ def combineCurveFits(curvefits_list):
combined_fits._hillcurves = {}
for c in curvefits_list:
for (serum, virus, replicate), curve in c._hillcurves.items():
assert serum in combined_fits.sera
assert virus in combined_fits.allviruses
if replicate != "average":
if (
(serum in combined_fits.sera)
and (virus in combined_fits.allviruses)
and (replicate in combined_fits.replicates[(serum, virus)])
and (replicate != "average")
):
combined_fits._hillcurves[(serum, virus, replicate)] = curve

combined_fits._fitparams = {} # clear this cache
Expand Down
Loading

0 comments on commit 6f544b1

Please sign in to comment.