From c0662f6b1969dff4c48611a68aae770dfb2537c2 Mon Sep 17 00:00:00 2001 From: erfanzar Date: Wed, 17 Jul 2024 16:07:21 -0700 Subject: [PATCH] Preparing For version Update `0.0.70` --- .vscode/PythonImportHelper-v2-Completion.json | 345 ++++++------------ README.md | 90 +++-- docs/contributing.rst | 43 +++ docs/index.rst | 68 +++- src/fjformer/core/implicit_array.py | 13 +- src/fjformer/custom_array/array4bit.py | 16 +- src/fjformer/custom_array/array8bit.py | 5 +- src/fjformer/optimizers/adafactor.py | 29 ++ src/fjformer/optimizers/adamw.py | 22 +- src/fjformer/optimizers/lion.py | 33 +- src/fjformer/optimizers/rmsprop.py | 52 ++- test/array4bit.py | 6 +- 12 files changed, 417 insertions(+), 305 deletions(-) create mode 100644 docs/contributing.rst diff --git a/.vscode/PythonImportHelper-v2-Completion.json b/.vscode/PythonImportHelper-v2-Completion.json index fbe5cac..97a049f 100644 --- a/.vscode/PythonImportHelper-v2-Completion.json +++ b/.vscode/PythonImportHelper-v2-Completion.json @@ -212,7 +212,7 @@ "documentation": {} }, { - "label": "ClassVar", + "label": "Any", "importPath": "typing", "description": "typing", "isExtraImport": true, @@ -220,7 +220,7 @@ "documentation": {} }, { - "label": "Optional", + "label": "Callable", "importPath": "typing", "description": "typing", "isExtraImport": true, @@ -228,7 +228,7 @@ "documentation": {} }, { - "label": "Callable", + "label": "ClassVar", "importPath": "typing", "description": "typing", "isExtraImport": true, @@ -236,7 +236,7 @@ "documentation": {} }, { - "label": "Any", + "label": "Optional", "importPath": "typing", "description": "typing", "isExtraImport": true, @@ -1068,31 +1068,7 @@ "documentation": {} }, { - "label": "Any", - "importPath": "typing", - "description": "typing", - "isExtraImport": true, - "detail": "typing", - "documentation": {} - }, - { - "label": "Callable", - "importPath": "typing", - "description": "typing", - "isExtraImport": true, - "detail": "typing", - "documentation": {} - }, - { - "label": "Optional", - "importPath": "typing", - "description": "typing", - "isExtraImport": true, - "detail": "typing", - "documentation": {} - }, - { - "label": "Literal", + "label": "Tuple", "importPath": "typing", "description": "typing", "isExtraImport": true, @@ -1206,6 +1182,14 @@ "detail": "dataclasses", "documentation": {} }, + { + "label": "dataclass", + "importPath": "dataclasses", + "description": "dataclasses", + "isExtraImport": true, + "detail": "dataclasses", + "documentation": {} + }, { "label": "calibration", "importPath": "fjformer.bit_quantization", @@ -1656,7 +1640,7 @@ "documentation": {} }, { - "label": "Array", + "label": "numpy", "importPath": "jax", "description": "jax", "isExtraImport": true, @@ -1664,7 +1648,7 @@ "documentation": {} }, { - "label": "numpy", + "label": "Array", "importPath": "jax", "description": "jax", "isExtraImport": true, @@ -1744,14 +1728,6 @@ "detail": "functools", "documentation": {} }, - { - "label": "partial", - "importPath": "functools", - "description": "functools", - "isExtraImport": true, - "detail": "functools", - "documentation": {} - }, { "label": "numpy", "kind": 6, @@ -1828,14 +1804,6 @@ "detail": "flax", "documentation": {} }, - { - "label": "linen", - "importPath": "flax", - "description": "flax", - "isExtraImport": true, - "detail": "flax", - "documentation": {} - }, { "label": "flax.linen", "kind": 6, @@ -2050,46 +2018,6 @@ "detail": "itertools", "documentation": {} }, - { - "label": "UninitializedAval", - "importPath": "fjformer.core.errors", - "description": "fjformer.core.errors", - "isExtraImport": true, - "detail": "fjformer.core.errors", - "documentation": {} - }, - { - "label": "MaterializationError", - "importPath": "fjformer.core.errors", - "description": "fjformer.core.errors", - "isExtraImport": true, - "detail": "fjformer.core.errors", - "documentation": {} - }, - { - "label": "OperationError", - "importPath": "fjformer.core.errors", - "description": "fjformer.core.errors", - "isExtraImport": true, - "detail": "fjformer.core.errors", - "documentation": {} - }, - { - "label": "ShapeDtypeError", - "importPath": "fjformer.core.errors", - "description": "fjformer.core.errors", - "isExtraImport": true, - "detail": "fjformer.core.errors", - "documentation": {} - }, - { - "label": "UnsupportedPrimitiveError", - "importPath": "fjformer.core.errors", - "description": "fjformer.core.errors", - "isExtraImport": true, - "detail": "fjformer.core.errors", - "documentation": {} - }, { "label": "jax.extend.linear_util", "kind": 6, @@ -2189,6 +2117,46 @@ "detail": "plum", "documentation": {} }, + { + "label": "UninitializedAval", + "importPath": "fjformer.core.errors", + "description": "fjformer.core.errors", + "isExtraImport": true, + "detail": "fjformer.core.errors", + "documentation": {} + }, + { + "label": "MaterializationError", + "importPath": "fjformer.core.errors", + "description": "fjformer.core.errors", + "isExtraImport": true, + "detail": "fjformer.core.errors", + "documentation": {} + }, + { + "label": "OperationError", + "importPath": "fjformer.core.errors", + "description": "fjformer.core.errors", + "isExtraImport": true, + "detail": "fjformer.core.errors", + "documentation": {} + }, + { + "label": "ShapeDtypeError", + "importPath": "fjformer.core.errors", + "description": "fjformer.core.errors", + "isExtraImport": true, + "detail": "fjformer.core.errors", + "documentation": {} + }, + { + "label": "UnsupportedPrimitiveError", + "importPath": "fjformer.core.errors", + "description": "fjformer.core.errors", + "isExtraImport": true, + "detail": "fjformer.core.errors", + "documentation": {} + }, { "label": "ELEMENTWISE_BINOPS", "importPath": "fjformer.core.implicit_array", @@ -2277,14 +2245,6 @@ "detail": "fjformer.core.implicit_array", "documentation": {} }, - { - "label": "implicit_compact", - "importPath": "fjformer.core.implicit_array", - "description": "fjformer.core.implicit_array", - "isExtraImport": true, - "detail": "fjformer.core.implicit_array", - "documentation": {} - }, { "label": "Complement", "importPath": "fjformer.core.types", @@ -2383,6 +2343,38 @@ "detail": "fjformer.core", "documentation": {} }, + { + "label": "ImplicitArray", + "importPath": "fjformer.core", + "description": "fjformer.core", + "isExtraImport": true, + "detail": "fjformer.core", + "documentation": {} + }, + { + "label": "primitive_handler", + "importPath": "fjformer.core", + "description": "fjformer.core", + "isExtraImport": true, + "detail": "fjformer.core", + "documentation": {} + }, + { + "label": "use_implicit_args", + "importPath": "fjformer.core", + "description": "fjformer.core", + "isExtraImport": true, + "detail": "fjformer.core", + "documentation": {} + }, + { + "label": "ArrayValue", + "importPath": "fjformer.core", + "description": "fjformer.core", + "isExtraImport": true, + "detail": "fjformer.core", + "documentation": {} + }, { "label": "chex", "kind": 6, @@ -2456,14 +2448,6 @@ "detail": "fjformer", "documentation": {} }, - { - "label": "GenerateRNG", - "importPath": "fjformer", - "description": "fjformer", - "isExtraImport": true, - "detail": "fjformer", - "documentation": {} - }, { "label": "flax.core", "kind": 6, @@ -2983,14 +2967,6 @@ "detail": "fjformer.custom_array.array4bit", "documentation": {} }, - { - "label": "Array4Bit", - "importPath": "fjformer.custom_array.array4bit", - "description": "fjformer.custom_array.array4bit", - "isExtraImport": true, - "detail": "fjformer.custom_array.array4bit", - "documentation": {} - }, { "label": "Array8Bit", "importPath": "fjformer.custom_array.array8bit", @@ -3103,46 +3079,6 @@ "detail": "fjformer.optimizers.adamw", "documentation": {} }, - { - "label": "quad", - "importPath": "scipy.integrate", - "description": "scipy.integrate", - "isExtraImport": true, - "detail": "scipy.integrate", - "documentation": {} - }, - { - "label": "root_scalar", - "importPath": "scipy.optimize", - "description": "scipy.optimize", - "isExtraImport": true, - "detail": "scipy.optimize", - "documentation": {} - }, - { - "label": "norm", - "importPath": "scipy.stats", - "description": "scipy.stats", - "isExtraImport": true, - "detail": "scipy.stats", - "documentation": {} - }, - { - "label": "halfnorm", - "importPath": "scipy.stats", - "description": "scipy.stats", - "isExtraImport": true, - "detail": "scipy.stats", - "documentation": {} - }, - { - "label": "truncnorm", - "importPath": "scipy.stats", - "description": "scipy.stats", - "isExtraImport": true, - "detail": "scipy.stats", - "documentation": {} - }, { "label": "project", "kind": 5, @@ -3922,7 +3858,7 @@ "kind": 5, "importPath": "src.fjformer.core.implicit_array", "description": "src.fjformer.core.implicit_array", - "peekOfCode": "_dispatch = Dispatcher()\n_primitive_ids = count()\nclass ArrayValue(ABC):\n pass\nArrayValue.register(jax.Array)\ndef get_lax_primitive_by_name(name: str) -> jax.core.Primitive:\n \"\"\"Get a JAX LAX primitive by its name.\"\"\"\n return getattr(jax.lax, f\"{name}_p\")\ndef get_primitive_handler(primitive):\n \"\"\"Get or create a handler for a given primitive.\"\"\"", + "peekOfCode": "_dispatch = Dispatcher()\n_primitive_ids = count()\nwarnings.filterwarnings(\n \"ignore\", message=\"Could not resolve the type hint of `~B`\", module=\"plum.type\"\n)\nwarnings.filterwarnings(\n \"ignore\", message=\"Could not resolve the type hint of `~A`\", module=\"plum.type\"\n)\nclass ArrayValue(ABC):\n pass", "detail": "src.fjformer.core.implicit_array", "documentation": {} }, @@ -3931,7 +3867,7 @@ "kind": 5, "importPath": "src.fjformer.core.implicit_array", "description": "src.fjformer.core.implicit_array", - "peekOfCode": "_primitive_ids = count()\nclass ArrayValue(ABC):\n pass\nArrayValue.register(jax.Array)\ndef get_lax_primitive_by_name(name: str) -> jax.core.Primitive:\n \"\"\"Get a JAX LAX primitive by its name.\"\"\"\n return getattr(jax.lax, f\"{name}_p\")\ndef get_primitive_handler(primitive):\n \"\"\"Get or create a handler for a given primitive.\"\"\"\n if isinstance(primitive, str):", + "peekOfCode": "_primitive_ids = count()\nwarnings.filterwarnings(\n \"ignore\", message=\"Could not resolve the type hint of `~B`\", module=\"plum.type\"\n)\nwarnings.filterwarnings(\n \"ignore\", message=\"Could not resolve the type hint of `~A`\", module=\"plum.type\"\n)\nclass ArrayValue(ABC):\n pass\nArrayValue.register(jax.Array)", "detail": "src.fjformer.core.implicit_array", "documentation": {} }, @@ -5110,7 +5046,7 @@ "kind": 2, "importPath": "src.fjformer.optimizers.lion", "description": "src.fjformer.optimizers.lion", - "peekOfCode": "def get_lion_with_linear_scheduler(\n steps: int,\n learning_rate_start: float = 5e-5,\n learning_rate_end: float = 1e-5,\n b1: float = 0.9,\n b2: float = 0.99,\n gradient_accumulation_steps: int = 1,\n mu_dtype: Optional[chex.ArrayDType] = None,\n) -> Tuple[optax.GradientTransformation, optax.Schedule]:\n \"\"\"", + "peekOfCode": "def get_lion_with_linear_scheduler(\n steps: int,\n learning_rate_start: float = 5e-5,\n learning_rate_end: float = 1e-5,\n b1: float = 0.9,\n b2: float = 0.99,\n gradient_accumulation_steps: int = 1,\n mu_dtype: Optional[chex.ArrayDType] = None,\n clip_grad: Optional[float] = None,\n **kwargs,", "detail": "src.fjformer.optimizers.lion", "documentation": {} }, @@ -5128,7 +5064,7 @@ "kind": 2, "importPath": "src.fjformer.optimizers.lion", "description": "src.fjformer.optimizers.lion", - "peekOfCode": "def get_lion_with_cosine_scheduler(\n steps: int,\n learning_rate: float = 5e-5,\n alpha: float = 0.0,\n exponent: float = 1.0,\n b1: float = 0.9,\n b2: float = 0.99,\n gradient_accumulation_steps: int = 1,\n mu_dtype: Optional[chex.ArrayDType] = None,\n) -> Tuple[optax.GradientTransformation, optax.Schedule]:", + "peekOfCode": "def get_lion_with_cosine_scheduler(\n steps: int,\n learning_rate: float = 5e-5,\n alpha: float = 0.0,\n exponent: float = 1.0,\n b1: float = 0.9,\n b2: float = 0.99,\n gradient_accumulation_steps: int = 1,\n mu_dtype: Optional[chex.ArrayDType] = None,\n clip_grad: Optional[float] = None,", "detail": "src.fjformer.optimizers.lion", "documentation": {} }, @@ -6789,120 +6725,57 @@ "documentation": {} }, { - "label": "Model", + "label": "QuantizedArray", "kind": 6, "importPath": "env", "description": "env", - "peekOfCode": "class Model(nn.Module):\n \"\"\"A simple linear model for demonstration.\"\"\"\n def setup(self) -> None:\n \"\"\"Initializes the model layers.\"\"\"\n self.fc = nn.Dense(512, use_bias=True, dtype=jnp.float32)\n self.fc1 = nn.Dense(64, use_bias=True, dtype=jnp.float32)\n self.out = nn.Dense(1, use_bias=True, dtype=jnp.float32)\n def __call__(self, x):\n \"\"\"Performs a forward pass through the model.\"\"\"\n x = self.fc(x)", + "peekOfCode": "class QuantizedArray(ImplicitArray):\n array_quant: ArrayValue\n scale: ArrayValue\n min_vals: ArrayValue\n def materialize(self):\n return self.dequantize(\n array_quant=self.array_quant,\n scale=self.scale,\n min_vals=self.min_vals,\n float_dtype=self.dtype,", "detail": "env", "documentation": {} }, { - "label": "quantize_params", + "label": "quantize", "kind": 2, "importPath": "env", "description": "env", - "peekOfCode": "def quantize_params(\n params: dict,\n block_size: Literal[32, 64, 128, 256, 512, 1024, 2048, 4096] = 64,\n contraction_axis: int = -1,\n factors: Optional[Array] = None,\n) -> dict:\n \"\"\"Quantizes model parameters using Array4Bit.\n Args:\n params: A dictionary of model parameters.\n Returns:", + "peekOfCode": "def quantize(array: Array, axis: int = -1) -> Tuple[Array, Array, Array]:\n min_vals = jnp.min(array, axis=axis, keepdims=True)\n max_vals = jnp.max(array, axis=axis, keepdims=True)\n # Compute the scaling factors\n scale = (max_vals - min_vals) / (2**7 - 1)\n # Quantize the data\n quantized_data = jnp.round((array - min_vals) / scale)\n # Clip the quantized values to ensure they lie within the representable range\n quantized_data = jnp.clip(quantized_data, 0, 2**7 - 1).astype(jnp.uint8)\n return quantized_data, scale, min_vals", "detail": "env", "documentation": {} }, { - "label": "main", + "label": "dequantize", "kind": 2, "importPath": "env", "description": "env", - "peekOfCode": "def main():\n \"\"\"\n Demonstrates the quantization process using a simple model.\n - Initializes a model and random input data.\n - Quantizes the model parameters.\n - Performs inference using both the original and quantized models.\n - Prints the output of both models for comparison.\n \"\"\"\n model = Model()\n init_x = jax.random.normal(rng.rng, (1, 64))", + "peekOfCode": "def dequantize(\n array_quant: Array,\n scale: Array,\n min_vals: Array,\n float_dtype: jnp.dtype = jnp.float16,\n):\n return (array_quant * scale + min_vals).astype(float_dtype)\n@dataclass\nclass QuantizedArray(ImplicitArray):\n array_quant: ArrayValue", "detail": "env", "documentation": {} }, { - "label": "rng", - "kind": 5, + "label": "get_binop_result_shape_dtype", + "kind": 2, "importPath": "env", "description": "env", - "peekOfCode": "rng = GenerateRNG()\nclass Model(nn.Module):\n \"\"\"A simple linear model for demonstration.\"\"\"\n def setup(self) -> None:\n \"\"\"Initializes the model layers.\"\"\"\n self.fc = nn.Dense(512, use_bias=True, dtype=jnp.float32)\n self.fc1 = nn.Dense(64, use_bias=True, dtype=jnp.float32)\n self.out = nn.Dense(1, use_bias=True, dtype=jnp.float32)\n def __call__(self, x):\n \"\"\"Performs a forward pass through the model.\"\"\"", + "peekOfCode": "def get_binop_result_shape_dtype(a, b):\n out_shape = jnp.broadcast_shapes(jnp.shape(a), jnp.shape(b))\n out_dtype = jnp.result_type(a.dtype, b.dtype)\n return out_shape, out_dtype\n@jax.jit\n@use_implicit_args\ndef f(x, y):\n return (x + y)[0, 0]\ndef main():\n orginal_array = jax.random.normal(jax.random.PRNGKey(0), (512, 64))", "detail": "env", "documentation": {} }, { - "label": "integrand", + "label": "f", "kind": 2, - "importPath": "experimental", - "description": "experimental", - "peekOfCode": "def integrand(block_size, x, m):\n p_z_less_than_mx = truncnorm.cdf(m * x, -m, m)\n pm = block_size * (halfnorm.cdf(m) ** (block_size - 1)) * 2 * norm.pdf(m)\n return p_z_less_than_mx * pm\ndef scaled_norm_cdf(block_size, x):\n result = quad(\n partial(integrand, block_size, x),\n 0,\n np.inf,\n epsabs=1e-9,", - "detail": "experimental", - "documentation": {} - }, - { - "label": "scaled_norm_cdf", - "kind": 2, - "importPath": "experimental", - "description": "experimental", - "peekOfCode": "def scaled_norm_cdf(block_size, x):\n result = quad(\n partial(integrand, block_size, x),\n 0,\n np.inf,\n epsabs=1e-9,\n )\n return result[0]\ndef cdf(x, block_size):\n discrete_mass = 1 / (2 * block_size)", - "detail": "experimental", - "documentation": {} - }, - { - "label": "cdf", - "kind": 2, - "importPath": "experimental", - "description": "experimental", - "peekOfCode": "def cdf(x, block_size):\n discrete_mass = 1 / (2 * block_size)\n cont_mass = scaled_norm_cdf(block_size, x) * (block_size - 1) / block_size\n result = discrete_mass + cont_mass\n result = np.where(x < -1, 0, result)\n result = np.where(x >= 1, 1, result)\n return result\ndef inv_cdf(val, block_size):\n edge_mass = 1 / (2 * block_size)\n if val <= edge_mass:", - "detail": "experimental", - "documentation": {} - }, - { - "label": "inv_cdf", - "kind": 2, - "importPath": "experimental", - "description": "experimental", - "peekOfCode": "def inv_cdf(val, block_size):\n edge_mass = 1 / (2 * block_size)\n if val <= edge_mass:\n return -1\n if val >= 1 - edge_mass:\n return 1\n def search_fn(x):\n return cdf(x, block_size) - val\n return root_scalar(search_fn, bracket=[-1, 1]).root\ndef build_code(", - "detail": "experimental", - "documentation": {} - }, - { - "label": "build_code", - "kind": 2, - "importPath": "experimental", - "description": "experimental", - "peekOfCode": "def build_code(\n start,\n lower_bound,\n upper_bound,\n n_steps,\n bcdf,\n binv_cdf,\n lower_bound_is_code_point=True,\n):\n code = [start]", - "detail": "experimental", - "documentation": {} - }, - { - "label": "interval_code_search", - "kind": 2, - "importPath": "experimental", - "description": "experimental", - "peekOfCode": "def interval_code_search(\n lower_bound,\n upper_bound,\n n_steps,\n block_size,\n bounds_are_code_points=True,\n):\n bcdf = partial(cdf, block_size=block_size)\n binv_cdf = partial(inv_cdf, block_size=block_size)\n code_builder = partial(", - "detail": "experimental", - "documentation": {} - }, - { - "label": "construct_af4", - "kind": 2, - "importPath": "experimental", - "description": "experimental", - "peekOfCode": "def construct_af4(block_size):\n lower = interval_code_search(-1, 0, 5, block_size)\n upper = -interval_code_search(-1, 0, 6, block_size)[::-1]\n code = np.asarray([-1.0, *lower, 0.0, *upper, 1.0], dtype=np.float64)\n assert code.shape == (16,)\n return code\ndef main():\n for block_size in (2048,):\n start = time.time()\n code = construct_af4(block_size)", - "detail": "experimental", + "importPath": "env", + "description": "env", + "peekOfCode": "def f(x, y):\n return (x + y)[0, 0]\ndef main():\n orginal_array = jax.random.normal(jax.random.PRNGKey(0), (512, 64))\n quantized_array = QuantizedArray.quantize(orginal_array)\n print(f(quantized_array, jnp.ones(64)))\n print((orginal_array + jnp.ones(64))[0, 0])\nif __name__ == \"__main__\":\n main()", + "detail": "env", "documentation": {} }, { "label": "main", "kind": 2, - "importPath": "experimental", - "description": "experimental", - "peekOfCode": "def main():\n for block_size in (2048,):\n start = time.time()\n code = construct_af4(block_size)\n end = time.time()\n print(f\"B = {block_size} - Runtime: {end - start:.2f} sec\\n{code}\\n\")\n np.save(f\"af4_{block_size}.npy\", code)\ndef view():\n for block_size in (2048,):\n array = np.load(f\"af4_{block_size}.npy\")", - "detail": "experimental", - "documentation": {} - }, - { - "label": "view", - "kind": 2, - "importPath": "experimental", - "description": "experimental", - "peekOfCode": "def view():\n for block_size in (2048,):\n array = np.load(f\"af4_{block_size}.npy\")\n print(array)\nif __name__ == \"__main__\":\n main()\n # view()", - "detail": "experimental", + "importPath": "env", + "description": "env", + "peekOfCode": "def main():\n orginal_array = jax.random.normal(jax.random.PRNGKey(0), (512, 64))\n quantized_array = QuantizedArray.quantize(orginal_array)\n print(f(quantized_array, jnp.ones(64)))\n print((orginal_array + jnp.ones(64))[0, 0])\nif __name__ == \"__main__\":\n main()", + "detail": "env", "documentation": {} }, { diff --git a/README.md b/README.md index 0b0bc03..3576f2f 100644 --- a/README.md +++ b/README.md @@ -1,36 +1,76 @@ # FJFormer -Embark on a journey of paralleled/unparalleled computational prowess with FJFormer - an arsenal of custom Jax Flax -Functions and Utils that elevate your AI endeavors to new heights! +[![PyPI version](https://badge.fury.io/py/fjformer.svg)](https://badge.fury.io/py/fjformer) +[![Documentation Status](https://readthedocs.org/projects/fjformer/badge/?version=latest)](https://fjformer.readthedocs.io/en/latest/?badge=latest) +[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) -## Overview +FJFormer is a powerful and flexible JAX-based package designed to accelerate and simplify machine learning and deep learning workflows. It provides a comprehensive suite of tools and utilities for efficient model development, training, and deployment. -FJFormer is a collection of functions and utilities that can help with various tasks when using Flax and JAX. It -includes -checkpoint savers, partitioning tools, and other helpful functions. -The goal of FJFormer is to make your life easier when working with Flax and JAX. Whether you are training a new model, -fine-tuning an existing one, or just exploring the capabilities of these powerful frameworks, FJFormer offers +## Features -- Pallas Kernels for GPU,TPU -- BITComputations for 8,6,4 BIT Flax Models -- Built-in functions and Loss functions -- Distributed and sharding Model Loaders and Checkpoint Savers -- Monitoring Utils for *TPU/GPU/CPU* memory `foot-print` -- Optimizers -- Special Optimizers with schedulers and Easy to Use -- Partitioning Utils -- LoRA +### 1. JAX Sharding Utils +Leverage the power of distributed computing and model parallelism with our advanced JAX sharding utilities. These tools enable efficient splitting and management of large models across multiple devices, enhancing performance and enabling the training of larger models. -And a lot of these features are fully documented so FJFormer has something -to offer, and it's not just a Computation BackEnd for [EasyDel](https://github.com/erfanzar/EasyDel). +### 2. Custom Pallas / Triton Operation Kernels +Boost your model's performance with our optimized kernels for specific operations. These custom-built kernels, implemented using Pallas and Triton, provide significant speedups for common bottleneck operations in deep learning models. -checkout for documentation [here](https://fjformer.readthedocs.io/en/latest/). +### 3. Pre-built Optimizers +Jump-start your training with our collection of ready-to-use, efficiently implemented optimization algorithms: +- **AdamW**: An Adam variant with decoupled weight decay. +- **Adafactor**: Memory-efficient adaptive optimization algorithm. +- **Lion**: Recently proposed optimizer combining the benefits of momentum and adaptive methods. +- **RMSprop**: Adaptive learning rate optimization algorithm. -## Contributing +### 4. Utility Functions +A rich set of utility functions to streamline your workflow, including: +- Various loss functions (e.g., cross-entropy) +- Metrics calculation +- Data preprocessing tools -FJFormer is an open-source project, and contributions are always welcome! If you have a feature request, bug report, or -just want to help out with development, please check out our GitHub repository and feel free to submit a pull request or -open an issue. +### 5. ImplicitArray +Our innovative ImplicitArray class provides a powerful abstraction for representing and manipulating large arrays without instantiation. Benefits include: +- Lazy evaluation for memory efficiency +- Optimized array operations in JAX +- Seamless integration with other FJFormer components -Thank you for using FJFormer, and happy training! +### 6. Custom Dtypes +- Implement 4-bit quantization (NF4) effortlessly using our Array4Bit class, built on top of ImplicitArray. Reduce model size and increase inference speed without significant loss in accuracy. + +- Similar to Array4Bit, our Array8Bit implementation offers 8-bit quantization via ImplicitArray, providing a balance between model compression and precision. + +### 7. LoRA (Low-Rank Adaptation) +Efficiently fine-tune large language models with our LoRA implementation, leveraging ImplicitArray for optimal performance and memory usage. + +### 8. JAX and Array Manipulation +A comprehensive set of tools and utilities for efficient array operations and manipulations in JAX, designed to complement and extend JAX's native capabilities. + +### 9. Checkpoint Managers +Robust utilities for managing model checkpoints, including: +- Efficient saving and loading of model states +- Version control for checkpoints +- Integration with distributed training workflows + +## Installation + +You can install FJFormer using pip: + +```bash +pip install fjformer +``` + +For the latest development version, you can install directly from GitHub: + +```bash +pip install git+https://github.com/yourusername/fjformer.git +``` + +## Documentation + +For detailed documentation, including API references, please visit: + +[https://fjformer.readthedocs.org](https://fjformer.readthedocs.org) + +## License + +FJFormer is released under the Apache License 2.0. See the [LICENSE](LICENSE) file for more details. diff --git a/docs/contributing.rst b/docs/contributing.rst new file mode 100644 index 0000000..075d19c --- /dev/null +++ b/docs/contributing.rst @@ -0,0 +1,43 @@ +Contributing to FJFormer +========== +Thank you for considering contributing to FJFormer! We welcome your input. To ensure a smooth collaboration, please review and adhere to the following guidelines. + + +How to Contribute +------ +To contribute to EasyDeL, follow these steps: +1. Fork the repository. +2. Create a new branch for your feature or bug fix. +3. Make your changes and commit them with clear and descriptive messages. +4. Push your changes to your branch in your forked repository. +5. Submit a pull request to the main EasyDeL repository, detailing the changes you've made and the problem it solves. + + +Code of Conduct +------ +Please adhere to the `Apache Code of Conduct `_ in all interactions related to EasyDeL. + +Reporting Bugs +------ +If you encounter a bug, please open an issue on the EasyDeL repository, providing a clear and detailed description of the issue, including steps to reproduce it. + +Suggesting Enhancements +------ +If you have ideas for enhancements, feel free to open an issue on the EasyDeL repository. Provide a clear and detailed description of your proposed enhancement. + +Development Setup +------ +To set up EasyDeL for development, follow the instructions in the README.md file. + +Pull Request Guidelines +------ +When submitting a pull request, please ensure the following: +- Your code follows the project's coding standards. +- Your commits are accompanied by clear and descriptive messages. +- Your pull request addresses a single issue or feature. + +License +------ +By contributing to EasyDeL, you agree that your contributions will be licensed under the Apache License, Version 2.0. + +Thank you for your interest in contributing to EasyDeL! We appreciate your support. diff --git a/docs/index.rst b/docs/index.rst index a6aa00f..6be028b 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,26 +1,52 @@ FJFormer 🔮 ========== -Embark on a journey of paralleled/unparalleled computational prowess with FJFormer - an arsenal of custom Jax Flax -Functions and Utils that elevate your AI endeavors to new heights! +FJFormer is a powerful and flexible JAX-based package designed to accelerate and simplify machine learning and deep learning workflows. It provides a comprehensive suite of tools and utilities for efficient model development, training, and deployment. -Overview +Features ---------- -FJFormer is a collection of functions and utilities that can help with various tasks when using Flax and JAX. It -includes -checkpoint savers, partitioning tools, and other helpful functions. -The goal of FJFormer is to make your life easier when working with Flax and JAX. Whether you are training a new model, -fine-tuning an existing one, or just exploring the capabilities of these powerful frameworks, FJFormer offers - -- Pallas Kernels for GPU,TPU -- BITComputations for 8,6,4 BIT Flax Models -- Built-in functions and Loss functions -- Distributed and sharding Model Loaders and Checkpoint Savers -- Monitoring Utils for *TPU/GPU/CPU* memory `foot-print` -- Optimizers -- Special Optimizers with schedulers and Easy to Use -- Partitioning Utils -- LoRA +1. JAX Sharding Utils +Leverage the power of distributed computing and model parallelism with our advanced JAX sharding utilities. These tools enable efficient splitting and management of large models across multiple devices, enhancing performance and enabling the training of larger models. + +2. Custom Pallas / Triton Operation Kernels +Boost your model's performance with our optimized kernels for specific operations. These custom-built kernels, implemented using Pallas and Triton, provide significant speedups for common bottleneck operations in deep learning models. + +3. Pre-built Optimizers +Jump-start your training with our collection of ready-to-use, efficiently implemented optimization algorithms: +- **AdamW**: An Adam variant with decoupled weight decay. +- **Adafactor**: Memory-efficient adaptive optimization algorithm. +- **Lion**: Recently proposed optimizer combining the benefits of momentum and adaptive methods. +- **RMSprop**: Adaptive learning rate optimization algorithm. + +4. Utility Functions +A rich set of utility functions to streamline your workflow, including: +- Various loss functions (e.g., cross-entropy) +- Metrics calculation +- Data preprocessing tools + +5. ImplicitArray +Our innovative ImplicitArray class provides a powerful abstraction for representing and manipulating large arrays without instantiation. Benefits include: +- Lazy evaluation for memory efficiency +- Optimized array operations in JAX +- Seamless integration with other FJFormer components + +6. Custom Dtypes + +- Implement 4-bit quantization (NF4) effortlessly using our Array4Bit class, built on top of ImplicitArray. Reduce model size and increase inference speed without significant loss in accuracy. + +- Similar to Array4Bit, our Array8Bit implementation offers 8-bit quantization via ImplicitArray, providing a balance between model compression and precision. + +7. LoRA (Low-Rank Adaptation) +Efficiently fine-tune large language models with our LoRA implementation, leveraging ImplicitArray for optimal performance and memory usage. + +8. JAX and Array Manipulation +A comprehensive set of tools and utilities for efficient array operations and manipulations in JAX, designed to complement and extend JAX's native capabilities. + +9. Checkpoint Managers +Robust utilities for managing model checkpoints, including: +- Efficient saving and loading of model states +- Version control for checkpoints +- Integration with distributed training workflows .. _FJFormer: Zare Chavoshi, Erfan. "FJFormer is a collection of functions and utilities that can help with various tasks when using Flax and JAX."" @@ -34,3 +60,9 @@ Zare Chavoshi, Erfan. "FJFormer is a collection of functions and utilities that api_docs/APIs +.. toctree:: + :hidden: + :maxdepth: 1 + :caption: Getting Started + + contributing \ No newline at end of file diff --git a/src/fjformer/core/implicit_array.py b/src/fjformer/core/implicit_array.py index 46145e6..f04398c 100644 --- a/src/fjformer/core/implicit_array.py +++ b/src/fjformer/core/implicit_array.py @@ -16,8 +16,8 @@ from dataclasses import dataclass, field, fields, is_dataclass from functools import partial, wraps from itertools import chain, count -from typing import ClassVar, Optional, Callable, Any, Tuple -from fjformer.core.errors import UninitializedAval +from typing import Any, Callable, ClassVar, Optional, Tuple + import jax import jax.extend.linear_util as lu import jax.interpreters.partial_eval as pe @@ -27,9 +27,16 @@ from jax.tree_util import register_pytree_with_keys_class from plum import Dispatcher, Function +from fjformer.core.errors import UninitializedAval + _dispatch = Dispatcher() _primitive_ids = count() - +warnings.filterwarnings( + "ignore", message="Could not resolve the type hint of `~B`", module="plum.type" +) +warnings.filterwarnings( + "ignore", message="Could not resolve the type hint of `~A`", module="plum.type" +) class ArrayValue(ABC): pass diff --git a/src/fjformer/custom_array/array4bit.py b/src/fjformer/custom_array/array4bit.py index a8e1833..73a620b 100644 --- a/src/fjformer/custom_array/array4bit.py +++ b/src/fjformer/custom_array/array4bit.py @@ -467,12 +467,20 @@ def handle_transpose( """ original_quantized = False if isinstance(operand, Array4Bit): - operand = operand.materialize() + array = operand.materialize() original_quantized = True - operand = lax.transpose(operand, *args, **kwargs) + else: + array = operand + array = lax.transpose(array, *args, **kwargs) if original_quantized: - operand = Array4Bit.quantize(operand, dtype=operand.dtype) - return operand + array = Array4Bit.quantize( + array=array, + block_size=operand.block_size, + contraction_axis=operand.contraction_axis, + dtype=operand.dtype, + factors=operand.factors, + ) + return array @core.primitive_handler("conv_general_dilated") diff --git a/src/fjformer/custom_array/array8bit.py b/src/fjformer/custom_array/array8bit.py index 5a6baa4..027ea9a 100644 --- a/src/fjformer/custom_array/array8bit.py +++ b/src/fjformer/custom_array/array8bit.py @@ -69,7 +69,10 @@ def materialize(self) -> Array: @classmethod def quantize( - cls, array: Array, axis: int = -1, dtype: Optional[jnp.dtype] = None + cls, + array: Array, + axis: int = -1, + dtype: Optional[jnp.dtype] = None, ) -> "Array8Bit": """ Quantize a JAX array to 8-bit representation. diff --git a/src/fjformer/optimizers/adafactor.py b/src/fjformer/optimizers/adafactor.py index 031e142..bcf92a1 100644 --- a/src/fjformer/optimizers/adafactor.py +++ b/src/fjformer/optimizers/adafactor.py @@ -5,6 +5,7 @@ import optax from fjformer.optimizers.optimizer_utils import optax_add_scheduled_weight_decay +import warnings def _get_adafactor_base( @@ -22,6 +23,8 @@ def _get_adafactor_base( weight_decay: float = 0.0, weight_decay_mask: Optional[Any] = None, gradient_accumulation_steps: int = 1, + clip_grad: Optional[float] = None, + **kwargs ) -> optax.GradientTransformation: """ Creates a base Adafactor optimizer with the given scheduler and options. @@ -41,10 +44,14 @@ def _get_adafactor_base( weight_decay (float): Additional weight decay factor. weight_decay_mask (Optional[Any]): Mask for weight decay. gradient_accumulation_steps (int): Number of steps to accumulate gradients. + clip_grad (Optional[float]): If provided, gradients will be clipped to this maximum norm. Returns: optax.GradientTransformation: The configured optimizer. """ + + for kwarg in kwargs.keys(): + warnings.warn(f"Key {kwarg} is not used for optimizer.") chain = [ optax.adafactor( learning_rate=scheduler, @@ -60,6 +67,8 @@ def _get_adafactor_base( factored=factored, ) ] + if clip_grad is not None: + chain.insert(0, optax.clip_by_global_norm(clip_grad)) if weight_decay != 0.0: chain.append( @@ -93,6 +102,8 @@ def get_adafactor_with_linear_scheduler( factored: bool = True, gradient_accumulation_steps: int = 1, weight_decay_mask: Optional[Any] = None, + clip_grad: Optional[float] = None, + **kwargs, ) -> Tuple[optax.GradientTransformation, optax.Schedule]: """ Creates an Adafactor optimizer with a linear learning rate scheduler. @@ -114,6 +125,7 @@ def get_adafactor_with_linear_scheduler( factored (bool): Whether to use factored second moment estimates. gradient_accumulation_steps (int): Number of steps to accumulate gradients. weight_decay_mask (Optional[Any]): Mask for weight decay. + clip_grad (Optional[float]): If provided, gradients will be clipped to this maximum norm. Returns: Tuple[optax.GradientTransformation, optax.Schedule]: The optimizer and scheduler. @@ -139,6 +151,8 @@ def get_adafactor_with_linear_scheduler( weight_decay=weight_decay, weight_decay_mask=weight_decay_mask, gradient_accumulation_steps=gradient_accumulation_steps, + clip_grad=clip_grad, + **kwargs, ) return tx, scheduler @@ -160,6 +174,8 @@ def get_adafactor_with_warmup_linear_scheduler( eps: float = 1e-30, factored: bool = True, gradient_accumulation_steps: int = 1, + clip_grad: Optional[float] = None, + **kwargs, ) -> Tuple[optax.GradientTransformation, optax.Schedule]: """ Creates an Adafactor optimizer with a warm-up linear learning rate scheduler. @@ -180,6 +196,7 @@ def get_adafactor_with_warmup_linear_scheduler( eps (float): Epsilon for numerical stability. factored (bool): Whether to use factored second moment estimates. gradient_accumulation_steps (int): Number of steps to accumulate gradients. + clip_grad (Optional[float]): If provided, gradients will be clipped to this maximum norm. Returns: Tuple[optax.GradientTransformation, optax.Schedule]: The optimizer and scheduler. @@ -210,6 +227,8 @@ def get_adafactor_with_warmup_linear_scheduler( eps=eps, factored=factored, gradient_accumulation_steps=gradient_accumulation_steps, + clip_grad=clip_grad, + **kwargs, ) return tx, scheduler_combined @@ -229,6 +248,8 @@ def get_adafactor_with_cosine_scheduler( eps: float = 1e-30, factored: bool = True, gradient_accumulation_steps: int = 1, + clip_grad: Optional[float] = None, + **kwargs, ) -> Tuple[optax.GradientTransformation, optax.Schedule]: """ Creates an Adafactor optimizer with a cosine learning rate scheduler. @@ -247,6 +268,7 @@ def get_adafactor_with_cosine_scheduler( eps (float): Epsilon for numerical stability. factored (bool): Whether to use factored second moment estimates. gradient_accumulation_steps (int): Number of steps to accumulate gradients. + clip_grad (Optional[float]): If provided, gradients will be clipped to this maximum norm. Returns: Tuple[optax.GradientTransformation, optax.Schedule]: The optimizer and scheduler. @@ -266,6 +288,8 @@ def get_adafactor_with_cosine_scheduler( eps=eps, factored=factored, gradient_accumulation_steps=gradient_accumulation_steps, + clip_grad=clip_grad, + **kwargs, ) return tx, scheduler @@ -290,6 +314,8 @@ def get_adafactor_with_warmup_cosine_scheduler( weight_decay_mask: Optional[Any] = None, gradient_accumulation_steps: int = 1, warmup_steps: int = 100, + clip_grad: Optional[float] = None, + **kwargs, ) -> Tuple[optax.GradientTransformation, optax.Schedule]: """ Creates an Adafactor optimizer with a warm-up cosine learning rate scheduler. @@ -313,6 +339,7 @@ def get_adafactor_with_warmup_cosine_scheduler( weight_decay_mask (Optional[Any]): Mask for weight decay. gradient_accumulation_steps (int): Number of steps to accumulate gradients. warmup_steps (int): Number of warm-up steps. + clip_grad (Optional[float]): If provided, gradients will be clipped to this maximum norm. Returns: Tuple[optax.GradientTransformation, optax.Schedule]: The optimizer and scheduler. @@ -341,6 +368,8 @@ def get_adafactor_with_warmup_cosine_scheduler( weight_decay=weight_decay, weight_decay_mask=weight_decay_mask, gradient_accumulation_steps=gradient_accumulation_steps, + clip_grad=clip_grad, + **kwargs, ) return tx, scheduler diff --git a/src/fjformer/optimizers/adamw.py b/src/fjformer/optimizers/adamw.py index d7ede53..abcf434 100644 --- a/src/fjformer/optimizers/adamw.py +++ b/src/fjformer/optimizers/adamw.py @@ -1,3 +1,4 @@ +import warnings from typing import Optional, Tuple import chex @@ -13,7 +14,8 @@ def _get_adamw_base( weight_decay: float = 1e-1, gradient_accumulation_steps: int = 1, mu_dtype: Optional[chex.ArrayDType] = None, - clip_grad: Optional[float] = 1.0, + clip_grad: Optional[float] = None, + **kwargs, ) -> optax.GradientTransformation: """ Creates a base AdamW optimizer with the given scheduler. @@ -32,6 +34,8 @@ def _get_adamw_base( Returns: optax.GradientTransformation: The configured optimizer. """ + for kwarg in kwargs.keys(): + warnings.warn(f"Key {kwarg} is not used for optimizer.") chain = [ optax.scale_by_adam( b1=b1, @@ -66,7 +70,8 @@ def get_adamw_with_cosine_scheduler( weight_decay: float = 1e-1, gradient_accumulation_steps: int = 1, mu_dtype: Optional[chex.ArrayDType] = None, - clip_grad: Optional[float] = 1.0, + clip_grad: Optional[float] = None, + **kwargs, ) -> Tuple[optax.GradientTransformation, optax.Schedule]: """ Creates an AdamW optimizer with a cosine learning rate scheduler. @@ -97,6 +102,7 @@ def get_adamw_with_cosine_scheduler( gradient_accumulation_steps=gradient_accumulation_steps, mu_dtype=mu_dtype, clip_grad=clip_grad, + **kwargs, ) return tx, scheduler @@ -112,7 +118,8 @@ def get_adamw_with_linear_scheduler( weight_decay: float = 1e-1, gradient_accumulation_steps: int = 1, mu_dtype: Optional[chex.ArrayDType] = None, - clip_grad: Optional[float] = 1.0, + clip_grad: Optional[float] = None, + **kwargs, ) -> Tuple[optax.GradientTransformation, optax.Schedule]: """ Creates an AdamW optimizer with a linear learning rate scheduler. @@ -148,6 +155,7 @@ def get_adamw_with_linear_scheduler( gradient_accumulation_steps=gradient_accumulation_steps, mu_dtype=mu_dtype, clip_grad=clip_grad, + **kwargs, ) return tx, scheduler @@ -165,7 +173,8 @@ def get_adamw_with_warmup_linear_scheduler( mu_dtype: Optional[chex.ArrayDType] = None, warmup_steps: int = 100, warmup_init_value: float = 5e-8, - clip_grad: Optional[float] = 1.0, + clip_grad: Optional[float] = None, + **kwargs, ) -> Tuple[optax.GradientTransformation, optax.Schedule]: """ Creates an AdamW optimizer with a linear learning rate scheduler and warm-up phase. @@ -213,6 +222,7 @@ def get_adamw_with_warmup_linear_scheduler( gradient_accumulation_steps=gradient_accumulation_steps, mu_dtype=mu_dtype, clip_grad=clip_grad, + **kwargs, ) return tx, scheduler_combined @@ -231,7 +241,8 @@ def get_adamw_with_warmup_cosine_scheduler( warmup_steps: int = 100, mu_dtype: Optional[chex.ArrayDType] = None, warmup_init_value: float = 0.5e-7, - clip_grad: Optional[float] = 1.0, + clip_grad: Optional[float] = None, + **kwargs, ) -> Tuple[optax.GradientTransformation, optax.Schedule]: """ Creates an AdamW optimizer with a cosine learning rate scheduler and warm-up phase. @@ -273,5 +284,6 @@ def get_adamw_with_warmup_cosine_scheduler( gradient_accumulation_steps=gradient_accumulation_steps, mu_dtype=mu_dtype, clip_grad=clip_grad, + **kwargs, ) return tx, scheduler diff --git a/src/fjformer/optimizers/lion.py b/src/fjformer/optimizers/lion.py index d33d9b1..ae7ab50 100644 --- a/src/fjformer/optimizers/lion.py +++ b/src/fjformer/optimizers/lion.py @@ -1,3 +1,4 @@ +import warnings from typing import Optional, Tuple import chex @@ -10,6 +11,8 @@ def _get_lion_base( b2: float = 0.99, mu_dtype: Optional[chex.ArrayDType] = None, gradient_accumulation_steps: int = 1, + clip_grad: Optional[float] = None, + **kwargs, ) -> optax.GradientTransformation: """ Creates a base Lion optimizer with the given scheduler. @@ -20,15 +23,21 @@ def _get_lion_base( b2 (float): The exponential decay rate for the second moment estimates. mu_dtype (Optional[chex.ArrayDType]): Optional datatype for the first moment estimates. gradient_accumulation_steps (int): Number of steps to accumulate gradients. + clip_grad (Optional[float]): If provided, gradients will be clipped to this maximum norm. Returns: optax.GradientTransformation: The configured optimizer. """ - tx = optax.chain( + for kwarg in kwargs.keys(): + warnings.warn(f"Key {kwarg} is not used for optimizer.") + chain = [ optax.scale_by_lion(b1=b1, b2=b2, mu_dtype=mu_dtype), optax.scale_by_schedule(scheduler), optax.scale(-1), - ) + ] + if clip_grad is not None: + chain.insert(0, optax.clip_by_global_norm(clip_grad)) + tx = optax.chain(*chain) if gradient_accumulation_steps > 1: tx = optax.MultiSteps(tx, gradient_accumulation_steps) return tx @@ -42,6 +51,8 @@ def get_lion_with_linear_scheduler( b2: float = 0.99, gradient_accumulation_steps: int = 1, mu_dtype: Optional[chex.ArrayDType] = None, + clip_grad: Optional[float] = None, + **kwargs, ) -> Tuple[optax.GradientTransformation, optax.Schedule]: """ Creates a Lion optimizer with a linear learning rate scheduler. @@ -54,6 +65,7 @@ def get_lion_with_linear_scheduler( b2 (float): The exponential decay rate for the second moment estimates. gradient_accumulation_steps (int): Number of steps to accumulate gradients. mu_dtype (Optional[chex.ArrayDType]): Optional datatype for the first moment estimates. + clip_grad (Optional[float]): If provided, gradients will be clipped to this maximum norm. Returns: Tuple[optax.GradientTransformation, optax.Schedule]: The optimizer and scheduler. @@ -69,6 +81,8 @@ def get_lion_with_linear_scheduler( b2=b2, mu_dtype=mu_dtype, gradient_accumulation_steps=gradient_accumulation_steps, + clip_grad=clip_grad, + **kwargs, ) return tx, scheduler @@ -83,6 +97,8 @@ def get_lion_with_warmup_linear_scheduler( mu_dtype: Optional[chex.ArrayDType] = None, warmup_steps: int = 100, warmup_init_value: float = 5e-8, + clip_grad: Optional[float] = None, + **kwargs, ) -> Tuple[optax.GradientTransformation, optax.Schedule]: """ Creates a Lion optimizer with a warm-up linear learning rate scheduler. @@ -97,6 +113,7 @@ def get_lion_with_warmup_linear_scheduler( mu_dtype (Optional[chex.ArrayDType]): Optional datatype for the first moment estimates. warmup_steps (int): Number of warm-up steps. warmup_init_value (float): Initial learning rate for warm-up. + clip_grad (Optional[float]): If provided, gradients will be clipped to this maximum norm. Returns: Tuple[optax.GradientTransformation, optax.Schedule]: The optimizer and scheduler. @@ -121,6 +138,8 @@ def get_lion_with_warmup_linear_scheduler( b2=b2, mu_dtype=mu_dtype, gradient_accumulation_steps=gradient_accumulation_steps, + clip_grad=clip_grad, + **kwargs, ) return tx, scheduler_combined @@ -134,6 +153,8 @@ def get_lion_with_cosine_scheduler( b2: float = 0.99, gradient_accumulation_steps: int = 1, mu_dtype: Optional[chex.ArrayDType] = None, + clip_grad: Optional[float] = None, + **kwargs, ) -> Tuple[optax.GradientTransformation, optax.Schedule]: """ Creates a Lion optimizer with a cosine learning rate scheduler. @@ -147,6 +168,7 @@ def get_lion_with_cosine_scheduler( b2 (float): The exponential decay rate for the second moment estimates. gradient_accumulation_steps (int): Number of steps to accumulate gradients. mu_dtype (Optional[chex.ArrayDType]): Optional datatype for the first moment estimates. + clip_grad (Optional[float]): If provided, gradients will be clipped to this maximum norm. Returns: Tuple[optax.GradientTransformation, optax.Schedule]: The optimizer and scheduler. @@ -167,6 +189,8 @@ def get_lion_with_cosine_scheduler( b2=b2, mu_dtype=mu_dtype, gradient_accumulation_steps=gradient_accumulation_steps, + clip_grad=clip_grad, + **kwargs, ) return tx, scheduler @@ -182,6 +206,8 @@ def get_lion_with_warmup_cosine_scheduler( warmup_steps: int = 100, mu_dtype: Optional[chex.ArrayDType] = None, warmup_init_value: float = 0.5e-7, + clip_grad: Optional[float] = None, + **kwargs, ) -> Tuple[optax.GradientTransformation, optax.Schedule]: """ Creates a Lion optimizer with a warm-up cosine learning rate scheduler. @@ -197,6 +223,7 @@ def get_lion_with_warmup_cosine_scheduler( warmup_steps (int): Number of warm-up steps. mu_dtype (Optional[chex.ArrayDType]): Optional datatype for the first moment estimates. warmup_init_value (float): Initial learning rate for warm-up. + clip_grad (Optional[float]): If provided, gradients will be clipped to this maximum norm. Returns: Tuple[optax.GradientTransformation, optax.Schedule]: The optimizer and scheduler. @@ -215,5 +242,7 @@ def get_lion_with_warmup_cosine_scheduler( b2=b2, mu_dtype=mu_dtype, gradient_accumulation_steps=gradient_accumulation_steps, + clip_grad=clip_grad, + **kwargs, ) return tx, scheduler diff --git a/src/fjformer/optimizers/rmsprop.py b/src/fjformer/optimizers/rmsprop.py index f35eb0a..0cb58f5 100644 --- a/src/fjformer/optimizers/rmsprop.py +++ b/src/fjformer/optimizers/rmsprop.py @@ -1,3 +1,4 @@ +import warnings from typing import Optional import optax @@ -13,6 +14,8 @@ def get_rmsprop_with_cosine_scheduler( eps: float = 1e-8, weight_decay: float = 1e-1, gradient_accumulation_steps: int = 1, + clip_grad: Optional[float] = None, + **kwargs, ) -> tuple[optax.GradientTransformation, optax.Schedule]: """ Creates an RMSprop optimizer with a cosine learning rate scheduler. @@ -27,16 +30,19 @@ def get_rmsprop_with_cosine_scheduler( eps: A small value added to the denominator for numerical stability. weight_decay: The weight decay rate. gradient_accumulation_steps: The number of steps to accumulate gradients over. + clip_grad (Optional[float]): If provided, gradients will be clipped to this maximum norm. Returns: A tuple containing the optimizer and the learning rate scheduler. """ + for kwarg in kwargs.keys(): + warnings.warn(f"Key {kwarg} is not used for optimizer.") scheduler = optax.cosine_decay_schedule( init_value=learning_rate, decay_steps=steps, ) - tx = optax.chain( + chain = [ optax.scale_by_rms( decay=decay, eps=eps, @@ -55,7 +61,11 @@ def get_rmsprop_with_cosine_scheduler( optax.add_decayed_weights( weight_decay=weight_decay, ), - ) + ] + + if clip_grad is not None: + chain.insert(0, optax.clip_by_global_norm(clip_grad)) + tx = optax.chain(*chain) if gradient_accumulation_steps > 1: tx = optax.MultiSteps(tx, gradient_accumulation_steps) return tx, scheduler @@ -72,6 +82,8 @@ def get_rmsprop_with_linear_scheduler( eps: float = 1e-8, weight_decay: float = 1e-1, gradient_accumulation_steps: int = 1, + clip_grad: Optional[float] = None, + **kwargs, ) -> tuple[optax.GradientTransformation, optax.Schedule]: """ Creates an RMSprop optimizer with a linear learning rate scheduler. @@ -87,17 +99,20 @@ def get_rmsprop_with_linear_scheduler( eps: A small value added to the denominator for numerical stability. weight_decay: The weight decay rate. gradient_accumulation_steps: The number of steps to accumulate gradients over. + clip_grad (Optional[float]): If provided, gradients will be clipped to this maximum norm. Returns: A tuple containing the optimizer and the learning rate scheduler. """ + for kwarg in kwargs.keys(): + warnings.warn(f"Key {kwarg} is not used for optimizer.") scheduler = optax.linear_schedule( init_value=learning_rate_start, end_value=learning_rate_end, transition_steps=steps, ) - tx = optax.chain( + chain = [ optax.scale_by_rms( decay=decay, eps=eps, @@ -116,7 +131,10 @@ def get_rmsprop_with_linear_scheduler( optax.add_decayed_weights( weight_decay=weight_decay, ), - ) + ] + if clip_grad is not None: + chain.insert(0, optax.clip_by_global_norm(clip_grad)) + tx = optax.chain(*chain) if gradient_accumulation_steps > 1: tx = optax.MultiSteps(tx, gradient_accumulation_steps) return tx, scheduler @@ -134,6 +152,8 @@ def get_rmsprop_with_warmup_linear_scheduler( weight_decay: float = 1e-1, gradient_accumulation_steps: int = 1, warmup_steps: int = 100, + clip_grad: Optional[float] = None, + **kwargs, ) -> tuple[optax.GradientTransformation, optax.Schedule]: """ Creates an RMSprop optimizer with a linear learning rate scheduler with warmup. @@ -150,10 +170,13 @@ def get_rmsprop_with_warmup_linear_scheduler( weight_decay: The weight decay rate. gradient_accumulation_steps: The number of steps to accumulate gradients over. warmup_steps: The number of warmup steps. + clip_grad (Optional[float]): If provided, gradients will be clipped to this maximum norm. Returns: A tuple containing the optimizer and the learning rate scheduler. """ + for kwarg in kwargs.keys(): + warnings.warn(f"Key {kwarg} is not used for optimizer.") scheduler_warmup = optax.linear_schedule( init_value=5e-8, end_value=learning_rate_start, @@ -170,7 +193,7 @@ def get_rmsprop_with_warmup_linear_scheduler( boundaries=[warmup_steps], ) - tx = optax.chain( + chain = [ optax.scale_by_rms( decay=decay, eps=eps, @@ -189,7 +212,11 @@ def get_rmsprop_with_warmup_linear_scheduler( optax.add_decayed_weights( weight_decay=weight_decay, ), - ) + ] + + if clip_grad is not None: + chain.insert(0, optax.clip_by_global_norm(clip_grad)) + tx = optax.chain(*chain) if gradient_accumulation_steps > 1: tx = optax.MultiSteps(tx, gradient_accumulation_steps) return tx, scheduler_combined @@ -208,6 +235,8 @@ def get_rmsprop_with_warmup_cosine_scheduler( exponent: float = 1.0, gradient_accumulation_steps: int = 1, warmup_steps: int = 100, + clip_grad: Optional[float] = None, + **kwargs, ) -> tuple[optax.GradientTransformation, optax.Schedule]: """ Creates an RMSprop optimizer with a cosine learning rate scheduler with warmup. @@ -225,10 +254,13 @@ def get_rmsprop_with_warmup_cosine_scheduler( exponent: The exponent to use for the cosine decay. gradient_accumulation_steps: The number of steps to accumulate gradients over. warmup_steps: The number of warmup steps. + clip_grad (Optional[float]): If provided, gradients will be clipped to this maximum norm. Returns: A tuple containing the optimizer and the learning rate scheduler. """ + for kwarg in kwargs.keys(): + warnings.warn(f"Key {kwarg} is not used for optimizer.") scheduler = optax.warmup_cosine_decay_schedule( init_value=0.5e-7, peak_value=learning_rate, @@ -238,7 +270,7 @@ def get_rmsprop_with_warmup_cosine_scheduler( exponent=exponent, ) - tx = optax.chain( + chain = [ optax.scale_by_rms( decay=decay, eps=eps, @@ -257,7 +289,11 @@ def get_rmsprop_with_warmup_cosine_scheduler( optax.add_decayed_weights( weight_decay=weight_decay, ), - ) + ] + + if clip_grad is not None: + chain.insert(0, optax.clip_by_global_norm(clip_grad)) + tx = optax.chain(*chain) if gradient_accumulation_steps > 1: tx = optax.MultiSteps(tx, gradient_accumulation_steps) return tx, scheduler diff --git a/test/array4bit.py b/test/array4bit.py index b84e924..7d8fbae 100644 --- a/test/array4bit.py +++ b/test/array4bit.py @@ -90,10 +90,10 @@ def main(): init_x = jax.random.normal(rng.rng, (1, 64)) x = jax.random.normal(rng.rng, (1, 64)) params = model.init(rng.rng, init_x) - model: Callable = jax.jit(implicit_compact(model.apply)) + model_apply: Callable = jax.jit(implicit_compact(model.apply)) q_params = quantize_params(params) - q_out = float(model(q_params, x).reshape(-1)[0]) - out = float(model(params, x).reshape(-1)[0]) + q_out = float(model_apply(q_params, x).reshape(-1)[0]) + out = float(model_apply(params, x).reshape(-1)[0]) print(f"Original Model Output: {out:.3e}") print(f"Quantized Model Output: {q_out:.3e}")