eformer (EasyDel Former) is a utility library designed to simplify and enhance the development of machine learning models using JAX. It provides a collection of tools for sharding, custom PyTrees, quantization, mixed precision training, and optimized operations, making it easier to build and scale models efficiently.
- Mixed Precision Training (
mpric
): Advanced mixed precision utilities supporting float8, float16, and bfloat16 with dynamic loss scaling. - Sharding Utilities (
escale
): Tools for efficient sharding and distributed computation in JAX. - Custom PyTrees (
jaximus
): Enhanced utilities for creating custom PyTrees andArrayValue
objects, updated from Equinox. - Custom Calling (
callib
): A tool for custom function calls and direct integration with Triton kernels in JAX. - Optimizer Factory: A flexible factory for creating and configuring optimizers like AdamW, Adafactor, Lion, and RMSProp.
- Custom Operations and Kernels:
- Flash Attention 2 for GPUs/TPUs (via Triton and Pallas).
- 8-bit and NF4 quantization for efficient model.
- Many others to be added.
- Quantization Support: Tools for 8-bit and NF4 quantization, enabling memory-efficient model deploymen