From feabb65220199aeabbf7483b9c31fcc8718a6127 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Tue, 5 Nov 2024 15:28:27 +0000 Subject: [PATCH] Add low-level Python-C support for arbirary derived state --- _tsinfermodule.c | 35 ++++++++++++++++++++++++++++++----- tests/test_low_level.py | 21 ++++++++++++++++++--- 2 files changed, 48 insertions(+), 8 deletions(-) diff --git a/_tsinfermodule.c b/_tsinfermodule.c index a887f962..7645d9a9 100644 --- a/_tsinfermodule.c +++ b/_tsinfermodule.c @@ -60,6 +60,17 @@ uint64_PyArray_converter(PyObject *in, PyObject **out) return NPY_SUCCEED; } +static int +int8_PyArray_converter(PyObject *in, PyObject **out) +{ + PyObject *ret = PyArray_FROMANY(in, NPY_INT8, 1, 1, NPY_ARRAY_IN_ARRAY); + if (ret == NULL) { + return NPY_FAIL; + } + *out = ret; + return NPY_SUCCEED; +} + /*=================================================================== * AncestorBuilder *=================================================================== @@ -429,8 +440,11 @@ TreeSequenceBuilder_init(TreeSequenceBuilder *self, PyObject *args, PyObject *kw { int ret = -1; int err; - static char *kwlist[] = {"num_alleles", "max_nodes", "max_edges", NULL}; + static char *kwlist[] = {"num_alleles", "max_nodes", "max_edges", "derived_state", + NULL}; PyArrayObject *num_alleles = NULL; + PyArrayObject *derived_state = NULL; + int8_t *derived_state_data = NULL; unsigned long max_nodes = 1024; unsigned long max_edges = 1024; unsigned long num_sites; @@ -438,21 +452,31 @@ TreeSequenceBuilder_init(TreeSequenceBuilder *self, PyObject *args, PyObject *kw int flags = 0; self->tree_sequence_builder = NULL; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&|kk", kwlist, + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&|kkO&", kwlist, uint64_PyArray_converter, &num_alleles, - &max_nodes, &max_edges)) { + &max_nodes, &max_edges, + int8_PyArray_converter, &derived_state)) { goto out; } shape = PyArray_DIMS(num_alleles); num_sites = shape[0]; - + if (derived_state != NULL) { + shape = PyArray_DIMS(derived_state); + if (shape[0] != (npy_intp) num_sites) { + PyErr_SetString(PyExc_ValueError, "derived state array wrong size"); + goto out; + } + derived_state_data = PyArray_DATA(derived_state); + } self->tree_sequence_builder = PyMem_Malloc(sizeof(tree_sequence_builder_t)); if (self->tree_sequence_builder == NULL) { PyErr_NoMemory(); goto out; } err = tree_sequence_builder_alloc(self->tree_sequence_builder, - num_sites, PyArray_DATA(num_alleles), + num_sites, + PyArray_DATA(num_alleles), + derived_state_data, max_nodes, max_edges, flags); if (err != 0) { handle_library_error(err); @@ -461,6 +485,7 @@ TreeSequenceBuilder_init(TreeSequenceBuilder *self, PyObject *args, PyObject *kw ret = 0; out: Py_XDECREF(num_alleles); + Py_XDECREF(derived_state); return ret; } diff --git a/tests/test_low_level.py b/tests/test_low_level.py index 5679499a..42932b85 100644 --- a/tests/test_low_level.py +++ b/tests/test_low_level.py @@ -88,9 +88,6 @@ class TestTreeSequenceBuilder: def test_init(self): with pytest.raises(TypeError): _tsinfer.TreeSequenceBuilder() - for bad_array in [None, "serf", [[], []], ["asdf"], {}]: - with pytest.raises(ValueError): - _tsinfer.TreeSequenceBuilder(bad_array) for bad_type in [None, "sdf", {}]: with pytest.raises(TypeError): @@ -98,6 +95,24 @@ def test_init(self): with pytest.raises(TypeError): _tsinfer.TreeSequenceBuilder([2], max_edges=bad_type) + def test_bad_num_alleles(self): + for bad_array in [None, "serf", [[], []], ["asdf"], {}]: + with pytest.raises(ValueError): + _tsinfer.TreeSequenceBuilder(bad_array) + with pytest.raises(_tsinfer.LibraryError, match="number of alleles"): + _tsinfer.TreeSequenceBuilder([1000]) + + def test_bad_derived_state(self): + for bad_array in [None, "serf", [[], []], ["asdf"], {}]: + with pytest.raises(ValueError): + _tsinfer.TreeSequenceBuilder([2], derived_state=bad_array) + with pytest.raises(_tsinfer.LibraryError, match="Bad derived state"): + for bad_derived_state in [-1, 2, 100]: + _tsinfer.TreeSequenceBuilder([2], derived_state=[bad_derived_state]) + + with pytest.raises(ValueError, match="derived state array wrong size"): + _tsinfer.TreeSequenceBuilder([2, 2, 2], derived_state=[]) + class TestAncestorBuilder: """