Multi-GPU Inference #9259
Replies: 7 comments 20 replies
-
@ricardorei Have you solved this problem? I find that |
Beta Was this translation helpful? Give feedback.
-
this doesn't work for you? trainer = Trainer(..., strategy='ddp')
model = ...
preds = trainer.predict(model, predict_dataloader) |
Beta Was this translation helpful? Give feedback.
-
@rohitgr7 |
Beta Was this translation helpful? Give feedback.
-
Alright it took me some time to figure out the best way to do this but here is my solution using DDP: import logging
import os
import shutil
import tempfile
import torch
from pytorch_lightning.callbacks import BasePredictionWriter
from .utils import Prediction, flatten_metadata, restore_list_order
logger = logging.getLogger(__name__)
class CustomWriter(BasePredictionWriter):
"""Pytorch Lightning Callback that saves predictions and the corresponding batch
indices in a temporary folder when using multigpu inference.
Args:
write_interval (str): When to perform write operations. Defaults to 'epoch'
"""
def __init__(self, write_interval="epoch") -> None:
super().__init__(write_interval)
def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
"""Saves predictions after running inference on all samples."""
# We need to save predictions in the most secure manner possible to avoid
# multiple users and processes writing to the same folder.
# For that we will create a tmp folder that will be shared only across
# the DDP processes that were created
if trainer.is_global_zero:
output_dir = [
tempfile.mkdtemp(),
]
logger.info(
"Created temporary folder to store predictions: {}.".format(
output_dir[0]
)
)
else:
output_dir = [
None,
]
torch.distributed.broadcast_object_list(output_dir)
# Make sure every process received the output_dir from RANK=0
torch.distributed.barrier()
# Now that we have a single output_dir shared across processes we can save
# prediction along with their indices.
self.output_dir = output_dir[0]
# this will create N (num processes) files in `output_dir` each containing
# the predictions of it's respective rank
torch.save(
predictions, os.path.join(self.output_dir, f"pred_{trainer.global_rank}.pt")
)
# optionally, you can also save `batch_indices` to get the information about
# the data index from your prediction data
torch.save(
batch_indices,
os.path.join(self.output_dir, f"batch_indices_{trainer.global_rank}.pt"),
)
def gather_all_predictions(self):
"""Reads all saves predictions from the self.output_dir into one single
Prediciton object respecting the original order of the samples.
"""
files = sorted(os.listdir(self.output_dir))
pred = flatten_predictions([torch.load(os.path.join(self.output_dir, f))[0] for f in files if "pred" in f])
indices = flatten_predictions([torch.load(os.path.join(self.output_dir, f))[0] for f in files if "batch_indices" in f])
TODO: this depends on your application
return output
def cleanup(self):
"""Cleans temporary files."""
logger.info("Cleanup temporary folder: {}.".format(self.output_dir))
shutil.rmtree(self.output_dir) |
Beta Was this translation helpful? Give feedback.
-
This code is honestly not well documented. 🥁After experimenting for many hours, this worked for me🥁:
The problem is if you don't call gather to all nodes, it will hang waiting for the other nodes to respond. Native pytorch has comparable functions for The annoying thing you will find is that this function is called after the model returns predictions, i.e.:
and if you coalesce the results returned by this line with strategy |
Beta Was this translation helpful? Give feedback.
-
Here is an updated gist of how to do this: https://gist.github.com/will-thompson-k/f6201b68c428d0344a6affa6d53bc91b |
Beta Was this translation helpful? Give feedback.
-
Is multi-gpu inference currently supported for trainers with strategy "deepspeed_stage_3"? |
Beta Was this translation helpful? Give feedback.
-
Hi all!
What is the best way to perform inference (
predict
) using multi-GPU?ATM in our framework we are relying on DP which is extremely slow and when I switch to DDP it basically splits the data loader into several data loaders and produces several "independent" system outputs. I would like something like DDP where in the end I could call a "merge" function to gather all the predictions that were performed by the different processes.
Am I missing something? there is probably a way do to this already...
Here is our predict code: [https://github.com/Unbabel/COMET/blob/master/comet/models/base.py#L395)
Beta Was this translation helpful? Give feedback.
All reactions