-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
#0: Added a naive batching mechanism for some models in the tests folder #740
base: main
Are you sure you want to change the base?
Changes from 1 commit
bd6ec3c
1b79c98
485a032
48e330e
d616377
c4a6a5a
253f534
9a7438d
1f2b60f
2bb52a2
28a6369
4748595
f8f7112
de0e23d
47e2ea0
c8c0af6
e11510d
bf0027d
4e31900
ea160c9
b406a8c
250cddc
fe6bb36
a757c30
f6c0ded
caa813e
13964e8
036e966
907751f
74dd352
bf2cd30
3ba6ef6
d92db1f
68fbb73
8ffaeda
5d710ec
894b00f
153b274
e3a5db9
bce3f43
1f6a8c2
1204887
8403376
95f3367
15c386e
6f70217
b2493e3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ | |
from transformers import AutoTokenizer, AlbertForMaskedLM | ||
import torch | ||
import pytest | ||
from tests.utils import ModelTester | ||
from tests.utils import ModelTester, validate_batch_size, process_batched_logits | ||
|
||
|
||
class ThisTester(ModelTester): | ||
|
@@ -40,15 +40,25 @@ def append_fake_loss_function(self, outputs): | |
"albert/albert-xxlarge-v2", | ||
], | ||
) | ||
def test_albert_masked_lm(record_property, model_name, mode): | ||
|
||
|
||
def test_albert_masked_lm(record_property, model_name, mode, get_batch_size): | ||
record_property("model_name", model_name) | ||
record_property("mode", mode) | ||
|
||
batch_size = get_batch_size | ||
if batch_size is not None: | ||
batch_size = int(batch_size) | ||
validate_batch_size(batch_size) | ||
|
||
tester = ThisTester(model_name, mode) | ||
results = tester.test_model() | ||
|
||
results = tester.test_model(batch_size=batch_size) | ||
if mode == "eval": | ||
# retrieve index of [MASK] | ||
|
||
results.logits = process_batched_logits(results.logits, batch_size) | ||
#print(results.logits.shape) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove |
||
logits = results.logits | ||
mask_token_index = (tester.inputs.input_ids == tester.tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0] | ||
predicted_token_id = logits[0, mask_token_index].argmax(axis=-1) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ | |
|
||
# Load model directly | ||
from transformers import AutoTokenizer, AutoModelForQuestionAnswering | ||
from tests.utils import ModelTester | ||
from tests.utils import ModelTester, validate_batch_size, process_batched_logits, batch_object_inputs | ||
|
||
|
||
class ThisTester(ModelTester): | ||
|
@@ -35,13 +35,19 @@ def _load_inputs(self): | |
["eval"], | ||
) | ||
@pytest.mark.converted_end_to_end | ||
def test_bert(record_property, mode): | ||
def test_bert(record_property, mode, get_batch_size): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead this being an option of the run here, I think it might be better to simply pass it to the ModelTester in conftest like this outputs_after = model_tester.test_model(as_ttnn=True, option=option, batch_size=batch_size) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Will do |
||
model_name = "BERT" | ||
record_property("model_name", model_name) | ||
record_property("mode", mode) | ||
|
||
batch_size = get_batch_size | ||
if batch_size is not None: | ||
batch_size = int(batch_size) | ||
validate_batch_size(batch_size) | ||
|
||
tester = ThisTester(model_name, mode) | ||
results = tester.test_model() | ||
results = tester.test_model(batch_size=batch_size) | ||
batch_object_inputs(tester, batch_size) # This is necessary to avoid shape mismatch errors in tester processing | ||
|
||
if mode == "eval": | ||
# Helper function to decode output to human-readable text | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Returned value is a
value
, not a function, so the name starting with "get" is a bit weird here.Wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe just "batch_size"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, looks cleaner