diff --git a/genie_registry/clinical.py b/genie_registry/clinical.py index e5a4244c..3c497edb 100644 --- a/genie_registry/clinical.py +++ b/genie_registry/clinical.py @@ -392,7 +392,7 @@ def preprocess(self, newpath): "sample is True and inClinicalDb is True" ) sample_cols = sample_cols_table.asDataFrame()["fieldName"].tolist() - clinicalTemplate = pd.DataFrame(columns=set(patient_cols + sample_cols)) + clinicalTemplate = pd.DataFrame(columns=list(set(patient_cols + sample_cols))) sample = True patient = True diff --git a/tests/test_clinical.py b/tests/test_clinical.py index 91815c97..c2ad2c92 100644 --- a/tests/test_clinical.py +++ b/tests/test_clinical.py @@ -38,11 +38,33 @@ def table_query_results(*args): ) ) +patientdf = pd.DataFrame( + dict( + fieldName=["PATIENT_ID", "SEX", "PRIMARY_RACE"], + patient=[True, True, True], + sample=[True, False, False], + ) +) +sampledf = pd.DataFrame( + dict( + fieldName=["PATIENT_ID", "SAMPLE_ID"], + patient=[True, False], + sample=[True, True], + ) +) + + table_query_results_map = { ("select * from syn7434222",): createMockTable(sexdf), ("select * from syn7434236",): createMockTable(no_nan), ("select * from syn7434242",): createMockTable(no_nan), ("select * from syn7434273",): createMockTable(no_nan), + ( + "select fieldName from syn8545211 where patient is True and inClinicalDb is True", + ): createMockTable(patientdf), + ( + "select fieldName from syn8545211 where sample is True and inClinicalDb is True", + ): createMockTable(sampledf), } json_oncotreeurl = ( @@ -1382,3 +1404,26 @@ def test_that__cross_validate_assay_info_has_seq_returns_expected_msg_if_valid( ) assert warnings == expected_warning assert errors == expected_error + + +def test_preprocess(clin_class, newpath=None): + """Test preprocess function""" + expected = { + "clinicalTemplate": pd.DataFrame( + columns=["PATIENT_ID", "SEX", "PRIMARY_RACE", "SAMPLE_ID"] + ), + "sample": True, + "patient": True, + "patientCols": ["PATIENT_ID", "SEX", "PRIMARY_RACE"], + "sampleCols": ["PATIENT_ID", "SAMPLE_ID"], + } + results = clin_class.preprocess(newpath) + assert ( + results["clinicalTemplate"] + .sort_index(axis=1) + .equals(expected["clinicalTemplate"].sort_index(axis=1)) + ) + assert results["sample"] == expected["sample"] + assert results["patient"] == expected["patient"] + assert results["patientCols"] == expected["patientCols"] + assert results["sampleCols"] == expected["sampleCols"]