Skip to content

Commit

Permalink
ENH Improve HRA speed and docs (#2160)
Browse files Browse the repository at this point in the history
  • Loading branch information
DaShenZi721 authored Oct 21, 2024
1 parent e8259ff commit d5f4e6d
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 4 deletions.
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@
title: FourierFT
- local: package_reference/vblora
title: VB-LoRA
- local: package_reference/hra
title: HRA

title: Adapters
- sections:
Expand Down
13 changes: 13 additions & 0 deletions docs/source/conceptual_guides/adapter.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,16 @@ A set of of learnable adaption prompts are prefixed to the input instruction tok
<small><a href="https://hf.co/papers/2303.16199">LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention</a></small>

To avoid adding noise to the tokens, the adapter uses zero-initialized attention. On top of this, the adapter adds a learnable gating factor (initialized with zeros) to progressively add information to the model during training. This prevents overwhelming the model's pretrained knowledge with the newly learned instructions.

## Householder Reflection Adaptation (HRA)

[HRA](https://huggingface.co/papers/2405.17484) provides a new perspective connecting LoRA to OFT, which means it can harness the advantages of both strategies, reduce parameters and computation costs while penalizing the loss of pre-training knowledge.

<div class="flex justify-center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/peft/hra.png"/>
</div>
<small><a href="https://huggingface.co/papers/2405.17484">Bridging The Gap between Low-rank and Orthogonal Adaptation via Householder Reflection Adaptation</a></small>

HRA constructs a chain of `r` trainable Householder reflections (HRs). Because the Householder reflection matrix is an orthogonal matrix and the product of orthogonal matrices is also an orthogonal matrix, HRA satisfies the theoretical guarantee of Orthogonal Finetuning (OFT). Meanwhile, HRA can also be viewed as an low-rank fine-tuning adapter by rewriting formula.

The higher `r`, the more trainable parameters, resulting in a larger model capacity and better performance. Besides, due to the chain structure, the orthogonality of HR planes impacts the capacity and regularity of HRA. To achieve a trade-off between the model capacity and regularity, an orthogonality regularizer of the HR planes is added to the loss function. The weight \\(\lambda\\) can control the strength of the regularizer.
32 changes: 32 additions & 0 deletions docs/source/package_reference/hra.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# Bridging The Gap between Low-rank and Orthogonal Adaptation via Householder Reflection Adaptation (HRA)

[HRA](https://huggingface.co/papers/2405.17484) is a simple but effective adapter-based fine-tuning method by leveraging Householder reflections. This method harnesses the advantages of both strategies, reducing parameters and computation costs while penalizing the loss of pre-training knowledge. It consistently achieves better performance with fewer trainable parameters and outperforms state-of-the-art adapters across different models, including large language models (LLMs) and conditional image generators.


The abstract from the paper is:

> While following different technical routes, both low-rank and orthogonal adaptation techniques can efficiently adapt large-scale pre-training models in specific tasks or domains based on a small piece of trainable parameters. In this study, we bridge the gap between these two techniques, proposing a simple but effective adaptation method based on Householder reflections. Given a pre-trained model, our method fine-tunes its layers by multiplying each frozen weight matrix with an orthogonal matrix constructed by a chain of learnable Householder reflections (HRs). This HR-based orthogonal fine-tuning is equivalent to an adaptive low-rank adaptation. Moreover, we show that the orthogonality of the reflection planes corresponding to the HRs impacts the model capacity and regularity. The analysis motivates us to regularize the orthogonality of the HRs, leading to different implementations of the proposed Householder reflection adaptation (HRA) method. Compared with state-of-the-art methods, HRA achieves superior performance with fewer learnable parameters when adapting large language models and conditional image generators. The code is available at [peft](https://github.com/huggingface/peft/tree/main/src/peft/tuners/hra) and [HRA](https://github.com/DaShenZi721/HRA).

## HRAConfig

[[autodoc]] tuners.hra.config.HRAConfig

## HRAModel

[[autodoc]] tuners.hra.model.HRAModel
13 changes: 9 additions & 4 deletions src/peft/tuners/hra/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def get_delta_weight(self, adapter_name: str, reverse: bool = False) -> torch.Te

for i in indices:
ui = opt_u[:, i].view(-1, 1)
weight = weight @ (torch.eye(shape[0], device=opt_u.device, dtype=opt_u.dtype) - 2 * ui @ ui.t())
weight = weight - 2 * weight @ ui @ ui.t()

return weight

Expand Down Expand Up @@ -384,7 +384,7 @@ def get_delta_weight(self, adapter_name: str, reverse: bool = False) -> torch.Te

for i in indices:
ui = opt_u[:, i].view(-1, 1)
weight = weight @ (torch.eye(shape[0], device=opt_u.device, dtype=opt_u.dtype) - 2 * ui @ ui.t())
weight = weight - 2 * weight @ ui @ ui.t()

return weight

Expand All @@ -399,7 +399,9 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
result = self.base_layer(x, *args, **kwargs)
else:
new_weight = torch.eye(
self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0], device=x.device
self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0],
device=x.device,
dtype=previous_dtype,
)
for active_adapter in self.active_adapters:
if active_adapter not in self.hra_u.keys():
Expand All @@ -416,7 +418,10 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
)
new_weight = torch.mm(orig_weight, new_weight)
new_weight = new_weight.view(
self.out_features, self.in_features, self.base_layer.kernel_size[0], self.base_layer.kernel_size[0]
self.out_features,
self.in_features,
self.base_layer.kernel_size[0],
self.base_layer.kernel_size[0],
)

result = F.conv2d(
Expand Down

0 comments on commit d5f4e6d

Please sign in to comment.