From 3b1d7d1ae88683a27e45fba984050c970370b0e6 Mon Sep 17 00:00:00 2001 From: botbw Date: Mon, 14 Oct 2024 09:41:25 +0000 Subject: [PATCH] [chore] refactor --- colossalai/utils/safetensors.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/colossalai/utils/safetensors.py b/colossalai/utils/safetensors.py index 4e295cdfc111..9aa3558d9926 100644 --- a/colossalai/utils/safetensors.py +++ b/colossalai/utils/safetensors.py @@ -6,6 +6,10 @@ import torch from safetensors.torch import _TYPES +try: + from tensornvme.async_file_io import AsyncFileWriter +except ModuleNotFoundError: + raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer") _TYPES_INV = {v: k for k, v in _TYPES.items()} @@ -47,3 +51,14 @@ def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Ten n = len(metadata_buf) return PreparedData(n=n, header_bytes=metadata_buf, offset=offset), tensors + + +def save(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None: + prepared_data, tensors = prepare(state_dict) + n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset + + f_writer.write(n.to_bytes(8, byteorder="little")) + f_writer.write(header_bytes) + + for tensor in tensors: + f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset)