From 09a916a3bf4c7e3494507af1d13673e2b78c8ff3 Mon Sep 17 00:00:00 2001 From: Arya Massarat <23412689+aryarm@users.noreply.github.com> Date: Thu, 5 Dec 2024 10:19:56 -0800 Subject: [PATCH] feat: allow for specifying the order of population labels in `Breakpoints.encode()` (#262) --- haptools/data/breakpoints.py | 25 ++++++++++++++++--------- tests/test_data.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 9 deletions(-) diff --git a/haptools/data/breakpoints.py b/haptools/data/breakpoints.py index 08b118d9..6610ca46 100644 --- a/haptools/data/breakpoints.py +++ b/haptools/data/breakpoints.py @@ -160,26 +160,31 @@ def __iter__(self, samples: set[str] = None) -> Iterable[str, SampleBlocks]: yield samp, [np.array(b, dtype=HapBlock) for b in blocks] bps.close() - def encode(self) -> dict[int, str]: + def encode(self, labels: tuple[str] = None): """ Replace each ancestral label in :py:attr:`~.Breakpoints.data` with an equivalent integer. Store a dictionary mapping these integers back to their - respective labels. + respective labels in :py:attr:`~.Breakpoints.labels`. This method modifies :py:attr:`~.Breakpoints.data` in place. - Returns - ------- - dict[int, str] - A dictionary mapping each integer back to its ancestral label + Parameters + ---------- + labels: tuple[str], optional + A list of population labels. The order of the labels in this list will be + kept in the respective labels. """ if not (self.labels is None): raise ValueError("The data has already been encoded.") # save the order of the fields for later reordering names = [f[0] for f in HapBlock] # initialize labels dict and label counter - labels = {} - pop_count = 0 + if labels is None: + labels = {} + else: + labels = {pop: i for i, pop in enumerate(labels)} + pop_count = len(labels) + seen = set() for sample, blocks in self.data.items(): for strand_num in range(len(blocks)): # initialize and fill the array of integers @@ -189,10 +194,11 @@ def encode(self) -> dict[int, str]: labels[pop] = pop_count pop_count += 1 ints[i] = labels[pop] + seen.add(pop) # replace the "pop" labels arr = rcf.drop_fields(blocks[strand_num], ["pop"]) blocks[strand_num] = rcf.merge_arrays((arr, ints), flatten=True)[names] - self.labels = labels + self.labels = {k: v for k, v in labels.items() if k in seen} def recode(self): """ @@ -332,6 +338,7 @@ def write(self): -------- To write to a file, you must first initialize a Breakpoints object and then fill out the names, data, and samples properties: + >>> from haptools.data import Breakpoints, HapBlock >>> breakpoints = Breakpoints('simple.bp') >>> breakpoints.data = { diff --git a/tests/test_data.py b/tests/test_data.py index d5d23fd8..2e0057c2 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -2075,6 +2075,40 @@ def test_encode(self): for obs, exp in zip(obs_strand["pop"], exp_strand["pop"]): assert expected.labels[exp] == obs + def test_encode_reorder(self): + expected = self._get_expected_breakpoints() + expected.labels = {"CEU": 0, "YRI": 1} + + observed = self._get_expected_breakpoints() + observed.encode(labels=("CEU", "YRI", "AMR")) + + assert observed.labels == expected.labels + assert len(expected.data) == len(observed.data) + for sample in expected.data: + for strand in range(len(expected.data[sample])): + exp_strand = expected.data[sample][strand] + obs_strand = observed.data[sample][strand] + assert len(exp_strand) == len(observed.data[sample][strand]) + for obs, exp in zip(obs_strand["pop"], exp_strand["pop"]): + assert expected.labels[exp] == obs + + # now try again with AMR in the middle + # In that case, it should keep the ordering when deciding the integers + # but the final labels should include the AMR key + expected.labels = {"CEU": 0, "YRI": 2} + observed = self._get_expected_breakpoints() + observed.encode(labels=("CEU", "AMR", "YRI")) + + assert observed.labels == expected.labels + assert len(expected.data) == len(observed.data) + for sample in expected.data: + for strand in range(len(expected.data[sample])): + exp_strand = expected.data[sample][strand] + obs_strand = observed.data[sample][strand] + assert len(exp_strand) == len(observed.data[sample][strand]) + for obs, exp in zip(obs_strand["pop"], exp_strand["pop"]): + assert expected.labels[exp] == obs + def test_recode(self): expected = self._get_expected_breakpoints()