From 09a916a3bf4c7e3494507af1d13673e2b78c8ff3 Mon Sep 17 00:00:00 2001
From: Arya Massarat <23412689+aryarm@users.noreply.github.com>
Date: Thu, 5 Dec 2024 10:19:56 -0800
Subject: [PATCH] feat: allow for specifying the order of population labels in
 `Breakpoints.encode()` (#262)

---
 haptools/data/breakpoints.py | 25 ++++++++++++++++---------
 tests/test_data.py           | 34 ++++++++++++++++++++++++++++++++++
 2 files changed, 50 insertions(+), 9 deletions(-)

diff --git a/haptools/data/breakpoints.py b/haptools/data/breakpoints.py
index 08b118d9..6610ca46 100644
--- a/haptools/data/breakpoints.py
+++ b/haptools/data/breakpoints.py
@@ -160,26 +160,31 @@ def __iter__(self, samples: set[str] = None) -> Iterable[str, SampleBlocks]:
             yield samp, [np.array(b, dtype=HapBlock) for b in blocks]
         bps.close()
 
-    def encode(self) -> dict[int, str]:
+    def encode(self, labels: tuple[str] = None):
         """
         Replace each ancestral label in :py:attr:`~.Breakpoints.data` with an
         equivalent integer. Store a dictionary mapping these integers back to their
-        respective labels.
+        respective labels in :py:attr:`~.Breakpoints.labels`.
 
         This method modifies :py:attr:`~.Breakpoints.data` in place.
 
-        Returns
-        -------
-        dict[int, str]
-            A dictionary mapping each integer back to its ancestral label
+        Parameters
+        ----------
+        labels: tuple[str], optional
+            A list of population labels. The order of the labels in this list will be
+            kept in the respective labels.
         """
         if not (self.labels is None):
             raise ValueError("The data has already been encoded.")
         # save the order of the fields for later reordering
         names = [f[0] for f in HapBlock]
         # initialize labels dict and label counter
-        labels = {}
-        pop_count = 0
+        if labels is None:
+            labels = {}
+        else:
+            labels = {pop: i for i, pop in enumerate(labels)}
+        pop_count = len(labels)
+        seen = set()
         for sample, blocks in self.data.items():
             for strand_num in range(len(blocks)):
                 # initialize and fill the array of integers
@@ -189,10 +194,11 @@ def encode(self) -> dict[int, str]:
                         labels[pop] = pop_count
                         pop_count += 1
                     ints[i] = labels[pop]
+                    seen.add(pop)
                 # replace the "pop" labels
                 arr = rcf.drop_fields(blocks[strand_num], ["pop"])
                 blocks[strand_num] = rcf.merge_arrays((arr, ints), flatten=True)[names]
-        self.labels = labels
+        self.labels = {k: v for k, v in labels.items() if k in seen}
 
     def recode(self):
         """
@@ -332,6 +338,7 @@ def write(self):
         --------
         To write to a file, you must first initialize a Breakpoints object and then
         fill out the names, data, and samples properties:
+
         >>> from haptools.data import Breakpoints, HapBlock
         >>> breakpoints = Breakpoints('simple.bp')
         >>> breakpoints.data = {
diff --git a/tests/test_data.py b/tests/test_data.py
index d5d23fd8..2e0057c2 100644
--- a/tests/test_data.py
+++ b/tests/test_data.py
@@ -2075,6 +2075,40 @@ def test_encode(self):
                 for obs, exp in zip(obs_strand["pop"], exp_strand["pop"]):
                     assert expected.labels[exp] == obs
 
+    def test_encode_reorder(self):
+        expected = self._get_expected_breakpoints()
+        expected.labels = {"CEU": 0, "YRI": 1}
+
+        observed = self._get_expected_breakpoints()
+        observed.encode(labels=("CEU", "YRI", "AMR"))
+
+        assert observed.labels == expected.labels
+        assert len(expected.data) == len(observed.data)
+        for sample in expected.data:
+            for strand in range(len(expected.data[sample])):
+                exp_strand = expected.data[sample][strand]
+                obs_strand = observed.data[sample][strand]
+                assert len(exp_strand) == len(observed.data[sample][strand])
+                for obs, exp in zip(obs_strand["pop"], exp_strand["pop"]):
+                    assert expected.labels[exp] == obs
+
+        # now try again with AMR in the middle
+        # In that case, it should keep the ordering when deciding the integers
+        # but the final labels should include the AMR key
+        expected.labels = {"CEU": 0, "YRI": 2}
+        observed = self._get_expected_breakpoints()
+        observed.encode(labels=("CEU", "AMR", "YRI"))
+
+        assert observed.labels == expected.labels
+        assert len(expected.data) == len(observed.data)
+        for sample in expected.data:
+            for strand in range(len(expected.data[sample])):
+                exp_strand = expected.data[sample][strand]
+                obs_strand = observed.data[sample][strand]
+                assert len(exp_strand) == len(observed.data[sample][strand])
+                for obs, exp in zip(obs_strand["pop"], exp_strand["pop"]):
+                    assert expected.labels[exp] == obs
+
     def test_recode(self):
         expected = self._get_expected_breakpoints()