forked from learning-at-home/hivemind
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbase.py
122 lines (95 loc) · 5.1 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import dataclasses
import os
import warnings
from abc import ABC, abstractmethod
from enum import Enum, auto
from typing import Any, Optional
import numpy as np
import torch
from hivemind.proto import runtime_pb2
from hivemind.utils.tensor_descr import TensorDescriptor
# While converting read-only NumPy arrays into PyTorch tensors, we don't make extra copies for efficiency
warnings.filterwarnings("ignore", message="The given NumPy array is not writable", category=UserWarning)
USE_LEGACY_BFLOAT16 = bool(int(os.environ.get("USE_LEGACY_BFLOAT16", 1)))
Key = Any
class TensorRole(Enum):
ACTIVATION = auto()
PARAMETER = auto()
GRADIENT = auto()
OPTIMIZER = auto()
UNSPECIFIED = auto()
@dataclasses.dataclass(frozen=True)
class CompressionInfo:
"""Auxiliary data structure that contains information about the tensor that determines how it is compressed"""
key: Key # name or index of the tensor from named parameters, optimizer state dict or i/o structure
descriptor: TensorDescriptor # data structure that defines shape, dtype, layout and device information
role: TensorRole = TensorRole.UNSPECIFIED # which role does the tensor play with respect to the model
part_index: int = 0 # if tensor is sliced into parts, this represents the index within one tensor
part_size: Optional[int] = None # if tensor is sliced into parts, this is the _maximum_ number of values per part
@classmethod
def from_tensor(cls, tensor: torch.Tensor, key: Key = None, descriptor: TensorDescriptor = None, **kwargs):
return cls(key, descriptor or TensorDescriptor.from_tensor(tensor), **kwargs)
def get_part(self, part_index: int, part_size: Optional[int]):
return CompressionInfo(self.key, self.descriptor, self.role, part_index=part_index, part_size=part_size)
class CompressionBase(ABC):
"""A base class that applies compression algorithm to a pytorch tensor"""
compression_type: runtime_pb2.CompressionType
@abstractmethod
def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
"""
Applies compression algorithm to a tensor based on their meta-parameters
:param tensor: a pytorch tensor to compress; depending on the application, it is a full tensor or a part
:param info: meta-information about the tensor; if partitioning is used, this still describes the full tensor
:param allow_inplace: if True, compression can (but doesn't have to) to modify tensor in-place for efficiency
:returns: a protobuf message that encodes the tensor
"""
...
@abstractmethod
def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
"""Create a pytorch tensor from the serialized outputs of .compress"""
...
@abstractmethod
def estimate_compression_ratio(self, info: CompressionInfo) -> float:
"""Estimate the compression ratio without doing the actual compression; lower ratio = better compression"""
...
def __repr__(self):
return f"hivemind.{self.__class__.__name__}()"
class NoCompression(CompressionBase):
"""A dummy compression strategy that preserves the original tensor as is."""
compression_type = runtime_pb2.CompressionType.NONE
def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
requires_grad = tensor.requires_grad
tensor = tensor.detach()
shape = tensor.shape
dtype_name = str(tensor.dtype).replace("torch.", "")
raw_data = tensor
if tensor.dtype == torch.bfloat16:
if USE_LEGACY_BFLOAT16: # legacy mode: convert to fp32
raw_data = tensor.to(torch.float32)
else: # efficient mode: send bfloat16 data directly
# reinterpret_cast to an arbitrary 2-byte type supported by numpy
raw_data = tensor.view(torch.int16)
return runtime_pb2.Tensor(
compression=self.compression_type,
buffer=raw_data.numpy().tobytes(),
size=shape,
dtype=dtype_name,
requires_grad=requires_grad,
)
def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
shape = torch.Size(serialized_tensor.size)
if serialized_tensor.dtype == "bfloat16":
numel = shape.numel()
if numel > 0 and len(serialized_tensor.buffer) // numel == 4:
array = np.frombuffer(serialized_tensor.buffer, dtype=np.float32)
tensor = torch.as_tensor(array, dtype=torch.bfloat16)
else:
array = np.frombuffer(serialized_tensor.buffer, dtype=np.int16)
# reinterpret_cast from an arbitrary 2-byte type supported by numpy
tensor = torch.as_tensor(array).view(torch.bfloat16)
else:
array = np.frombuffer(serialized_tensor.buffer, dtype=np.dtype(serialized_tensor.dtype))
tensor = torch.as_tensor(array)
return tensor.reshape(shape)
def estimate_compression_ratio(self, info: CompressionInfo) -> float:
return 1.0