diff --git a/tests/api_test.py b/tests/api_test.py index 99d143b..ba8d736 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -95,6 +95,19 @@ def subject_database(tmpdir: str, meds_dataset: str): return meds_reader.SubjectDatabase(str(meds_reader_dir)) +@pytest.fixture +def threaded_subject_database(tmpdir: str, meds_dataset: str): + + meds_reader_dir = os.path.join(tmpdir, "meds_reader") + + subprocess.run( + ["meds_reader_convert", meds_dataset, meds_reader_dir, "--num_threads", "4"], + check=True, + ) + + return meds_reader.SubjectDatabase(str(meds_reader_dir), num_threads=4) + + def test_metadata(subject_database): with open(os.path.join(subject_database.path_to_database, "metadata", "dataset.json")) as f: loaded_metadata = json.load(f) @@ -114,12 +127,24 @@ def test_iter(subject_database): assert list(subject_database) == [32, 64] -def test_map(subject_database): - def h(subjects): - result = [] - for p in subjects: - result.append(p.subject_id) - return result +def h(subjects): + result = [] + for p in subjects: + result.append(p.subject_id) + return result + + +def h2(subjects_and_data): + result = [] + for subject, rows in subjects_and_data: + assert len(rows) == 1 + row = rows[0] + print(subject, row) + result.append((subject.subject_id, row.other)) + return result + + +def map_helper(subject_database): results = list(subject_database.map(h)) @@ -131,15 +156,6 @@ def h(subjects): table = pd.DataFrame({"subject_id": [64, 32], "other": [1, 1000]}) - def h2(subjects_and_data): - result = [] - for subject, rows in subjects_and_data: - assert len(rows) == 1 - row = rows[0] - print(subject, row) - result.append((subject.subject_id, row.other)) - return result - results = list(subject_database.map_with_data(h2, table)) final_result = {a for b in results for a in b} @@ -149,6 +165,15 @@ def h2(subjects_and_data): assert final_result == {(32, 1000), (64, 1)} +def test_map(subject_database): + map_helper(subject_database) + + +def test_map_threaded(threaded_subject_database): + map_helper(threaded_subject_database) + threaded_subject_database.terminate() + + def test_properties(subject_database): print(subject_database.properties) assert subject_database.properties == {