Skip to content

Commit

Permalink
Improve tests
Browse files Browse the repository at this point in the history
  • Loading branch information
EthanSteinberg committed Aug 16, 2024
1 parent 6ab215d commit 423f096
Showing 1 changed file with 40 additions and 15 deletions.
55 changes: 40 additions & 15 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,19 @@ def subject_database(tmpdir: str, meds_dataset: str):
return meds_reader.SubjectDatabase(str(meds_reader_dir))


@pytest.fixture
def threaded_subject_database(tmpdir: str, meds_dataset: str):

meds_reader_dir = os.path.join(tmpdir, "meds_reader")

subprocess.run(
["meds_reader_convert", meds_dataset, meds_reader_dir, "--num_threads", "4"],
check=True,
)

return meds_reader.SubjectDatabase(str(meds_reader_dir), num_threads=4)


def test_metadata(subject_database):
with open(os.path.join(subject_database.path_to_database, "metadata", "dataset.json")) as f:
loaded_metadata = json.load(f)
Expand All @@ -114,12 +127,24 @@ def test_iter(subject_database):
assert list(subject_database) == [32, 64]


def test_map(subject_database):
def h(subjects):
result = []
for p in subjects:
result.append(p.subject_id)
return result
def h(subjects):
result = []
for p in subjects:
result.append(p.subject_id)
return result


def h2(subjects_and_data):
result = []
for subject, rows in subjects_and_data:
assert len(rows) == 1
row = rows[0]
print(subject, row)
result.append((subject.subject_id, row.other))
return result


def map_helper(subject_database):

results = list(subject_database.map(h))

Expand All @@ -131,15 +156,6 @@ def h(subjects):

table = pd.DataFrame({"subject_id": [64, 32], "other": [1, 1000]})

def h2(subjects_and_data):
result = []
for subject, rows in subjects_and_data:
assert len(rows) == 1
row = rows[0]
print(subject, row)
result.append((subject.subject_id, row.other))
return result

results = list(subject_database.map_with_data(h2, table))

final_result = {a for b in results for a in b}
Expand All @@ -149,6 +165,15 @@ def h2(subjects_and_data):
assert final_result == {(32, 1000), (64, 1)}


def test_map(subject_database):
map_helper(subject_database)


def test_map_threaded(threaded_subject_database):
map_helper(threaded_subject_database)
threaded_subject_database.terminate()


def test_properties(subject_database):
print(subject_database.properties)
assert subject_database.properties == {
Expand Down

0 comments on commit 423f096

Please sign in to comment.