diff --git a/eir/predict_modules/predict_output_modules/predict_survival_output.py b/eir/predict_modules/predict_output_modules/predict_survival_output.py index f9f7839b..6e22cfb4 100644 --- a/eir/predict_modules/predict_output_modules/predict_survival_output.py +++ b/eir/predict_modules/predict_output_modules/predict_survival_output.py @@ -122,7 +122,7 @@ def predict_survival_wrapper_with_labels( ) else: - risk_scores = model_outputs.numpy() + risk_scores = model_outputs.cpu().numpy() baseline_hazard = output_object.baseline_hazard unique_times = output_object.baseline_unique_times diff --git a/eir/train_utils/evaluation_modules/evaluation_output_survival.py b/eir/train_utils/evaluation_modules/evaluation_output_survival.py index 1de669a0..2e5e7734 100644 --- a/eir/train_utils/evaluation_modules/evaluation_output_survival.py +++ b/eir/train_utils/evaluation_modules/evaluation_output_survival.py @@ -109,7 +109,7 @@ def save_survival_evaluation_results_wrapper( ) else: - risk_scores = model_outputs.numpy() + risk_scores = model_outputs.cpu().numpy() unique_times, baseline_hazard = estimate_baseline_hazard( times=times,