Skip to content

Commit

Permalink
[chore] refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
botbw committed Oct 17, 2024
1 parent 2bcd0b6 commit 3b1d7d1
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions colossalai/utils/safetensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
import torch
from safetensors.torch import _TYPES

try:
from tensornvme.async_file_io import AsyncFileWriter
except ModuleNotFoundError:
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
_TYPES_INV = {v: k for k, v in _TYPES.items()}


Expand Down Expand Up @@ -47,3 +51,14 @@ def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Ten
n = len(metadata_buf)

return PreparedData(n=n, header_bytes=metadata_buf, offset=offset), tensors


def save(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None:
prepared_data, tensors = prepare(state_dict)
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset

f_writer.write(n.to_bytes(8, byteorder="little"))
f_writer.write(header_bytes)

for tensor in tensors:
f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset)

0 comments on commit 3b1d7d1

Please sign in to comment.