Skip to content

Commit

Permalink
Preparing For version Update 0.0.70
Browse files Browse the repository at this point in the history
  • Loading branch information
erfanzar committed Jul 17, 2024
1 parent 063299d commit c0662f6
Show file tree
Hide file tree
Showing 12 changed files with 417 additions and 305 deletions.
345 changes: 109 additions & 236 deletions .vscode/PythonImportHelper-v2-Completion.json

Large diffs are not rendered by default.

90 changes: 65 additions & 25 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,36 +1,76 @@
# FJFormer

Embark on a journey of paralleled/unparalleled computational prowess with FJFormer - an arsenal of custom Jax Flax
Functions and Utils that elevate your AI endeavors to new heights!
[![PyPI version](https://badge.fury.io/py/fjformer.svg)](https://badge.fury.io/py/fjformer)
[![Documentation Status](https://readthedocs.org/projects/fjformer/badge/?version=latest)](https://fjformer.readthedocs.io/en/latest/?badge=latest)
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)

## Overview
FJFormer is a powerful and flexible JAX-based package designed to accelerate and simplify machine learning and deep learning workflows. It provides a comprehensive suite of tools and utilities for efficient model development, training, and deployment.

FJFormer is a collection of functions and utilities that can help with various tasks when using Flax and JAX. It
includes
checkpoint savers, partitioning tools, and other helpful functions.
The goal of FJFormer is to make your life easier when working with Flax and JAX. Whether you are training a new model,
fine-tuning an existing one, or just exploring the capabilities of these powerful frameworks, FJFormer offers
## Features

- Pallas Kernels for GPU,TPU
- BITComputations for 8,6,4 BIT Flax Models
- Built-in functions and Loss functions
- Distributed and sharding Model Loaders and Checkpoint Savers
- Monitoring Utils for *TPU/GPU/CPU* memory `foot-print`
- Optimizers
- Special Optimizers with schedulers and Easy to Use
- Partitioning Utils
- LoRA
### 1. JAX Sharding Utils
Leverage the power of distributed computing and model parallelism with our advanced JAX sharding utilities. These tools enable efficient splitting and management of large models across multiple devices, enhancing performance and enabling the training of larger models.

And a lot of these features are fully documented so FJFormer has something
to offer, and it's not just a Computation BackEnd for [EasyDel](https://github.com/erfanzar/EasyDel).
### 2. Custom Pallas / Triton Operation Kernels
Boost your model's performance with our optimized kernels for specific operations. These custom-built kernels, implemented using Pallas and Triton, provide significant speedups for common bottleneck operations in deep learning models.

checkout for documentation [here](https://fjformer.readthedocs.io/en/latest/).
### 3. Pre-built Optimizers
Jump-start your training with our collection of ready-to-use, efficiently implemented optimization algorithms:
- **AdamW**: An Adam variant with decoupled weight decay.
- **Adafactor**: Memory-efficient adaptive optimization algorithm.
- **Lion**: Recently proposed optimizer combining the benefits of momentum and adaptive methods.
- **RMSprop**: Adaptive learning rate optimization algorithm.

## Contributing
### 4. Utility Functions
A rich set of utility functions to streamline your workflow, including:
- Various loss functions (e.g., cross-entropy)
- Metrics calculation
- Data preprocessing tools

FJFormer is an open-source project, and contributions are always welcome! If you have a feature request, bug report, or
just want to help out with development, please check out our GitHub repository and feel free to submit a pull request or
open an issue.
### 5. ImplicitArray
Our innovative ImplicitArray class provides a powerful abstraction for representing and manipulating large arrays without instantiation. Benefits include:
- Lazy evaluation for memory efficiency
- Optimized array operations in JAX
- Seamless integration with other FJFormer components

Thank you for using FJFormer, and happy training!
### 6. Custom Dtypes

- Implement 4-bit quantization (NF4) effortlessly using our Array4Bit class, built on top of ImplicitArray. Reduce model size and increase inference speed without significant loss in accuracy.

- Similar to Array4Bit, our Array8Bit implementation offers 8-bit quantization via ImplicitArray, providing a balance between model compression and precision.

### 7. LoRA (Low-Rank Adaptation)
Efficiently fine-tune large language models with our LoRA implementation, leveraging ImplicitArray for optimal performance and memory usage.

### 8. JAX and Array Manipulation
A comprehensive set of tools and utilities for efficient array operations and manipulations in JAX, designed to complement and extend JAX's native capabilities.

### 9. Checkpoint Managers
Robust utilities for managing model checkpoints, including:
- Efficient saving and loading of model states
- Version control for checkpoints
- Integration with distributed training workflows

## Installation

You can install FJFormer using pip:

```bash
pip install fjformer
```

For the latest development version, you can install directly from GitHub:

```bash
pip install git+https://github.com/yourusername/fjformer.git
```

## Documentation

For detailed documentation, including API references, please visit:

[https://fjformer.readthedocs.org](https://fjformer.readthedocs.org)

## License

FJFormer is released under the Apache License 2.0. See the [LICENSE](LICENSE) file for more details.
43 changes: 43 additions & 0 deletions docs/contributing.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
Contributing to FJFormer
==========
Thank you for considering contributing to FJFormer! We welcome your input. To ensure a smooth collaboration, please review and adhere to the following guidelines.


How to Contribute
------
To contribute to EasyDeL, follow these steps:
1. Fork the repository.
2. Create a new branch for your feature or bug fix.
3. Make your changes and commit them with clear and descriptive messages.
4. Push your changes to your branch in your forked repository.
5. Submit a pull request to the main EasyDeL repository, detailing the changes you've made and the problem it solves.


Code of Conduct
------
Please adhere to the `Apache Code of Conduct <https://www.apache.org/foundation/policies/conduct.html>`_ in all interactions related to EasyDeL.

Reporting Bugs
------
If you encounter a bug, please open an issue on the EasyDeL repository, providing a clear and detailed description of the issue, including steps to reproduce it.

Suggesting Enhancements
------
If you have ideas for enhancements, feel free to open an issue on the EasyDeL repository. Provide a clear and detailed description of your proposed enhancement.

Development Setup
------
To set up EasyDeL for development, follow the instructions in the README.md file.

Pull Request Guidelines
------
When submitting a pull request, please ensure the following:
- Your code follows the project's coding standards.
- Your commits are accompanied by clear and descriptive messages.
- Your pull request addresses a single issue or feature.

License
------
By contributing to EasyDeL, you agree that your contributions will be licensed under the Apache License, Version 2.0.

Thank you for your interest in contributing to EasyDeL! We appreciate your support.
68 changes: 50 additions & 18 deletions docs/index.rst
Original file line number Diff line number Diff line change
@@ -1,26 +1,52 @@
FJFormer 🔮
==========
Embark on a journey of paralleled/unparalleled computational prowess with FJFormer - an arsenal of custom Jax Flax
Functions and Utils that elevate your AI endeavors to new heights!
FJFormer is a powerful and flexible JAX-based package designed to accelerate and simplify machine learning and deep learning workflows. It provides a comprehensive suite of tools and utilities for efficient model development, training, and deployment.

Overview
Features
----------
FJFormer is a collection of functions and utilities that can help with various tasks when using Flax and JAX. It
includes
checkpoint savers, partitioning tools, and other helpful functions.
The goal of FJFormer is to make your life easier when working with Flax and JAX. Whether you are training a new model,
fine-tuning an existing one, or just exploring the capabilities of these powerful frameworks, FJFormer offers

- Pallas Kernels for GPU,TPU
- BITComputations for 8,6,4 BIT Flax Models
- Built-in functions and Loss functions
- Distributed and sharding Model Loaders and Checkpoint Savers
- Monitoring Utils for *TPU/GPU/CPU* memory `foot-print`
- Optimizers
- Special Optimizers with schedulers and Easy to Use
- Partitioning Utils
- LoRA

1. JAX Sharding Utils
Leverage the power of distributed computing and model parallelism with our advanced JAX sharding utilities. These tools enable efficient splitting and management of large models across multiple devices, enhancing performance and enabling the training of larger models.

2. Custom Pallas / Triton Operation Kernels
Boost your model's performance with our optimized kernels for specific operations. These custom-built kernels, implemented using Pallas and Triton, provide significant speedups for common bottleneck operations in deep learning models.

3. Pre-built Optimizers
Jump-start your training with our collection of ready-to-use, efficiently implemented optimization algorithms:
- **AdamW**: An Adam variant with decoupled weight decay.
- **Adafactor**: Memory-efficient adaptive optimization algorithm.
- **Lion**: Recently proposed optimizer combining the benefits of momentum and adaptive methods.
- **RMSprop**: Adaptive learning rate optimization algorithm.

4. Utility Functions
A rich set of utility functions to streamline your workflow, including:
- Various loss functions (e.g., cross-entropy)
- Metrics calculation
- Data preprocessing tools

5. ImplicitArray
Our innovative ImplicitArray class provides a powerful abstraction for representing and manipulating large arrays without instantiation. Benefits include:
- Lazy evaluation for memory efficiency
- Optimized array operations in JAX
- Seamless integration with other FJFormer components

6. Custom Dtypes

- Implement 4-bit quantization (NF4) effortlessly using our Array4Bit class, built on top of ImplicitArray. Reduce model size and increase inference speed without significant loss in accuracy.

- Similar to Array4Bit, our Array8Bit implementation offers 8-bit quantization via ImplicitArray, providing a balance between model compression and precision.

7. LoRA (Low-Rank Adaptation)
Efficiently fine-tune large language models with our LoRA implementation, leveraging ImplicitArray for optimal performance and memory usage.

8. JAX and Array Manipulation
A comprehensive set of tools and utilities for efficient array operations and manipulations in JAX, designed to complement and extend JAX's native capabilities.

9. Checkpoint Managers
Robust utilities for managing model checkpoints, including:
- Efficient saving and loading of model states
- Version control for checkpoints
- Integration with distributed training workflows
.. _FJFormer:

Zare Chavoshi, Erfan. "FJFormer is a collection of functions and utilities that can help with various tasks when using Flax and JAX.""
Expand All @@ -34,3 +60,9 @@ Zare Chavoshi, Erfan. "FJFormer is a collection of functions and utilities that
api_docs/APIs


.. toctree::
:hidden:
:maxdepth: 1
:caption: Getting Started

contributing
13 changes: 10 additions & 3 deletions src/fjformer/core/implicit_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from dataclasses import dataclass, field, fields, is_dataclass
from functools import partial, wraps
from itertools import chain, count
from typing import ClassVar, Optional, Callable, Any, Tuple
from fjformer.core.errors import UninitializedAval
from typing import Any, Callable, ClassVar, Optional, Tuple

import jax
import jax.extend.linear_util as lu
import jax.interpreters.partial_eval as pe
Expand All @@ -27,9 +27,16 @@
from jax.tree_util import register_pytree_with_keys_class
from plum import Dispatcher, Function

from fjformer.core.errors import UninitializedAval

_dispatch = Dispatcher()
_primitive_ids = count()

warnings.filterwarnings(
"ignore", message="Could not resolve the type hint of `~B`", module="plum.type"
)
warnings.filterwarnings(
"ignore", message="Could not resolve the type hint of `~A`", module="plum.type"
)

class ArrayValue(ABC):
pass
Expand Down
16 changes: 12 additions & 4 deletions src/fjformer/custom_array/array4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,12 +467,20 @@ def handle_transpose(
"""
original_quantized = False
if isinstance(operand, Array4Bit):
operand = operand.materialize()
array = operand.materialize()
original_quantized = True
operand = lax.transpose(operand, *args, **kwargs)
else:
array = operand
array = lax.transpose(array, *args, **kwargs)
if original_quantized:
operand = Array4Bit.quantize(operand, dtype=operand.dtype)
return operand
array = Array4Bit.quantize(
array=array,
block_size=operand.block_size,
contraction_axis=operand.contraction_axis,
dtype=operand.dtype,
factors=operand.factors,
)
return array


@core.primitive_handler("conv_general_dilated")
Expand Down
5 changes: 4 additions & 1 deletion src/fjformer/custom_array/array8bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,10 @@ def materialize(self) -> Array:

@classmethod
def quantize(
cls, array: Array, axis: int = -1, dtype: Optional[jnp.dtype] = None
cls,
array: Array,
axis: int = -1,
dtype: Optional[jnp.dtype] = None,
) -> "Array8Bit":
"""
Quantize a JAX array to 8-bit representation.
Expand Down
Loading

0 comments on commit c0662f6

Please sign in to comment.