From 05d8a867415cc10209376c9e4ef730fe8ab2e7a5 Mon Sep 17 00:00:00 2001
From: danlu1 <dan.lu@sagebase.org>
Date: Sat, 16 Mar 2024 00:45:56 +0000
Subject: [PATCH] add test function for Clinical.preprocess

---
 tests/test_clinical.py | 44 ++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 44 insertions(+)

diff --git a/tests/test_clinical.py b/tests/test_clinical.py
index 91815c97..7b8789ec 100644
--- a/tests/test_clinical.py
+++ b/tests/test_clinical.py
@@ -38,11 +38,33 @@ def table_query_results(*args):
     )
 )
 
+patientdf = pd.DataFrame(
+    dict(
+        fieldName=["PATIENT_ID", "SEX", "PRIMARY_RACE"],
+        patient=[True, True, True],
+        sample=[True, False, False],
+    )
+)
+sampledf = pd.DataFrame(
+    dict(
+        fieldName=["PATIENT_ID", "SAMPLE_ID"],
+        patient=[True, False],
+        sample=[True, True],
+    )
+)
+
+
 table_query_results_map = {
     ("select * from syn7434222",): createMockTable(sexdf),
     ("select * from syn7434236",): createMockTable(no_nan),
     ("select * from syn7434242",): createMockTable(no_nan),
     ("select * from syn7434273",): createMockTable(no_nan),
+    (
+        "select fieldName from syn8545211 where patient is True and inClinicalDb is True",
+    ): createMockTable(patientdf),
+    (
+        "select fieldName from syn8545211 where sample is True and inClinicalDb is True",
+    ): createMockTable(sampledf),
 }
 
 json_oncotreeurl = (
@@ -1382,3 +1404,25 @@ def test_that__cross_validate_assay_info_has_seq_returns_expected_msg_if_valid(
         )
         assert warnings == expected_warning
         assert errors == expected_error
+
+
+def test_preprocess(clin_class, newpath=None):
+    expected = {
+        "clinicalTemplate": pd.DataFrame(
+            columns=["PATIENT_ID", "SEX", "PRIMARY_RACE", "SAMPLE_ID"]
+        ),
+        "sample": True,
+        "patient": True,
+        "patientCols": ["PATIENT_ID", "SEX", "PRIMARY_RACE"],
+        "sampleCols": ["PATIENT_ID", "SAMPLE_ID"],
+    }
+    results = clin_class.preprocess(newpath)
+    assert (
+        results["clinicalTemplate"]
+        .sort_index(axis=1)
+        .equals(expected["clinicalTemplate"].sort_index(axis=1))
+    )
+    assert results["sample"] == expected["sample"]
+    assert results["patient"] == expected["patient"]
+    assert results["patientCols"] == expected["patientCols"]
+    assert results["sampleCols"] == expected["sampleCols"]