diff --git a/examples/fabric/tensor_parallel/train.py b/examples/fabric/tensor_parallel/train.py index 1435e5c2003e1..4a98f12cf6168 100644 --- a/examples/fabric/tensor_parallel/train.py +++ b/examples/fabric/tensor_parallel/train.py @@ -1,14 +1,13 @@ import lightning as L import torch import torch.nn.functional as F +from data import RandomTokenDataset from lightning.fabric.strategies import ModelParallelStrategy from model import ModelArgs, Transformer from parallelism import parallelize from torch.distributed.tensor.parallel import loss_parallel from torch.utils.data import DataLoader -from data import RandomTokenDataset - def train(): strategy = ModelParallelStrategy( diff --git a/examples/pytorch/tensor_parallel/train.py b/examples/pytorch/tensor_parallel/train.py index 37c620f4582f0..6a91e1242e4af 100644 --- a/examples/pytorch/tensor_parallel/train.py +++ b/examples/pytorch/tensor_parallel/train.py @@ -1,14 +1,13 @@ import lightning as L import torch import torch.nn.functional as F +from data import RandomTokenDataset from lightning.pytorch.strategies import ModelParallelStrategy from model import ModelArgs, Transformer from parallelism import parallelize from torch.distributed.tensor.parallel import loss_parallel from torch.utils.data import DataLoader -from data import RandomTokenDataset - class Llama3(L.LightningModule): def __init__(self):