Skip to content

Commit

Permalink
[test] add e2e on device training test
Browse files Browse the repository at this point in the history
Adding MNIST training test which runs everything on device, i.e.
fwd/bwd of the model and loss function, as well as the optimizer.

Additionally, modified the `test_lora` test to use optimizer on device.
  • Loading branch information
pilkicTT committed Feb 25, 2025
1 parent 682b974 commit 1775e3c
Showing 1 changed file with 84 additions and 23 deletions.
107 changes: 84 additions & 23 deletions forge/test/mlir/mnist/training/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@

@pytest.mark.push
def test_mnist_training():
torch.manual_seed(0)

# Model and data type.
# For bfloat16, the following line should be added to the test_forge_vs_torch function:
# In file forge/forge/op/eval/forge/eltwise_unary.py:418 should be replaced with: threshold_tensor = ac.tensor(torch.zeros(shape, dtype=torch.bfloat16) + threshold)
Expand Down Expand Up @@ -100,8 +98,6 @@ def test_mnist_training():

@pytest.mark.push
def test_mnist_training_with_grad_accumulation():
torch.manual_seed(0)

# Config
num_epochs = 3
batch_size = 1
Expand Down Expand Up @@ -179,7 +175,6 @@ def test_mnist_training_with_grad_accumulation():
@pytest.mark.push
def test_forge_vs_torch_gradients(freeze_layer):
logger.disable("")
torch.manual_seed(0)
batch_size = 64

dtype = torch.float32
Expand Down Expand Up @@ -239,8 +234,6 @@ def test_forge_vs_torch_gradients(freeze_layer):
@pytest.mark.skip(reason="Need to be tested with bfloat16 and takes around 10 minutes to run")
@pytest.mark.push
def test_forge_vs_torch():
torch.manual_seed(0)

batch_size = 64
learning_rate = 1e-2
epochs = 10
Expand Down Expand Up @@ -339,8 +332,6 @@ def test_forge_vs_torch():

@pytest.mark.push
def test_loss_device():
torch.manual_seed(0)

# Config
num_epochs = 3
batch_size = 1
Expand Down Expand Up @@ -415,20 +406,20 @@ def test_loss_device():

@pytest.mark.push
def test_lora():
torch.manual_seed(0)

# Config
num_epochs = 3
batch_size = 64
learning_rate = 0.001
batch_size = 128
learning_rate = 0.1

# Load dataset
test_loader, train_loader = load_dataset(batch_size)

framework_model = MNISTLora(bias=False)
framework_optimizer = torch.optim.SGD(framework_model.parameters(), lr=learning_rate)

tt_model = forge.compile(framework_model, sample_inputs=[torch.rand(batch_size, 784)], training=True)
tt_optimizer = forge.optimizers.SGD(learning_rate=learning_rate)
tt_model = forge.compile(
framework_model, sample_inputs=[torch.rand(batch_size, 784)], optimizer=tt_optimizer, training=True
)

loss_fn = CrossEntropyLoss(name="cross_entropy_loss")

Expand All @@ -438,11 +429,10 @@ def test_lora():

logger.info("Starting training loop... (logger will be disabled)")
logger.disable("")
prev_total_loss = 1e10
for epoch_idx in range(num_epochs):
total_loss = 0
for _, (data, target) in enumerate(train_loader):
framework_optimizer.zero_grad()

# Create target tensor and leave on CPU
target = nn.functional.one_hot(target, num_classes=10).float()

Expand All @@ -451,16 +441,20 @@ def test_lora():
golden_pred = framework_model(data)
assert compare_with_golden(golden_pred, pred, pcc=0.95)

loss = tt_loss(pred, target)
total_loss += loss[0].item()
loss = tt_loss(pred, target)[0]

total_loss += loss.item()

# Run backward pass on device
tt_loss.backward()

# Adjust weights (on CPU)
framework_optimizer.step()
# Adjust weights on the device.
# NOTE: after executing the step, this will also zero the gradients.
tt_optimizer.step()

print(f"epoch: {epoch_idx} loss: {total_loss}")
assert prev_total_loss - total_loss > 1.0, "Loss should go down"
prev_total_loss = total_loss

test_loss = 0
for _, (data, target) in enumerate(test_loader):
Expand All @@ -474,8 +468,6 @@ def test_lora():

@pytest.mark.push
def test_optimizer_device():
torch.manual_seed(0)

# Config
num_epochs = 32
batch_size = 1024
Expand Down Expand Up @@ -537,3 +529,72 @@ def test_optimizer_device():
break

print(f"Test (total) loss: {test_loss}")


@pytest.mark.push
def test_e2e_device():
# Config
num_epochs = 5
batch_size = 1024
learning_rate = 0.1

# Load dataset
test_loader, train_loader = load_dataset(batch_size)

framework_model = MNISTLinear(bias=False)
framework_loss = torch.nn.CrossEntropyLoss()
tt_optimizer = forge.optimizers.SGD(learning_rate=learning_rate)

tt_model = forge.compile(
framework_model, sample_inputs=[torch.rand(batch_size, 784)], optimizer=tt_optimizer, training=True
)

loss_inputs = [torch.rand(batch_size, 10).requires_grad_(True), torch.rand(batch_size, 10)]
loss_inputs = to_forge_tensors(loss_inputs)
tt_loss = forge.compile(
CrossEntropyLoss(name="cross_entropy_loss"), sample_inputs=loss_inputs, training=True, attach_to=tt_model
)

logger.info("Starting training loop... (logger will be disabled)")
logger.disable("")

prev_total_loss = 1e10
for epoch_idx in range(num_epochs):
total_loss = 0
for batch_idx, (data, target) in enumerate(train_loader):

# Create target tensor and leave on CPU.
target = nn.functional.one_hot(target, num_classes=10).float()

# Forward pass (prediction) on device.
pred = tt_model(data)[0]
golden_pred = framework_model(data)
assert compare_with_golden(golden_pred, pred, pcc=0.95)

# Execute loss (and its backward) on device.
loss = tt_loss(pred, target)[0]
total_loss += loss.item()

golden_loss = framework_loss(pred, target)
assert compare_with_golden(golden_loss, loss, rtol=1e-1)

tt_loss.backward()

# Adjust weights on the device.
# NOTE: after executing the step, this will also zero the gradients.
tt_optimizer.step()

print(f"epoch: {epoch_idx} loss: {total_loss}")

assert prev_total_loss - total_loss > 1.0, "Loss should go down"
prev_total_loss = total_loss

test_loss = 0
for batch_idx, (data, target) in enumerate(test_loader):
pred = tt_model(data)[0]
target = nn.functional.one_hot(target, num_classes=10).float()

test_loss += framework_loss(pred, target)
break

print(f"Test (total) loss: {test_loss}")

0 comments on commit 1775e3c

Please sign in to comment.