You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello, I encounted an error when calling flop_count_table() in my distributed training code.
The error message is as below. But I checked the input of function allgather() and didn't find anything unusual.
File "/xxx/anaconda3/envs/torch13/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 2275, in all_gather
work = default_pg.allgather([tensor_list], [tensor])
RuntimeError: unsupported input list type: Tensor[]
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 628698) of binary: /xxx/anaconda3/envs/torch13/bin/python
Here's a brief code which can regenerate my error by calling python -m torch.distributed.run --nproc_per_node=1 --master_port 10603 try.py
import torch
import torch.nn as nn
from fvcore.nn import FlopCountAnalysis, flop_count_str, flop_count_table
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(10, 1)
def forward(self, x):
y = self.fc(x)
concat_all_gather(y)
return y.sum()
@torch.no_grad()
def concat_all_gather(tensor):
tensors_gather = [torch.ones_like(tensor)
for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
output = torch.cat(tensors_gather, dim=0)
return output
torch.distributed.init_process_group(backend='nccl')
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
model = SimpleModel().cuda()
model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
flop = FlopCountAnalysis(model.module, torch.randn(100, 10).cuda())
print(flop_count_table(flop, max_depth=7, show_param_shapes=True))
torch.distributed.destroy_process_group()
Additionally, my environment is: Python 3.9.18, cuda-11.7, fvcore==0.1.5.post20221221, torch 1.13
Another confusing thing is, in the python3.8.18 & cuda-11.4 & torch 1.10 environment, the above doesn't result in an error.
The text was updated successfully, but these errors were encountered:
Hello, I encounted an error when calling
flop_count_table()
in my distributed training code.The error message is as below. But I checked the input of function
allgather()
and didn't find anything unusual.Here's a brief code which can regenerate my error by calling
python -m torch.distributed.run --nproc_per_node=1 --master_port 10603 try.py
Additionally, my environment is:
Python 3.9.18
,cuda-11.7
,fvcore==0.1.5.post20221221
,torch 1.13
Another confusing thing is, in the
python3.8.18
&cuda-11.4
&torch 1.10
environment, the above doesn't result in an error.The text was updated successfully, but these errors were encountered: