Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cugraph Examples Updates #9953

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Consolidate Cugraph examples into ogbn_train_cugraph.py and ogbn_train_cugraph_multigpu.py for ogbn-arxiv, ogbn-products and ogbn-papers100M ([#9953](https://github.com/pyg-team/pytorch_geometric/pull/9953))
- Added support for weighted `LinkPredRecall` metric ([#9947](https://github.com/pyg-team/pytorch_geometric/pull/9947))
- Added support for weighted `LinkPredNDCG` metric ([#9945](https://github.com/pyg-team/pytorch_geometric/pull/9945))
- Added `LinkPredMetricCollection` ([#9941](https://github.com/pyg-team/pytorch_geometric/pull/9941))
Expand Down
387 changes: 387 additions & 0 deletions examples/multi_gpu/ogbn_train_cugraph_multigpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,387 @@
"""Single-node, multi-GPU example."""

import argparse
import os
import os.path as osp
import tempfile
import time

import cupy
import rmm
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.functional as F
from cugraph.gnn import (
cugraph_comms_create_unique_id,
cugraph_comms_init,
cugraph_comms_shutdown,
)
from ogb.nodeproppred import PygNodePropPredDataset
from torch.nn.parallel import DistributedDataParallel

import torch_geometric
from torch_geometric.utils import to_undirected

# Allow computation on objects that are larger than GPU memory
# https://docs.rapids.ai/api/cudf/stable/developer_guide/library_design/#spilling-to-host-memory
os.environ['CUDF_SPILL'] = '1'

# Ensures that a CUDA context is not created on import of rapids.
# Allows pytorch to create the context instead
os.environ['RAPIDS_NO_INITIALIZE'] = '1'


def arg_parse():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
parser.add_argument(
'--dataset',
type=str,
default='ogbn-papers100M',
choices=['ogbn-papers100M', 'ogbn-products', 'ogbn-arxiv'],
help='Dataset name.',
)
parser.add_argument(
'--dataset_dir',
type=str,
default='/workspace/data',
help='Root directory of dataset.',
)
parser.add_argument(
"--dataset_subdir",
type=str,
default="ogb-papers100M",
help="directory of dataset.",
)
parser.add_argument('--hidden_channels', type=int, default=256)
parser.add_argument('--num_layers', type=int, default=2)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--wd', type=float, default=0.000)
parser.add_argument('-e', '--epochs', type=int, default=10)
parser.add_argument('-b', '--batch_size', type=int, default=1024)
parser.add_argument('--fan_out', type=int, default=10)
parser.add_argument('--eval_steps', type=int, default=1000)
parser.add_argument('--warmup_steps', type=int, default=20)
parser.add_argument(
'--use_directed_graph',
action='store_true',
help='Whether or not to use directed graph',
)
parser.add_argument(
"--model",
type=str,
default='GCN',
choices=['SAGE', 'GAT', 'GCN'],
help="Model used for training, default GraphSAGE",
)
parser.add_argument(
"--num_gat_conv_heads",
type=int,
default=4,
help="If using GATConv, number of attention heads to use",
)
parser.add_argument('--tempdir_root', type=str, default=None)
parser.add_argument(
'--num_devices',
type=int,
default=-1,
help='How many GPUs to use. Defaults to all available GPUs',
)
parser.add_argument(
'--verbose',
action='store_true',
help='Whether or not to generate statistical report',
)
args = parser.parse_args()

return args


def evaluate(rank, loader, model):
with torch.no_grad():
total_correct = total_examples = 0
for i, batch in enumerate(loader):
batch = batch.to(rank)
batch_size = batch.batch_size

batch.y = batch.y.to(torch.long)
out = model(batch.x, batch.edge_index)[:batch_size]

pred = out.argmax(dim=-1)
y = batch.y[:batch_size].view(-1).to(torch.long)

total_correct += int((pred == y).sum())
total_examples += y.size(0)

acc = total_correct / total_examples
return acc


def init_pytorch_worker(rank, world_size, cugraph_id):

rmm.reinitialize(
devices=rank,
managed_memory=True,
pool_allocator=True,
)

cupy.cuda.Device(rank).use()
from rmm.allocators.cupy import rmm_cupy_allocator

cupy.cuda.set_allocator(rmm_cupy_allocator)

from cugraph.testing.mg_utils import enable_spilling

enable_spilling()

torch.cuda.set_device(rank)

os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group('nccl', rank=rank, world_size=world_size)

cugraph_comms_init(rank=rank, world_size=world_size, uid=cugraph_id,
device=rank)


def run_train(rank, args, data, world_size, cugraph_id, model, split_idx,
num_classes, wall_clock_start, tempdir=None):

epochs = args.epochs
batch_size = args.batch_size
fan_out = args.fan_out
num_layers = args.num_layers

init_pytorch_worker(
rank,
world_size,
cugraph_id,
)

model = model.to(rank)
model = DistributedDataParallel(model, device_ids=[rank])
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr,
weight_decay=args.wd)

kwargs = dict(
num_neighbors=[fan_out] * num_layers,
batch_size=batch_size,
)
from cugraph_pyg.data import GraphStore, TensorDictFeatureStore
from cugraph_pyg.loader import NeighborLoader

graph_store = GraphStore(is_multi_gpu=True)
ixr = torch.tensor_split(data.edge_index, world_size, dim=1)[rank]
graph_store[dict(
edge_type=('node', 'rel', 'node'),
layout='coo',
is_sorted=False,
size=(data.num_nodes, data.num_nodes),
)] = ixr

feature_store = TensorDictFeatureStore()
feature_store['node', 'x'] = data.x
feature_store['node', 'y'] = data.y

dist.barrier()

ix_train = torch.tensor_split(split_idx['train'], world_size)[rank].cuda()
train_path = osp.join(tempdir, f'train_{rank}')
os.mkdir(train_path)
train_loader = NeighborLoader(
(feature_store, graph_store),
input_nodes=ix_train,
directory=train_path,
shuffle=True,
drop_last=True,
**kwargs,
)

ix_val = torch.tensor_split(split_idx['valid'], world_size)[rank].cuda()
val_path = osp.join(tempdir, f'val_{rank}')
os.mkdir(val_path)
val_loader = NeighborLoader(
(feature_store, graph_store),
input_nodes=ix_val,
directory=val_path,
drop_last=True,
**kwargs,
)

ix_test = torch.tensor_split(split_idx['test'], world_size)[rank].cuda()
test_path = osp.join(tempdir, f'test_{rank}')
os.mkdir(test_path)
test_loader = NeighborLoader(
(feature_store, graph_store),
input_nodes=ix_test,
directory=test_path,
drop_last=True,
local_seeds_per_call=80000,
**kwargs,
)

dist.barrier()

args.eval_steps
warmup_steps = args.warmup_steps
dist.barrier()
torch.cuda.synchronize()

if rank == 0:
prep_time = time.perf_counter() - wall_clock_start
print("Total time before training begins (prep_time) =", prep_time,
"seconds")
print("Beginning training...")

test_accs = []
val_accs = []
times = []
train_times = []
inference_times = []
best_val = best_test = 0.
start = time.time()
for epoch in range(1, epochs + 1):
train_start = time.time()
total_loss = 0
i = 0
for i, batch in enumerate(train_loader):
if i == warmup_steps:
torch.cuda.synchronize()
batch = batch.to(rank)
batch_size = batch.batch_size
batch.y = batch.y.to(torch.long)
optimizer.zero_grad()
out = model(batch.x, batch.edge_index)
loss = F.cross_entropy(out[:batch_size], batch.y[:batch_size])
loss.backward()
optimizer.step()
total_loss += loss
train_end = time.time()
train_times.append(train_end - train_start)
nb = i + 1.0
total_loss /= nb
dist.barrier()
torch.cuda.synchronize()

inference_start = time.time()
train_acc = evaluate(rank, train_loader, model)
dist.barrier()
val_acc = evaluate(rank, val_loader, model)
dist.barrier()
test_acc = evaluate(rank, test_loader, model)
dist.barrier()

inference_times.append(time.time() - inference_start)
test_accs.append(test_acc)
val_accs.append(val_acc)
if rank == 0:
print(f'Epoch {epoch:02d}, Loss: {total_loss:.4f}, Approx. Train:'
f' {train_acc:.4f} Time: {train_end - train_start:.4f}s')
print(f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, '
f'Test: {test_acc:.4f}')

times.append(time.time() - train_start)
if val_acc > best_val:
best_val = val_acc
if test_acc > best_test:
best_test = test_acc

print(f"Total time used for rank: {rank:02d} is {time.time()-start:.4f}")
if rank == 0:
test_acc = torch.tensor(test_accs)
val_acc = torch.tensor(val_accs)
print('============================')
print("Average Epoch Time on training: {:.4f}".format(
torch.tensor(train_times).mean()))
print("Average Epoch Time on inference: {:.4f}".format(
torch.tensor(inference_times).mean()))
print(f"Average Epoch Time: {torch.tensor(times).mean():.4f}")
print(f"Median time per epoch: {torch.tensor(times).median():.4f}s")
print(f'Final Test: {test_acc.mean():.4f} ± {test_acc.std():.4f}')
print(f'Final Validation: {val_acc.mean():.4f} ± {val_acc.std():.4f}')
print(f"Best validation accuracy: {best_val:.4f}")
print(f"Best testing accuracy: {best_test:.4f}")

if rank == 0:
print("Testing...")
final_test_acc = evaluate(rank, test_loader, model)
dist.barrier()
if rank == 0:
print(f'Test Accuracy: {final_test_acc:.4f} for rank: {rank:02d}')
if rank == 0:
total_time = time.perf_counter() - wall_clock_start
print(f"Total Training Runtime: {total_time - prep_time}s")
print(f"Total Program Runtime: {total_time}s")

cugraph_comms_shutdown()
dist.destroy_process_group()


if __name__ == '__main__':

args = arg_parse()

wall_clock_start = time.perf_counter()

root = osp.join(args.dataset_dir, args.dataset_subdir)
dataset = PygNodePropPredDataset(name=args.dataset, root=root)
split_idx = dataset.get_idx_split()
data = dataset[0]
if not args.use_directed_graph:
data.edge_index = to_undirected(data.edge_index, reduce="mean")
data.y = data.y.reshape(-1)

if args.model == "GAT":
print(f"Training {args.dataset} with GAT model.")
model = torch_geometric.nn.models.GAT(dataset.num_features,
args.hidden_channels,
args.num_layers,
dataset.num_classes,
heads=args.num_gat_conv_heads)
elif args.model == "GCN":
print(f"Training {args.dataset} with GCN model.")
model = torch_geometric.nn.models.GCN(
dataset.num_features,
args.hidden_channels,
args.num_layers,
dataset.num_classes,
)
elif args.model == "SAGE":
print(f"Training {args.dataset} with GraphSAGE model.")
model = torch_geometric.nn.models.GraphSAGE(
dataset.num_features,
args.hidden_channels,
args.num_layers,
dataset.num_classes,
)
else:
pass

print("Data =", data)
if args.num_devices < 1:
world_size = torch.cuda.device_count()
elif args.num_devices <= torch.cuda.device_count():
world_size = args.num_devices
else:
world_size = torch.cuda.device_count()
print('Let\'s use', world_size, 'GPUs!')

# Create the uid needed for cuGraph comms
cugraph_id = cugraph_comms_create_unique_id()

with tempfile.TemporaryDirectory(dir=args.tempdir_root) as tempdir:
if world_size > 1:
mp.spawn(
run_train,
args=(args, data, world_size, cugraph_id, model, split_idx,
dataset.num_classes, wall_clock_start, tempdir),
nprocs=world_size,
join=True,
)
else:
run_train(0, args, data, world_size, cugraph_id, model, split_idx,
dataset.num_classes, wall_clock_start, tempdir)

total_time = round(time.perf_counter() - wall_clock_start, 2)
print("Total Program Runtime (total_time) =", total_time, "seconds")
Loading
Loading