Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mostafa committed May 17, 2024
1 parent c6c8100 commit 4237d88
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions training/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def model(request):
@pytest.mark.parametrize(
"sample",
[
("select * from users where id=1 or 1=1;", [99.99, 99.83]),
("select * from users where id='1' or 1=1--", [92.02, 99.83]),
("select * from users where id=1 or 1=1;", [99.99, 28.87]),
("select * from users where id='1' or 1=1--", [92.02, 28.87]),
("select * from users", [0.077, 0.08]),
("select * from users where id=10000", [14.83, 97.32]),
("select * from users where id=10000", [14.83, 4.137]),
("select '1' union select 'a'; -- -'", [99.99, 97.32]),
(
"select '' union select 'malicious php code' \g /var/www/test.php; -- -';",
Expand All @@ -71,7 +71,11 @@ def test_sqli_model(model, sample):
predictions = model["sqli_model"](sample_vec)

# Scale up to 100
print(predictions["dense"].numpy() * 100) # Debugging purposes (prints on error)
assert predictions["dense"].numpy() * 100 == pytest.approx(
output = "dense"
if "output_0" in predictions:
output = "output_0" # Model v2 uses output_0 instead of dense

print(predictions[output].numpy() * 100) # Debugging purposes (prints on error)
assert predictions[output].numpy() * 100 == pytest.approx(
np.array([[sample[1][model["index"]]]]), 0.1
)

0 comments on commit 4237d88

Please sign in to comment.