From 41b0356dc5dd3b7ec6a412bab4f637e317ccd4b7 Mon Sep 17 00:00:00 2001
From: danlu1 <dan.lu@sagebase.org>
Date: Tue, 23 Apr 2024 02:12:35 +0000
Subject: [PATCH] add unit tests

---
 tests/test_clinical.py | 59 +++++++++++++++++++++++++++++++++++++++---
 1 file changed, 55 insertions(+), 4 deletions(-)

diff --git a/tests/test_clinical.py b/tests/test_clinical.py
index 91815c97..fead33dd 100644
--- a/tests/test_clinical.py
+++ b/tests/test_clinical.py
@@ -1,14 +1,13 @@
-from collections import Counter
 import datetime
 import json
+from collections import Counter
 from unittest import mock
 
+import genie_registry
 import pandas as pd
 import pytest
 import synapseclient
-
 from genie import process_functions, validate
-import genie_registry
 from genie_registry.clinical import Clinical
 
 
@@ -662,7 +661,6 @@ def test_errors__validate(clin_class):
     ) as mock_get_onco_map:
         error, warning = clin_class._validate(clinicalDf)
         mock_get_onco_map.called_once_with(json_oncotreeurl)
-
         expectedErrors = (
             "Sample Clinical File: SAMPLE_ID must start with GENIE-SAGE\n"
             "Patient Clinical File: PATIENT_ID must start with GENIE-SAGE\n"
@@ -700,6 +698,8 @@ def test_errors__validate(clin_class):
             "Patient Clinical File: Please double check your YEAR_CONTACT "
             "column, it must be an integer in YYYY format <= {year} or "
             "'Unknown', 'Not Collected', 'Not Released', '>89', '<18'.\n"
+            "Patient Clinical File: Please double check your YEAR_DEATH "
+            "and YEAR_CONTACT columns. YEAR_DEATH must be >= YEAR_CONTACT.\n"
             "Patient Clinical File: Please double check your INT_CONTACT "
             "column, it must be an integer, '>32485', '<6570', 'Unknown', "
             "'Not Released' or 'Not Collected'.\n"
@@ -1062,6 +1062,57 @@ def test__check_int_dead_consistency_inconsistent(inconsistent_df):
     )
 
 
+@pytest.mark.parametrize(
+    "df,expected_error",
+    [
+        (
+            pd.DataFrame({"YEAR_DEATH": [420, 555, 390], "YEAR_CONTACT": [50, 40, 22]}),
+            "",
+        ),
+        (
+            pd.DataFrame(
+                {
+                    "YEAR_DEATH": [420, float("nan"), 390],
+                    "YEAR_CONTACT": [50, 40, float("nan")],
+                }
+            ),
+            "",
+        ),
+        (
+            pd.DataFrame(
+                {
+                    "YEAR_DEATH": [float("nan"), float("nan"), 390],
+                    "YEAR_CONTACT": [50, 40, float("nan")],
+                }
+            ),
+            "",
+        ),
+        (
+            pd.DataFrame(
+                {"YEAR_DEATH": [420, 666, 390], "YEAR_CONTACT": [50, 40, 555]}
+            ),
+            "Patient Clinical File: Please double check your YEAR_DEATH and YEAR_CONTACT columns. YEAR_DEATH must be >= YEAR_CONTACT.\n",
+        ),
+        (
+            pd.DataFrame(
+                {"YEAR_DEATH": [420, float("nan"), 390], "YEAR_CONTACT": [50, 40, 555]}
+            ),
+            "Patient Clinical File: Please double check your YEAR_DEATH and YEAR_CONTACT columns. YEAR_DEATH must be >= YEAR_CONTACT.\n",
+        ),
+    ],
+    ids=[
+        "valid_dataframe_no_NAs",
+        "valid_dataframe_w_NAs",
+        "valid_dataframe_all_NAs",
+        "invalid_dataframe_no_NAs",
+        "invalid_dataframe_w_NAs",
+    ],
+)
+def test__check_year_death_validity(df, expected_error):
+    error = genie_registry.clinical._check_year_death_validity(clinicaldf=df)
+    assert error == expected_error
+
+
 def get_cross_validate_bed_files_test_cases():
     return [
         {