Skip to content

Commit

Permalink
Update test harness to be compatible with latest Treelite
Browse files Browse the repository at this point in the history
  • Loading branch information
hcho3 committed Oct 3, 2024
1 parent d4e7fd7 commit eb9ba75
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions qa/L0_e2e/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,20 +118,27 @@ def __init__(self, model_path, model_format, output_class):
self.output_class = output_class

def _predict(self, arr):
return treelite.gtil.predict(self.tl_model, arr)
result = treelite.gtil.predict(self.tl_model, arr)
# GTIL always returns prediction result with dimensions
# (num_row, num_target, num_class)
assert len(result.shape) == 3
# We don't test multi-target models
# TODO(hcho3): Add coverage for multi-target models
assert result.shape[1] == 1
return result[:, 0, :]

def predict_proba(self, arr):
result = self._predict(arr)
if len(result.shape) > 1:
if result.shape[1] > 1:
return result
else:
return np.transpose(np.vstack((1 - result, result)))
return np.hstack((1 - result, result))

def predict(self, arr):
if self.output_class:
return np.argmax(self.predict_proba(arr), axis=1)
else:
return self._predict(arr)
return self._predict(arr).squeeze()


class GroundTruthModel:
Expand Down

0 comments on commit eb9ba75

Please sign in to comment.