Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
CCLDArjun committed Jan 22, 2025
1 parent ba03b7a commit 7218296
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
LoraConfig,
{"target_modules": ["emb", "conv1d"], "use_dora": True},
),
("Conv1d LoRA", "Conv1d", LoraConfig, {"target_modules": ["conv1d"]}),
("Conv2d 1 LoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d"]}),
("Conv2d 2 LoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d", "lin0"]}),
("Conv2d 1 LoRA with DoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d"], "use_dora": True}),
Expand Down Expand Up @@ -810,6 +811,25 @@ def get_output_embeddings(self):
return None


class ModelConv1D(nn.Module):
def __init__(self):
super().__init__()
self.conv1d = Conv1D(5, 10)
self.relu = nn.ReLU()
self.flat = nn.Flatten()
self.lin0 = nn.Linear(10, 2)
self.sm = nn.LogSoftmax(dim=-1)

def forward(self, X):
X = X.float().reshape(-1, 5, 10)
X = self.conv1d(X)
X = self.relu(X)
X = self.flat(X)
X = self.lin0(X)
X = self.sm(X)
return X


class ModelConv2D(nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -910,6 +930,9 @@ def from_pretrained(cls, model_id, torch_dtype=None):
if model_id == "EmbConv1D":
return ModelEmbConv1D().to(torch_dtype)

if model_id == "Conv1d":
return ModelConv1D().to(torch_dtype)

if model_id == "Conv2d":
return ModelConv2D().to(torch_dtype)

Expand Down

0 comments on commit 7218296

Please sign in to comment.