diff --git a/mediapipe/tasks/tools/BUILD b/mediapipe/model_maker/python/llm/BUILD similarity index 100% rename from mediapipe/tasks/tools/BUILD rename to mediapipe/model_maker/python/llm/BUILD diff --git a/mediapipe/tasks/tools/converter_base.py b/mediapipe/model_maker/python/llm/converter_base.py similarity index 100% rename from mediapipe/tasks/tools/converter_base.py rename to mediapipe/model_maker/python/llm/converter_base.py diff --git a/mediapipe/tasks/tools/quantization_util.py b/mediapipe/model_maker/python/llm/quantization_util.py similarity index 100% rename from mediapipe/tasks/tools/quantization_util.py rename to mediapipe/model_maker/python/llm/quantization_util.py diff --git a/mediapipe/tasks/tools/quantization_util_test.py b/mediapipe/model_maker/python/llm/quantization_util_test.py similarity index 99% rename from mediapipe/tasks/tools/quantization_util_test.py rename to mediapipe/model_maker/python/llm/quantization_util_test.py index 81b5425559..c858dc1b77 100644 --- a/mediapipe/tasks/tools/quantization_util_test.py +++ b/mediapipe/model_maker/python/llm/quantization_util_test.py @@ -19,7 +19,7 @@ from jax import numpy as jnp import numpy as np -from mediapipe.tasks.tools import quantization_util +from mediapipe.model_maker.python.llm import quantization_util _dtype = lambda x: getattr(x, 'dtype', None) or np.asarray(x).dtype diff --git a/mediapipe/tasks/testdata/text/BUILD b/mediapipe/tasks/testdata/text/BUILD index 50a3b46f22..b05e748b0a 100644 --- a/mediapipe/tasks/testdata/text/BUILD +++ b/mediapipe/tasks/testdata/text/BUILD @@ -20,6 +20,7 @@ load( package( default_visibility = [ "//mediapipe/calculators/tensor:__subpackages__", + "//mediapipe/model_maker:__subpackages__", "//mediapipe/tasks:__subpackages__", ], licenses = ["notice"], # Apache 2.0 diff --git a/mediapipe/tasks/tools/pytorch_converter.py b/mediapipe/tasks/tools/pytorch_converter.py deleted file mode 100644 index 6162091120..0000000000 --- a/mediapipe/tasks/tools/pytorch_converter.py +++ /dev/null @@ -1,259 +0,0 @@ -"""CkptLoader implementation for loading the Pytorch file.""" - -import enum -import os -from typing import List, Optional - -import numpy as np -import torch - -from mediapipe.tasks.tools import converter_base - - -class LayerType(enum.Enum): - """Enum for layer type.""" - - NONE = 0 - ATTENTION = 1 # Layer is part of the attention module. - FEEDFORWARD = 2 # Layer is part of the feedforward module in the Transformer. - EMBEDDING = 3 # Layer is the embedding lookup or final projection layer. - LAYER_NORM = ( - 4 # Layer is layer normalization before and after attention layer. - ) - - @classmethod - def get_layer_type(cls, layer_name: str): - """Gets the layer type of the given layer name.""" - ffn_layers = [ - "mlp", - ] - attn_layers = [ - "self_attention", - ] - emb_layers = [ - "word_embeddings", - "lm_head", - ] - layer_norms = [ - "input_layernorm", - "post_attention_layernorm", - "ln_f", - ] - if any(sub_name in layer_name for sub_name in attn_layers): - return LayerType.ATTENTION - if any(sub_name in layer_name for sub_name in ffn_layers): - return LayerType.FEEDFORWARD - if any(sub_name in layer_name for sub_name in emb_layers): - return LayerType.EMBEDDING - if any(sub_name in layer_name for sub_name in layer_norms): - return LayerType.LAYER_NORM - else: - return LayerType.NONE - - -class FalconMapper(converter_base.LayerActionMapperBase): - """LayerActionMapper for handling the Falcon-rw-1b model.""" - - # we don't quantize embedding, final MLP and layer norm for falcon model. - NON_QUANTIZED_LAYERS = [ - "transformer.word_embeddings.weight", - "transformer.ln_f", - "lm_head", - "input_layernorm", - "post_attention_layernorm", - ] - - def map_to_actions( - self, layer_name: str - ) -> Optional[converter_base.QuantizationAction]: - """Map the given layer name to actions.""" - quantize_axis = None - quantize_bits = None - if all(name not in layer_name for name in self.NON_QUANTIZED_LAYERS) and ( - layer_name.endswith(".weight") - ): - layer_type = LayerType.get_layer_type(layer_name) - quantize_axis = [0] - if layer_type == LayerType.FEEDFORWARD: - quantize_bits = self._feedforward_quant_bits - elif layer_type == LayerType.ATTENTION: - quantize_bits = self._attention_quant_bits - elif layer_type == LayerType.EMBEDDING: - quantize_bits = self._embedding_quant_bits - - return converter_base.QuantizationAction( - tensor_name=layer_name, - target_name=layer_name, - quantize_axis=quantize_axis, - quantize_bits=quantize_bits, - pack_dim=0, - ) - - def update_target_name(self, target_name: str) -> str: - """Updates the target name to match the tensor name convention.""" - layer_type = LayerType.get_layer_type(target_name) - - target_name = target_name.replace( - "transformer.h.", "params.lm.transformer.x_layers_" - ) - - if layer_type == LayerType.FEEDFORWARD: - target_name = target_name.replace(".weight", ".linear.w") - target_name = target_name.replace(".bias", ".bias.b") - target_name = target_name.replace( - "mlp.dense_h_to_4h", "ff_layer.ffn_layer1" - ) - target_name = target_name.replace( - "mlp.dense_4h_to_h", "ff_layer.ffn_layer2" - ) - elif layer_type == LayerType.ATTENTION: - target_name = target_name.replace("dense", "post") - target_name = target_name.replace(".weight", ".linear.w") - target_name = target_name.replace(".bias", ".bias.b") - elif layer_type == LayerType.EMBEDDING: - target_name = target_name.replace( - "transformer.word_embeddings", "params.lm.token_embedding" - ) - target_name = target_name.replace( - "lm_head", "params.lm.softmax.logits_ffn" - ) - target_name = target_name.replace(".weight", ".w") - elif layer_type == LayerType.LAYER_NORM: - target_name = target_name.replace("input_layernorm", "pre_layer_norm") - target_name = target_name.replace( - "pre_layer_norm.weight", "pre_layer_norm.scale" - ) - target_name = target_name.replace( - "post_attention_layernorm", "post_layer_norm" - ) - target_name = target_name.replace( - "post_layer_norm.weight", "post_layer_norm.scale" - ) - target_name = target_name.replace( - "transformer.ln_f.weight", "params.lm.final_ln.scale" - ) - target_name = target_name.replace( - "transformer.ln_f.bias", "params.lm.final_ln.bias" - ) - - return target_name - - -class PytorchCkptLoader(converter_base.CkptLoaderBase): - """CkptLoader implementation for loading the Pytorch model.""" - - def __init__( - self, - ckpt_path: str, - is_symmetric: bool, - attention_quant_bits: int, - feedforward_quant_bits: int, - embedding_quant_bits: int, - special_model: str, - ): - """Initializes the loader. - - Args: - ckpt_path: The filepath to the safetensors file. - is_symmetric: Whether to apply symmetric or asymmetric quantization. - attention_quant_bits: An integer that specify the target quantization bits - (support 8 or 4) for the attention layers. - feedforward_quant_bits: An integer that specify the target quantization - bits (support 8 or 4) for the feedforward layers in each Transformer - blocks. - embedding_quant_bits: An integer that specify the target quantization bits - (support 8 or 4) for the embedding (and the final projection) layers. - special_model: A string that indicates which input model is and whether - any special treatment is needed. - """ - super().__init__( - ckpt_path, - is_symmetric, - attention_quant_bits, - feedforward_quant_bits, - embedding_quant_bits, - ) - - self._special_model = special_model - if special_model in ["FALCON_RW_1B"]: - self.mapper = FalconMapper( - is_symmetric, - attention_quant_bits, - feedforward_quant_bits, - ) - else: - raise ValueError(f"Unknown special model: {special_model}") - - self._ckpt_path = ckpt_path - if not os.path.exists(self._ckpt_path): - raise ValueError(f"{self._ckpt_path} does not exists.") - self._model = torch.load(self._ckpt_path, map_location=torch.device("cpu")) - - def load_to_actions(self): - tensor_names = self._model.keys() - actions = [] - for tensor_name in tensor_names: - tensor_value = ( - self._model[tensor_name] - .to(torch.float32) - .t() - .contiguous() - .detach() - .cpu() - .numpy() - ) - if ( - isinstance(self.mapper, FalconMapper) - and "query_key_value" in tensor_name - ): - qkv_tensors = self._decompose_falcon_qkv(tensor_value) - for tensor, qkv_name in zip(qkv_tensors, ["q", "k", "v"]): - decomposed_name = tensor_name.replace("query_key_value", qkv_name) - action = self.mapper.map_to_actions(decomposed_name) - action.tensor_value = tensor - action.target_name = self.mapper.update_target_name(decomposed_name) - actions.append(action) - else: - action = self.mapper.map_to_actions(tensor_name) - if action is None: - continue - action.tensor_value = tensor_value - action.target_name = self.mapper.update_target_name(tensor_name) - actions.append(action) - return actions - - def _decompose_falcon_qkv(self, tensor_value: np.ndarray) -> List[np.ndarray]: - """Decomposes combined qkv tensor used in falcon model into separate q, k and v tensors.""" - chunk_size = 64 - hidden_size = 2048 - - tensor_value = tensor_value.transpose() - - q_tensor = np.zeros( - (hidden_size,) - + ((hidden_size,) if len(tensor_value.shape) == 2 else ()), - dtype=tensor_value.dtype, - ) - k_tensor = np.zeros_like(q_tensor, dtype=tensor_value.dtype) - v_tensor = np.zeros_like(k_tensor, dtype=tensor_value.dtype) - - j = 0 - for i in range(0 * chunk_size, hidden_size * 3, chunk_size * 3): - q_tensor[j : j + chunk_size] = tensor_value[i : i + chunk_size] - j += chunk_size - - j = 0 - for i in range(1 * chunk_size, hidden_size * 3, chunk_size * 3): - k_tensor[j : j + chunk_size] = tensor_value[i : i + chunk_size] - j += chunk_size - - j = 0 - for i in range(2 * chunk_size, hidden_size * 3, chunk_size * 3): - v_tensor[j : j + chunk_size] = tensor_value[i : i + chunk_size] - j += chunk_size - - return [ - np.ascontiguousarray(q_tensor.transpose()), - np.ascontiguousarray(k_tensor.transpose()), - np.ascontiguousarray(v_tensor.transpose()), - ] diff --git a/mediapipe/tasks/tools/pytorch_converter_test.py b/mediapipe/tasks/tools/pytorch_converter_test.py deleted file mode 100644 index 6fb5603840..0000000000 --- a/mediapipe/tasks/tools/pytorch_converter_test.py +++ /dev/null @@ -1,70 +0,0 @@ -"""Unit tests for pytorch_converter.""" - -import os - -from absl.testing import absltest -from absl.testing import parameterized - -from mediapipe.tasks.python.test import test_utils -from mediapipe.tasks.tools import pytorch_converter - -_TEST_DATA_DIR = 'mediapipe/tasks/testdata/text' -_PYTORCH_FILE = test_utils.get_test_data_path( - os.path.join(_TEST_DATA_DIR, 'falcon_rw_1b_test_weight.pt') -) - - -class PytorchConverterTest(parameterized.TestCase): - VARIABLE_NAMES = [ - 'transformer.word_embeddings.weight', - 'transformer.h.0.input_layernorm.weight', - 'transformer.h.0.input_layernorm.bias', - 'transformer.h.0.self_attention.query_key_value.weight', - 'transformer.h.0.self_attention.query_key_value.bias', - 'transformer.h.0.self_attention.dense.weight', - 'transformer.h.0.self_attention.dense.bias', - 'transformer.h.0.post_attention_layernorm.weight', - 'transformer.h.0.post_attention_layernorm.bias', - 'transformer.h.0.mlp.dense_h_to_4h.weight', - 'transformer.h.0.mlp.dense_h_to_4h.bias', - 'transformer.h.0.mlp.dense_4h_to_h.weight', - 'transformer.h.0.mlp.dense_4h_to_h.bias', - 'transformer.ln_f.weight', - 'transformer.ln_f.bias', - 'lm_head.weight', - ] - - def test_init(self): - loader = pytorch_converter.PytorchCkptLoader( - ckpt_path=_PYTORCH_FILE, - is_symmetric=True, - attention_quant_bits=8, - feedforward_quant_bits=8, - embedding_quant_bits=8, - special_model='FALCON_RW_1B', - ) - self.assertEqual(loader._ckpt_path, _PYTORCH_FILE) - self.assertEqual(loader._is_symmetric, True) - self.assertEqual(loader._attention_quant_bits, 8) - self.assertEqual(loader._feedforward_quant_bits, 8) - - @parameterized.product( - quant_bits=(4, 8), - ) - def test_load_to_actions(self, quant_bits): - loader = pytorch_converter.PytorchCkptLoader( - ckpt_path=_PYTORCH_FILE, - is_symmetric=True, - attention_quant_bits=8, - feedforward_quant_bits=quant_bits, - embedding_quant_bits=8, - special_model='FALCON_RW_1B', - ) - actions = loader.load_to_actions() - # There are 16 layers in the model, but qkv weight and bias would be - # decomposed to q, k, v tensors, so there would be 20 quantization actions. - self.assertLen(actions, 20) - - -if __name__ == '__main__': - absltest.main() diff --git a/mediapipe/tasks/tools/safetensors_converter.py b/mediapipe/tasks/tools/safetensors_converter.py deleted file mode 100644 index 134246b00c..0000000000 --- a/mediapipe/tasks/tools/safetensors_converter.py +++ /dev/null @@ -1,301 +0,0 @@ -"""CkptLoader implementation for loading the Safetensors.""" - -import array -import enum -import json -import os -from typing import List, Optional - -import numpy as np -import torch - -from mediapipe.tasks.tools import converter_base - - -class LayerType(enum.Enum): - """Enum for layer type.""" - - NONE = 0 - ATTENTION = 1 # Layer is part of the attention module. - FEEDFORWARD = 2 # Layer is part of the feedforward module in the Transformer. - EMBEDDING = 3 # Layer is the embedding lookup or final projection layer. - LAYER_NORM = ( - 4 # Layer is layer normalization before and after attention layer. - ) - - @classmethod - def get_layer_type(cls, layer_name: str): - """Gets the layer type of the given layer name.""" - ffn_layers = [ - "mlp", - ] - attn_layers = [ - "self_attn", - ] - emb_layers = [ - "embed_tokens", - "lm_head", - ] - layer_norms = [ - "input_layernorm", - "post_attention_layernorm", - "final_layernorm", - ] - if any(sub_name in layer_name for sub_name in attn_layers): - return LayerType.ATTENTION - if any(sub_name in layer_name for sub_name in ffn_layers): - return LayerType.FEEDFORWARD - if any(sub_name in layer_name for sub_name in emb_layers): - return LayerType.EMBEDDING - if any(sub_name in layer_name for sub_name in layer_norms): - return LayerType.LAYER_NORM - else: - return LayerType.NONE - - -class StablelmMapper(converter_base.LayerActionMapperBase): - """LayerActionMapper for handling the StableLM model.""" - - # we don't quantize layer norm for stablelm model. - NON_QUANTIZED_LAYERS = [ - "model.norm.weight", - "input_layernorm", - "post_attention_layernorm", - ] - - def map_to_actions( - self, layer_name: str - ) -> Optional[converter_base.QuantizationAction]: - """Map the given layer name to actions.""" - quantize_axis = None - quantize_bits = None - layer_type = LayerType.get_layer_type(layer_name) - - if layer_type != LayerType.LAYER_NORM and layer_name.endswith(".weight"): - quantize_axis = [0] - if layer_type == LayerType.FEEDFORWARD: - quantize_bits = self._feedforward_quant_bits - elif layer_type == LayerType.ATTENTION: - quantize_bits = self._attention_quant_bits - elif layer_type == LayerType.EMBEDDING: - quantize_bits = self._embedding_quant_bits - target_name = self.update_target_name(layer_name) - - return converter_base.QuantizationAction( - tensor_name=layer_name, - target_name=target_name, - quantize_axis=quantize_axis, - quantize_bits=quantize_bits, - pack_dim=0, - ) - - def update_target_name(self, target_name: str) -> str: - """Updates the target name to match the tensor name convention.""" - target_name = target_name.replace( - "model.layers.", "params.lm.transformer.x_layers_" - ) - target_name = target_name.replace("mlp.up_proj", "ff_layer.ffn_layer1") - target_name = target_name.replace("mlp.down_proj", "ff_layer.ffn_layer2") - target_name = target_name.replace( - "mlp.gate_proj", "ff_layer.ffn_layer1_gate" - ) - target_name = target_name.replace("input_layernorm", "pre_layer_norm") - target_name = target_name.replace( - "pre_layer_norm.weight", "pre_layer_norm.scale" - ) - target_name = target_name.replace( - "post_attention_layernorm", "post_layer_norm" - ) - target_name = target_name.replace( - "post_layer_norm.weight", "post_layer_norm.scale" - ) - target_name = target_name.replace("self_attn.q_proj", "self_attention.q") - target_name = target_name.replace("self_attn.k_proj", "self_attention.k") - target_name = target_name.replace("self_attn.v_proj", "self_attention.v") - target_name = target_name.replace("self_attn.o_proj", "self_attention.post") - target_name = target_name.replace( - "model.embed_tokens", "params.lm.token_embedding" - ) - target_name = target_name.replace("model.norm", "params.lm.final_ln") - target_name = target_name.replace("final_ln.weight", "final_ln.scale") - target_name = target_name.replace("lm_head", "params.lm.softmax.logits_ffn") - target_name = target_name.replace(".weight", ".w") - - return target_name - - -class PhiMapper(converter_base.LayerActionMapperBase): - """LayerActionMapper for handling the Phi model.""" - - def map_to_actions( - self, layer_name: str - ) -> Optional[converter_base.QuantizationAction]: - """Map the given layer name to actions.""" - quantize_axis = None - quantize_bits = None - layer_type = LayerType.get_layer_type(layer_name) - - if layer_type != LayerType.LAYER_NORM and layer_name.endswith(".weight"): - quantize_axis = [0] - if layer_type == LayerType.FEEDFORWARD: - quantize_bits = self._feedforward_quant_bits - elif layer_type == LayerType.ATTENTION: - quantize_bits = self._attention_quant_bits - elif layer_type == LayerType.EMBEDDING: - quantize_bits = self._embedding_quant_bits - target_name = self.update_target_name(layer_name) - - return converter_base.QuantizationAction( - tensor_name=layer_name, - target_name=target_name, - quantize_axis=quantize_axis, - quantize_bits=quantize_bits, - pack_dim=0, - ) - - def update_target_name(self, target_name: str) -> str: - """Updates the target name to match the tensor name convention.""" - target_name = target_name.replace( - "model.layers.", "params.lm.transformer.x_layers_" - ) - - layer_type = LayerType.get_layer_type(target_name) - if layer_type == LayerType.FEEDFORWARD: - target_name = target_name.replace(".weight", ".linear.w") - target_name = target_name.replace(".bias", ".bias.b") - target_name = target_name.replace("mlp.fc1", "ff_layer.ffn_layer1") - target_name = target_name.replace("mlp.fc2", "ff_layer.ffn_layer2") - - elif layer_type == LayerType.ATTENTION: - target_name = target_name.replace(".weight", ".linear.w") - target_name = target_name.replace(".bias", ".bias.b") - target_name = target_name.replace("self_attn.q_proj", "self_attention.q") - target_name = target_name.replace("self_attn.k_proj", "self_attention.k") - target_name = target_name.replace("self_attn.v_proj", "self_attention.v") - target_name = target_name.replace( - "self_attn.dense", "self_attention.post" - ) - elif layer_type == LayerType.EMBEDDING: - target_name = target_name.replace( - "model.embed_tokens", "params.lm.token_embedding" - ) - target_name = target_name.replace( - "lm_head", "params.lm.softmax.logits_ffn" - ) - target_name = target_name.replace( - "logits_ffn.weight", "logits_ffn.linear.w" - ) - target_name = target_name.replace("logits_ffn.bias", "logits_ffn.bias.b") - elif layer_type == LayerType.LAYER_NORM: - target_name = target_name.replace("input_layernorm", "pre_layer_norm") - target_name = target_name.replace( - "pre_layer_norm.weight", "pre_layer_norm.scale" - ) - target_name = target_name.replace( - "model.final_layernorm", "params.lm.final_ln" - ) - target_name = target_name.replace("final_ln.weight", "final_ln.scale") - target_name = target_name.replace(".weight", ".w") - return target_name - - -DTYPE_MAP = { - "F16": torch.float16, - "BF16": torch.bfloat16, - "F32": torch.float32, -} - - -class SafetensorsCkptLoader(converter_base.CkptLoaderBase): - """CkptLoader implementation for loading the Safetensors.""" - - _HEAD_BYTES = 8 - - def __init__( - self, - ckpt_path: str, - is_symmetric: bool, - attention_quant_bits: int, - feedforward_quant_bits: int, - embedding_quant_bits: int, - special_model: str, - ): - """Initializes the loader. - - Args: - ckpt_path: The filepath to the safetensors file. - is_symmetric: Whether to apply symmetric or asymmetric quantization. - attention_quant_bits: An integer that specify the target quantization bits - (support 8 or 4) for the attention layers. - feedforward_quant_bits: An integer that specify the target quantization - bits (support 8 or 4) for the feedforward layers in each Transformer - blocks. - embedding_quant_bits: An integer that specify the target quantization bits - (support 8 or 4) for the embedding (and the final projection) layers. - special_model: A string that indicates which input model is and whether - any special treatment is needed. - """ - super().__init__( - ckpt_path, - is_symmetric, - attention_quant_bits, - feedforward_quant_bits, - embedding_quant_bits, - ) - - self._special_model = special_model - if special_model in ["STABLELM_4E1T_3B"]: - self.mapper = StablelmMapper( - is_symmetric, - attention_quant_bits, - feedforward_quant_bits, - embedding_quant_bits, - ) - elif special_model in ["PHI_2"]: - self.mapper = PhiMapper( - is_symmetric, - attention_quant_bits, - feedforward_quant_bits, - embedding_quant_bits, - ) - else: - raise ValueError(f"Unknown special model: {special_model}") - - self._ckpt_path = ckpt_path - if not os.path.exists(self._ckpt_path): - raise ValueError(f"{self._ckpt_path} does not exists.") - with open(self._ckpt_path, "rb") as f: - head_bytes = f.read(self._HEAD_BYTES) - metadata_bytes_num = np.frombuffer(head_bytes, dtype=np.uint64)[0] - metadata_bytes = f.read(metadata_bytes_num) - self.layers_info = json.loads(metadata_bytes) - self.metadata_bytes_num = metadata_bytes_num - - def load_to_actions(self) -> List[converter_base.QuantizationAction]: - tensor_names = self.layers_info.keys() - actions = [] - for tensor_name in tensor_names: - if tensor_name == "__metadata__": - continue - action = self.mapper.map_to_actions(tensor_name) - if action is None: - continue - action.tensor_value = self._read_tensor_as_numpy(tensor_name) - actions.append(action) - return actions - - def _read_tensor_as_numpy(self, tensor_name) -> np.ndarray: - """Reads a tensor from the model file as a numpy array with np.float32 type.""" - tensor_info = self.layers_info[tensor_name] - with open(self._ckpt_path, "rb") as f: - shape = tensor_info["shape"] - dtype = tensor_info["dtype"] - if dtype not in DTYPE_MAP: - raise ValueError(f"{dtype} is not supported.") - data_offsets = tensor_info["data_offsets"] - f.seek(int(self._HEAD_BYTES + self.metadata_bytes_num + data_offsets[0])) - tensor_bytes = f.read(data_offsets[1] - data_offsets[0]) - raw_tensor = torch.frombuffer( - array.array("b", tensor_bytes), dtype=DTYPE_MAP[dtype] - ).reshape(shape) - return raw_tensor.float().t().contiguous().numpy() diff --git a/mediapipe/tasks/tools/safetensors_converter_test.py b/mediapipe/tasks/tools/safetensors_converter_test.py deleted file mode 100644 index 8e5653ef51..0000000000 --- a/mediapipe/tasks/tools/safetensors_converter_test.py +++ /dev/null @@ -1,67 +0,0 @@ -"""Unit tests for safetensors_converter.""" - -import os - -from absl.testing import absltest -from absl.testing import parameterized - -from mediapipe.tasks.python.test import test_utils -from mediapipe.tasks.tools import safetensors_converter - -_TEST_DATA_DIR = 'mediapipe/tasks/testdata/text' -_SAFETENSORS_FILE = test_utils.get_test_data_path( - os.path.join(_TEST_DATA_DIR, 'stablelm_3b_4e1t_test_weight.safetensors') -) - - -class SafetensorsConverterTest(parameterized.TestCase): - VARIABLE_NAMES = [ - 'model.embed_tokens.weight', - 'model.layers.0.input_layernorm.bias', - 'model.layers.0.input_layernorm.weight', - 'model.layers.0.mlp.down_proj.weight', - 'model.layers.0.mlp.gate_proj.weight', - 'model.layers.0.mlp.up_proj.weight', - 'model.layers.0.post_attention_layernorm.bias', - 'model.layers.0.post_attention_layernorm.weight', - 'model.layers.0.self_attn.k_proj.weight', - 'model.layers.0.self_attn.o_proj.weight', - 'model.layers.0.self_attn.q_proj.weight', - 'model.layers.0.self_attn.v_proj.weight', - 'model.norm.bias', - 'model.norm.weight', - 'lm_head.weight', - ] - - def test_init(self): - loader = safetensors_converter.SafetensorsCkptLoader( - ckpt_path=_SAFETENSORS_FILE, - is_symmetric=True, - attention_quant_bits=8, - feedforward_quant_bits=8, - embedding_quant_bits=8, - special_model='STABLELM_4E1T_3B', - ) - self.assertEqual(loader._ckpt_path, _SAFETENSORS_FILE) - self.assertEqual(loader._is_symmetric, True) - self.assertEqual(loader._attention_quant_bits, 8) - self.assertEqual(loader._feedforward_quant_bits, 8) - - @parameterized.product( - quant_bits=(4, 8), - ) - def test_load_to_actions(self, quant_bits): - loader = safetensors_converter.SafetensorsCkptLoader( - ckpt_path=_SAFETENSORS_FILE, - is_symmetric=True, - attention_quant_bits=8, - feedforward_quant_bits=quant_bits, - embedding_quant_bits=8, - special_model='STABLELM_4E1T_3B', - ) - actions = loader.load_to_actions() - self.assertLen(actions, 15) - - -if __name__ == '__main__': - absltest.main()