Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Lora implementation for nn.Conv1d #2333

Merged
merged 6 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, *
base_layer = self.get_base_layer()
if isinstance(base_layer, nn.Linear):
in_features, out_features = base_layer.in_features, base_layer.out_features
elif isinstance(base_layer, nn.Conv1d):
in_features, out_features = base_layer.in_channels, base_layer.out_channels
elif isinstance(base_layer, nn.Conv2d):
in_features, out_features = base_layer.in_channels, base_layer.out_channels
elif isinstance(base_layer, nn.Conv3d):
Expand Down Expand Up @@ -1297,6 +1299,18 @@ def _get_dora_layer_class(self):
return DoraConv2dLayer


class Conv1d(_ConvNd):
# Lora implemented in a conv1d layer
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not self._kernel_dim == 3:
raise ValueError(f"Conv1d layer kernel must have 3 dimensions, not {self._kernel_dim}")
self.conv_fn = F.conv1d
CCLDArjun marked this conversation as resolved.
Show resolved Hide resolved

def _get_dora_layer_class(self):
raise NotImplementedError


class Conv3d(_ConvNd):
# Lora implemented in a conv3d layer
def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -1679,6 +1693,9 @@ def dispatch_default(
elif isinstance(target_base_layer, torch.nn.Conv3d):
kwargs.update(lora_config.loftq_config)
new_module = Conv3d(target, adapter_name, **kwargs)
elif isinstance(target_base_layer, nn.Conv1d):
kwargs.update(lora_config.loftq_config)
new_module = Conv1d(target, adapter_name, **kwargs)
elif isinstance(target_base_layer, torch.nn.MultiheadAttention):
kwargs.update(lora_config.loftq_config)
new_module = MultiheadAttention(target, adapter_name, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def dynamic_dispatch_func(target, adapter_name, lora_config, **kwargs):
# no module could be matched
raise ValueError(
f"Target module {target} is not supported. Currently, only the following modules are supported: "
"`torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv2d`, `torch.nn.Conv3d`, "
"`torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv1d`, `torch.nn.Conv2d`, `torch.nn.Conv3d`, "
"`transformers.pytorch_utils.Conv1D`, `torch.nn.MultiheadAttention.`."
)

Expand Down
23 changes: 23 additions & 0 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
LoraConfig,
{"target_modules": ["emb", "conv1d"], "use_dora": True},
),
("Conv1d LoRA", "Conv1d", LoraConfig, {"target_modules": ["conv1d"]}),
("Conv2d 1 LoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d"]}),
("Conv2d 2 LoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d", "lin0"]}),
("Conv2d 1 LoRA with DoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d"], "use_dora": True}),
Expand Down Expand Up @@ -810,6 +811,25 @@ def get_output_embeddings(self):
return None


class ModelConv1D(nn.Module):
def __init__(self):
super().__init__()
self.conv1d = nn.Conv1d(1, 1, 2)
self.relu = nn.ReLU()
self.flat = nn.Flatten()
self.lin0 = nn.Linear(9, 2)
self.sm = nn.LogSoftmax(dim=-1)

def forward(self, X):
X = X.float().reshape(-1, 1, 10)
X = self.conv1d(X)
X = self.relu(X)
X = self.flat(X)
X = self.lin0(X)
X = self.sm(X)
return X


class ModelConv2D(nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -910,6 +930,9 @@ def from_pretrained(cls, model_id, torch_dtype=None):
if model_id == "EmbConv1D":
return ModelEmbConv1D().to(torch_dtype)

if model_id == "Conv1d":
return ModelConv1D().to(torch_dtype)

if model_id == "Conv2d":
return ModelConv2D().to(torch_dtype)

Expand Down
Loading