From 09a916a3bf4c7e3494507af1d13673e2b78c8ff3 Mon Sep 17 00:00:00 2001
From: Arya Massarat <>
Date: Thu, 5 Dec 2024 10:19:56 -0800
Subject: [PATCH] feat: allow for specifying the order of population labels in
 `Breakpoints.encode()` (#262)

 haptools/data/ | 25 ++++++++++++++++---------
 tests/           | 34 ++++++++++++++++++++++++++++++++++
 2 files changed, 50 insertions(+), 9 deletions(-)

diff --git a/haptools/data/ b/haptools/data/
index 08b118d9..6610ca46 100644
--- a/haptools/data/
+++ b/haptools/data/
@@ -160,26 +160,31 @@ def __iter__(self, samples: set[str] = None) -> Iterable[str, SampleBlocks]:
             yield samp, [np.array(b, dtype=HapBlock) for b in blocks]
-    def encode(self) -> dict[int, str]:
+    def encode(self, labels: tuple[str] = None):
         Replace each ancestral label in :py:attr:`` with an
         equivalent integer. Store a dictionary mapping these integers back to their
-        respective labels.
+        respective labels in :py:attr:`~.Breakpoints.labels`.
         This method modifies :py:attr:`` in place.
-        Returns
-        -------
-        dict[int, str]
-            A dictionary mapping each integer back to its ancestral label
+        Parameters
+        ----------
+        labels: tuple[str], optional
+            A list of population labels. The order of the labels in this list will be
+            kept in the respective labels.
         if not (self.labels is None):
             raise ValueError("The data has already been encoded.")
         # save the order of the fields for later reordering
         names = [f[0] for f in HapBlock]
         # initialize labels dict and label counter
-        labels = {}
-        pop_count = 0
+        if labels is None:
+            labels = {}
+        else:
+            labels = {pop: i for i, pop in enumerate(labels)}
+        pop_count = len(labels)
+        seen = set()
         for sample, blocks in
             for strand_num in range(len(blocks)):
                 # initialize and fill the array of integers
@@ -189,10 +194,11 @@ def encode(self) -> dict[int, str]:
                         labels[pop] = pop_count
                         pop_count += 1
                     ints[i] = labels[pop]
+                    seen.add(pop)
                 # replace the "pop" labels
                 arr = rcf.drop_fields(blocks[strand_num], ["pop"])
                 blocks[strand_num] = rcf.merge_arrays((arr, ints), flatten=True)[names]
-        self.labels = labels
+        self.labels = {k: v for k, v in labels.items() if k in seen}
     def recode(self):
@@ -332,6 +338,7 @@ def write(self):
         To write to a file, you must first initialize a Breakpoints object and then
         fill out the names, data, and samples properties:
         >>> from import Breakpoints, HapBlock
         >>> breakpoints = Breakpoints('simple.bp')
         >>> = {
diff --git a/tests/ b/tests/
index d5d23fd8..2e0057c2 100644
--- a/tests/
+++ b/tests/
@@ -2075,6 +2075,40 @@ def test_encode(self):
                 for obs, exp in zip(obs_strand["pop"], exp_strand["pop"]):
                     assert expected.labels[exp] == obs
+    def test_encode_reorder(self):
+        expected = self._get_expected_breakpoints()
+        expected.labels = {"CEU": 0, "YRI": 1}
+        observed = self._get_expected_breakpoints()
+        observed.encode(labels=("CEU", "YRI", "AMR"))
+        assert observed.labels == expected.labels
+        assert len( == len(
+        for sample in
+            for strand in range(len([sample])):
+                exp_strand =[sample][strand]
+                obs_strand =[sample][strand]
+                assert len(exp_strand) == len([sample][strand])
+                for obs, exp in zip(obs_strand["pop"], exp_strand["pop"]):
+                    assert expected.labels[exp] == obs
+        # now try again with AMR in the middle
+        # In that case, it should keep the ordering when deciding the integers
+        # but the final labels should include the AMR key
+        expected.labels = {"CEU": 0, "YRI": 2}
+        observed = self._get_expected_breakpoints()
+        observed.encode(labels=("CEU", "AMR", "YRI"))
+        assert observed.labels == expected.labels
+        assert len( == len(
+        for sample in
+            for strand in range(len([sample])):
+                exp_strand =[sample][strand]
+                obs_strand =[sample][strand]
+                assert len(exp_strand) == len([sample][strand])
+                for obs, exp in zip(obs_strand["pop"], exp_strand["pop"]):
+                    assert expected.labels[exp] == obs
     def test_recode(self):
         expected = self._get_expected_breakpoints()