Skip to content

Commit

Permalink
Changed device of torch dataset in predict method to avoid potential …
Browse files Browse the repository at this point in the history
…error
  • Loading branch information
danyoungday committed Jul 31, 2024
1 parent bef6351 commit c0e5300
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion prsdk/predictors/neural_network/neural_net_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def predict(self, context_actions_df: pd.DataFrame) -> pd.DataFrame:
:return: DataFrame of predictions properly labeled and indexed.
"""
X_test_scaled = self.scaler.transform(context_actions_df[self.features])
test_ds = TorchDataset(X_test_scaled, np.zeros(len(X_test_scaled)))
test_ds = TorchDataset(X_test_scaled, np.zeros(len(X_test_scaled)), device=self.device)
test_dl = DataLoader(test_ds, self.batch_size, shuffle=False)
pred_list = []
with torch.no_grad():
Expand Down

0 comments on commit c0e5300

Please sign in to comment.