Skip to content

Commit

Permalink
fix pagable h2d memcpy (#5301)
Browse files Browse the repository at this point in the history
ZeRO offload case

Fix the issue of pageble h2d memcpy in step process. Now h2d memcpy uses
pinned memory.

Speedup h2d memcpy by 6x on single GPU and 4-5x on 8GPU node.

cc @tjruwase

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Ubuntu <deepspeed@deepspeed-login.2d1icxc5dsxehnpuwt3ifc34ph.gvxx.internal.cloudapp.net>
Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
4 people authored Apr 14, 2024
1 parent f69f884 commit 7b5b066
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,11 +390,16 @@ def __init__(self,
# Note that the params in single_partition_of_fp32_groups is cloned and detached
# from the origin params of the model.
if not fp16_master_weights_and_gradients:
self.single_partition_of_fp32_groups.append(self.parallel_partitioned_bit16_groups[i][partition_id].to(
self.device).clone().float().detach())
weights_partition = self.parallel_partitioned_bit16_groups[i][partition_id].to(
self.device).clone().float().detach()
else:
self.single_partition_of_fp32_groups.append(self.parallel_partitioned_bit16_groups[i][partition_id].to(
self.device).clone().half().detach())
weights_partition = self.parallel_partitioned_bit16_groups[i][partition_id].to(
self.device).clone().half().detach()

if self.cpu_offload:
weights_partition = get_accelerator().pin_memory(weights_partition)

self.single_partition_of_fp32_groups.append(weights_partition)

# Set local optimizer to have flat params of its own partition.
# After this, the local optimizer will only contain its own partition of params.
Expand Down Expand Up @@ -1862,7 +1867,8 @@ def step(self, closure=None):
# bit16_partitions[partition_id].data.copy_(fp32_partition.data)
bit16_partitions = self.parallel_partitioned_bit16_groups[i]
fp32_partition = self.single_partition_of_fp32_groups[i]
bit16_partitions[partition_id].data.copy_(fp32_partition.data)
bit16_partitions[partition_id].data.copy_(
fp32_partition.to(get_accelerator().current_device_name()).data)

self.timers(OPTIMIZER_STEP_TIMER).stop()
else:
Expand Down

0 comments on commit 7b5b066

Please sign in to comment.