Skip to content

Commit

Permalink
feat: allow for specifying the order of population labels in `Breakpo…
Browse files Browse the repository at this point in the history
…ints.encode()` (#262)
  • Loading branch information
aryarm authored Dec 5, 2024
1 parent 16a84d1 commit 09a916a
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 9 deletions.
25 changes: 16 additions & 9 deletions haptools/data/breakpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,26 +160,31 @@ def __iter__(self, samples: set[str] = None) -> Iterable[str, SampleBlocks]:
yield samp, [np.array(b, dtype=HapBlock) for b in blocks]
bps.close()

def encode(self) -> dict[int, str]:
def encode(self, labels: tuple[str] = None):
"""
Replace each ancestral label in :py:attr:`~.Breakpoints.data` with an
equivalent integer. Store a dictionary mapping these integers back to their
respective labels.
respective labels in :py:attr:`~.Breakpoints.labels`.
This method modifies :py:attr:`~.Breakpoints.data` in place.
Returns
-------
dict[int, str]
A dictionary mapping each integer back to its ancestral label
Parameters
----------
labels: tuple[str], optional
A list of population labels. The order of the labels in this list will be
kept in the respective labels.
"""
if not (self.labels is None):
raise ValueError("The data has already been encoded.")
# save the order of the fields for later reordering
names = [f[0] for f in HapBlock]
# initialize labels dict and label counter
labels = {}
pop_count = 0
if labels is None:
labels = {}
else:
labels = {pop: i for i, pop in enumerate(labels)}
pop_count = len(labels)
seen = set()
for sample, blocks in self.data.items():
for strand_num in range(len(blocks)):
# initialize and fill the array of integers
Expand All @@ -189,10 +194,11 @@ def encode(self) -> dict[int, str]:
labels[pop] = pop_count
pop_count += 1
ints[i] = labels[pop]
seen.add(pop)
# replace the "pop" labels
arr = rcf.drop_fields(blocks[strand_num], ["pop"])
blocks[strand_num] = rcf.merge_arrays((arr, ints), flatten=True)[names]
self.labels = labels
self.labels = {k: v for k, v in labels.items() if k in seen}

def recode(self):
"""
Expand Down Expand Up @@ -332,6 +338,7 @@ def write(self):
--------
To write to a file, you must first initialize a Breakpoints object and then
fill out the names, data, and samples properties:
>>> from haptools.data import Breakpoints, HapBlock
>>> breakpoints = Breakpoints('simple.bp')
>>> breakpoints.data = {
Expand Down
34 changes: 34 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2075,6 +2075,40 @@ def test_encode(self):
for obs, exp in zip(obs_strand["pop"], exp_strand["pop"]):
assert expected.labels[exp] == obs

def test_encode_reorder(self):
expected = self._get_expected_breakpoints()
expected.labels = {"CEU": 0, "YRI": 1}

observed = self._get_expected_breakpoints()
observed.encode(labels=("CEU", "YRI", "AMR"))

assert observed.labels == expected.labels
assert len(expected.data) == len(observed.data)
for sample in expected.data:
for strand in range(len(expected.data[sample])):
exp_strand = expected.data[sample][strand]
obs_strand = observed.data[sample][strand]
assert len(exp_strand) == len(observed.data[sample][strand])
for obs, exp in zip(obs_strand["pop"], exp_strand["pop"]):
assert expected.labels[exp] == obs

# now try again with AMR in the middle
# In that case, it should keep the ordering when deciding the integers
# but the final labels should include the AMR key
expected.labels = {"CEU": 0, "YRI": 2}
observed = self._get_expected_breakpoints()
observed.encode(labels=("CEU", "AMR", "YRI"))

assert observed.labels == expected.labels
assert len(expected.data) == len(observed.data)
for sample in expected.data:
for strand in range(len(expected.data[sample])):
exp_strand = expected.data[sample][strand]
obs_strand = observed.data[sample][strand]
assert len(exp_strand) == len(observed.data[sample][strand])
for obs, exp in zip(obs_strand["pop"], exp_strand["pop"]):
assert expected.labels[exp] == obs

def test_recode(self):
expected = self._get_expected_breakpoints()

Expand Down

0 comments on commit 09a916a

Please sign in to comment.