Skip to content

Commit

Permalink
refactoring tests and only searching even batches
Browse files Browse the repository at this point in the history
  • Loading branch information
jbedichekTT committed Feb 25, 2025
1 parent d92db1f commit 68fbb73
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 47 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/batch-experiment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ jobs:
echo "Expanding upper bound to: $batch_range_upper" # Optional logging
fi
local batch_size_to_test=$(( (batch_range[0] + batch_range[1]) / 2 ))
if (( batch_size_to_test % 2 != 0)); then
batch_size_to_test=$batch_size_to_test-1
fi
echo "Testing with batch size $batch_size_to_test"
python3 -m pytest "$test_path" -s --batch_size $batch_size_to_test --report_nth_iteration $num_iterations
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def skip_by_platform(request, device):


@pytest.fixture(autouse=True)
def compile_and_run(device, reset_torch_dynamo, request, batch_size):
def compile_and_run(device, reset_torch_dynamo, request):
logging.info("Starting the compile_and_run fixture.")

runtime_metrics = {"success": False} # Initialize early to ensure it's defined
Expand Down Expand Up @@ -182,8 +182,8 @@ def compile_and_run(device, reset_torch_dynamo, request, batch_size):
for idx in range(int(request.config.getoption("--report_nth_iteration"))):
start = time.perf_counter() * 1000
# Don't need to reset options if inputs don't change because of cache

outputs_after = model_tester.test_model(as_ttnn=True, option=option)
# return
end = time.perf_counter() * 1000
run_time = end - start
if idx == 0:
Expand Down
51 changes: 6 additions & 45 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(self, model_name, mode, batch_size=None):
self.inputs = self._load_inputs()
self.batch_size = batch_size
self.validate_batch_size()
self.batch_inputs()

def _load_model(self):
raise NotImplementedError("This method should be implemented in the derived class")
Expand Down Expand Up @@ -71,23 +72,6 @@ def run_model(self, model, inputs):
else:
return model(inputs)

def run_model_batched(self, model, inputs):
# This creates a batch of duplicates (all items in the batch are the same, just repeated
# Naively to create a batch)
def repeat_tensor(x):
x = x.squeeze(0)
x = x.repeat(self.batch_size, *([1] * (x.dim()))) # Repeat along batch dim
return x

if isinstance(inputs, collections.abc.Mapping):
batched_inputs = {k: repeat_tensor(v) for k, v in inputs.items()}
return model(**batched_inputs)
elif isinstance(inputs, collections.abc.Sequence) and not isinstance(inputs, (str, bytes)):
batched_inputs = [repeat_tensor(x) for x in inputs]
return model(*batched_inputs)
else:
return model(repeat_tensor(inputs))

def append_fake_loss_function(self, outputs):
# Using `torch.mean` as the loss function for testing purposes.
#
Expand Down Expand Up @@ -138,15 +122,11 @@ def test_model_train(self, as_ttnn=False, option=None):
inputs = self.set_inputs_train(self.inputs)
if as_ttnn == True:
model = self.compile_model(model, option)
if self.batch_size is not None:
outputs = self.run_model_batched(model, inputs)
else:
outputs = self.run_model(model, inputs)
outputs = self.run_model(model, inputs)
loss = self.append_fake_loss_function(outputs)
loss.backward()
# Again, use the gradient of the input (`test_input.grad`) as the golden result for the training process.
results = self.get_results_train(model, inputs, outputs)
self.batch_inputs()
return results

@torch.no_grad()
Expand All @@ -156,12 +136,8 @@ def test_model_eval(self, as_ttnn=False, option=None):
inputs = self.set_inputs_eval(self.inputs)
if as_ttnn == True:
model = self.compile_model(model, option)
if self.batch_size is not None:
outputs = self.run_model_batched(model, inputs)
else:
outputs = self.run_model(model, inputs)
outputs = self.run_model(model, inputs)
results = self.get_results_eval(model, inputs, outputs)
self.batch_inputs()
return results

def test_model(self, as_ttnn=False, option=None):
Expand All @@ -182,10 +158,13 @@ def batch_inputs(self):
if isinstance(self.inputs[key], torch.Tensor):
self.inputs[key] = self.inputs[key].repeat(self.batch_size, 1)
elif isinstance(self.inputs, torch.Tensor):
# if self.inputs.ndim < 4:
print(self.inputs.shape)
if self.inputs.shape[0] == 0:
self.inputs = self.inputs.squeeze(0)
self.inputs = self.inputs.repeat(self.batch_size, *([1] * (self.inputs.dim())))
self.inputs = self.inputs.squeeze(1)
print(self.inputs.shape)
else:
raise TypeError(f"Unregonized inputs type: {type(self.inputs)}")

Expand Down Expand Up @@ -620,21 +599,3 @@ def process_batched_logits(logits, batch_size):
return logits[0, :].squeeze(0)
else:
raise ValueError(f"Unrecognized logit dimension: {logits.shape.numel()} (not 2D or 3D including batch)")


def batch_object_inputs(tester_obj, batch_size):
if batch_size is None:
return
inputs = tester_obj.inputs
if isinstance(inputs, dict) or isinstance(inputs, transformers.tokenization_utils_base.BatchEncoding):
keys = inputs.keys()
for key in keys:
if isinstance(inputs[key], torch.Tensor):
inputs[key] = inputs[key].repeat(batch_size, 1)
elif isinstance(inputs, torch.Tensor):
if inputs.shape[0] == 0:
inputs = inputs.squeeze(0)
tester_obj.inputs = inputs.repeat(batch_size, *([1] * (inputs.dim())))
tester_obj.inputs = tester_obj.inputs.squeeze(1)
else:
raise TypeError(f"Unregonized inputs type: {type(inputs)}")

0 comments on commit 68fbb73

Please sign in to comment.