diff --git a/forge/test/mlir/mnist/training/test_training.py b/forge/test/mlir/mnist/training/test_training.py index 45ae3cc73..045a2187f 100644 --- a/forge/test/mlir/mnist/training/test_training.py +++ b/forge/test/mlir/mnist/training/test_training.py @@ -19,8 +19,6 @@ @pytest.mark.push def test_mnist_training(): - torch.manual_seed(0) - # Model and data type. # For bfloat16, the following line should be added to the test_forge_vs_torch function: # In file forge/forge/op/eval/forge/eltwise_unary.py:418 should be replaced with: threshold_tensor = ac.tensor(torch.zeros(shape, dtype=torch.bfloat16) + threshold) @@ -100,8 +98,6 @@ def test_mnist_training(): @pytest.mark.push def test_mnist_training_with_grad_accumulation(): - torch.manual_seed(0) - # Config num_epochs = 3 batch_size = 1 @@ -179,7 +175,6 @@ def test_mnist_training_with_grad_accumulation(): @pytest.mark.push def test_forge_vs_torch_gradients(freeze_layer): logger.disable("") - torch.manual_seed(0) batch_size = 64 dtype = torch.float32 @@ -239,8 +234,6 @@ def test_forge_vs_torch_gradients(freeze_layer): @pytest.mark.skip(reason="Need to be tested with bfloat16 and takes around 10 minutes to run") @pytest.mark.push def test_forge_vs_torch(): - torch.manual_seed(0) - batch_size = 64 learning_rate = 1e-2 epochs = 10 @@ -339,8 +332,6 @@ def test_forge_vs_torch(): @pytest.mark.push def test_loss_device(): - torch.manual_seed(0) - # Config num_epochs = 3 batch_size = 1 @@ -415,20 +406,20 @@ def test_loss_device(): @pytest.mark.push def test_lora(): - torch.manual_seed(0) - # Config num_epochs = 3 - batch_size = 64 - learning_rate = 0.001 + batch_size = 128 + learning_rate = 0.1 # Load dataset test_loader, train_loader = load_dataset(batch_size) framework_model = MNISTLora(bias=False) - framework_optimizer = torch.optim.SGD(framework_model.parameters(), lr=learning_rate) - tt_model = forge.compile(framework_model, sample_inputs=[torch.rand(batch_size, 784)], training=True) + tt_optimizer = forge.optimizers.SGD(learning_rate=learning_rate) + tt_model = forge.compile( + framework_model, sample_inputs=[torch.rand(batch_size, 784)], optimizer=tt_optimizer, training=True + ) loss_fn = CrossEntropyLoss(name="cross_entropy_loss") @@ -438,11 +429,10 @@ def test_lora(): logger.info("Starting training loop... (logger will be disabled)") logger.disable("") + prev_total_loss = 1e10 for epoch_idx in range(num_epochs): total_loss = 0 for _, (data, target) in enumerate(train_loader): - framework_optimizer.zero_grad() - # Create target tensor and leave on CPU target = nn.functional.one_hot(target, num_classes=10).float() @@ -451,16 +441,20 @@ def test_lora(): golden_pred = framework_model(data) assert compare_with_golden(golden_pred, pred, pcc=0.95) - loss = tt_loss(pred, target) - total_loss += loss[0].item() + loss = tt_loss(pred, target)[0] + + total_loss += loss.item() # Run backward pass on device tt_loss.backward() - # Adjust weights (on CPU) - framework_optimizer.step() + # Adjust weights on the device. + # NOTE: after executing the step, this will also zero the gradients. + tt_optimizer.step() print(f"epoch: {epoch_idx} loss: {total_loss}") + assert prev_total_loss - total_loss > 1e-5, "Loss should go down" + prev_total_loss = total_loss test_loss = 0 for _, (data, target) in enumerate(test_loader): @@ -474,8 +468,6 @@ def test_lora(): @pytest.mark.push def test_optimizer_device(): - torch.manual_seed(0) - # Config num_epochs = 32 batch_size = 1024 @@ -537,3 +529,72 @@ def test_optimizer_device(): break print(f"Test (total) loss: {test_loss}") + + +@pytest.mark.push +def test_e2e_device(): + # Config + num_epochs = 5 + batch_size = 1024 + learning_rate = 0.1 + + # Load dataset + test_loader, train_loader = load_dataset(batch_size) + + framework_model = MNISTLinear(bias=False) + framework_loss = torch.nn.CrossEntropyLoss() + tt_optimizer = forge.optimizers.SGD(learning_rate=learning_rate) + + tt_model = forge.compile( + framework_model, sample_inputs=[torch.rand(batch_size, 784)], optimizer=tt_optimizer, training=True + ) + + loss_inputs = [torch.rand(batch_size, 10).requires_grad_(True), torch.rand(batch_size, 10)] + loss_inputs = to_forge_tensors(loss_inputs) + tt_loss = forge.compile( + CrossEntropyLoss(name="cross_entropy_loss"), sample_inputs=loss_inputs, training=True, attach_to=tt_model + ) + + logger.info("Starting training loop... (logger will be disabled)") + logger.disable("") + + prev_total_loss = 1e10 + for epoch_idx in range(num_epochs): + total_loss = 0 + for batch_idx, (data, target) in enumerate(train_loader): + + # Create target tensor and leave on CPU. + target = nn.functional.one_hot(target, num_classes=10).float() + + # Forward pass (prediction) on device. + pred = tt_model(data)[0] + golden_pred = framework_model(data) + assert compare_with_golden(golden_pred, pred, pcc=0.95) + + # Execute loss (and its backward) on device. + loss = tt_loss(pred, target)[0] + total_loss += loss.item() + + golden_loss = framework_loss(pred, target) + assert compare_with_golden(golden_loss, loss, rtol=1e-1) + + tt_loss.backward() + + # Adjust weights on the device. + # NOTE: after executing the step, this will also zero the gradients. + tt_optimizer.step() + + print(f"epoch: {epoch_idx} loss: {total_loss}") + + assert prev_total_loss - total_loss > 1e-5, "Loss should go down" + prev_total_loss = total_loss + + test_loss = 0 + for batch_idx, (data, target) in enumerate(test_loader): + pred = tt_model(data)[0] + target = nn.functional.one_hot(target, num_classes=10).float() + + test_loss += framework_loss(pred, target) + break + + print(f"Test (total) loss: {test_loss}")