-
-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Kye
committed
Nov 29, 2023
1 parent
6f029ba
commit b62e95c
Showing
5 changed files
with
212 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
from unittest.mock import Mock | ||
|
||
import pytest | ||
import torch | ||
from torch import nn | ||
|
||
from zeta.nn.modules.lang_conv_module import ConvolutionLanguageBlock | ||
|
||
|
||
# 1. Basic Tests | ||
def test_convolution_language_block_creation(): | ||
block = ConvolutionLanguageBlock(256, 512, 3, 1) | ||
assert isinstance(block, ConvolutionLanguageBlock) | ||
|
||
|
||
def test_forward_pass(): | ||
block = ConvolutionLanguageBlock(256, 512, 3, 1) | ||
x = torch.randn(1, 256, 1024) | ||
output = block(x) | ||
assert output.shape == torch.Size([1, 512, 1024]) | ||
|
||
|
||
# 2. Utilize Fixtures | ||
@pytest.fixture | ||
def sample_block(): | ||
return ConvolutionLanguageBlock(128, 256, 3, 1) | ||
|
||
|
||
def test_fixture_usage(sample_block): | ||
x = torch.randn(1, 128, 1024) | ||
output = sample_block(x) | ||
assert output.shape == torch.Size([1, 256, 1024]) | ||
|
||
|
||
# 3. Parameterized Testing | ||
@pytest.mark.parametrize( | ||
( | ||
"in_channels, out_channels, kernel_size, padding, depth, stride," | ||
" activation, batchnorm, dilation, dropout" | ||
), | ||
[ | ||
(128, 256, 3, 1, 2, 1, "relu", True, 1, 0.1), | ||
(256, 512, 3, 1, 3, 1, "gelu", False, 2, 0.2), | ||
# Add more parameter combinations as needed | ||
], | ||
) | ||
def test_parameterized_block( | ||
in_channels, | ||
out_channels, | ||
kernel_size, | ||
padding, | ||
depth, | ||
stride, | ||
activation, | ||
batchnorm, | ||
dilation, | ||
dropout, | ||
): | ||
block = ConvolutionLanguageBlock( | ||
in_channels, | ||
out_channels, | ||
kernel_size, | ||
padding, | ||
depth, | ||
stride, | ||
activation, | ||
batchnorm, | ||
dilation, | ||
dropout, | ||
) | ||
x = torch.randn(1, in_channels, 1024) | ||
output = block(x) | ||
assert output.shape == torch.Size([1, out_channels, 1024]) | ||
|
||
|
||
def test_with_mocked_convolution_layer(): | ||
mock_convolution = Mock(spec=nn.Conv1d) | ||
block = ConvolutionLanguageBlock(128, 256, 3, 1) | ||
block.conv_layers[0] = mock_convolution | ||
x = torch.randn(1, 128, 1024) | ||
output = block(x) | ||
assert mock_convolution.called | ||
|
||
|
||
# 5. Exception Testing | ||
def test_invalid_activation_raises_error(): | ||
with pytest.raises(ValueError): | ||
ConvolutionLanguageBlock( | ||
128, 256, 3, 1, activation="invalid_activation" | ||
) | ||
|
||
|
||
# 6. Test Coverage (requires pytest-cov) | ||
def test_coverage(): | ||
pytest.main(["--cov=your_module", "test_your_module.py"]) | ||
|
||
|
||
# Add more tests as needed... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
import torch | ||
from torch import nn | ||
|
||
|
||
class ConvolutionLanguageBlock(nn.Module): | ||
""" | ||
Convolutional block for language modeling. | ||
-------------------------------------------- | ||
A convolutional block that consists of multiple 1D convolutional layers, | ||
optional batch normalization, dropout, and a flexible choice of activation functions. | ||
This block is designed to maintain the input's dimensionality through the network, | ||
making it suitable for tasks that require consistent input and output dimensions. | ||
Parameters: | ||
- in_channels (int): Number of channels in the input tensor. | ||
- out_channels (int): Number of channels produced by the convolution. | ||
- kernel_size (int): Size of the convolving kernel. | ||
- num_layers (int, optional): Number of convolutional layers. Default: 1 | ||
- stride (int, optional): Stride of the convolution. Default: 1 | ||
- padding (int, optional): Zero-padding added to both sides of the input. Default: 1 | ||
- dilation (int, optional): Spacing between kernel elements. Default: 1 | ||
- activation (str, optional): Type of activation function. Options: 'relu', 'gelu'. Default: 'relu' | ||
- use_batchnorm (bool, optional): If True, includes batch normalization. Default: False | ||
- dropout (float, optional): Dropout rate. Default: 0.0 | ||
Examples: | ||
>>> import torch | ||
>>> from attnconv.main import ConvolutionLanguageBlock | ||
>>> x = torch.randn(1, 512, 1024) | ||
>>> block = ConvolutionLanguageBlock(512, 512, 3, 1, 1, 1) | ||
>>> out = block(x) | ||
>>> out.shape | ||
torch.Size([1, 512, 1024]) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
in_channels, | ||
out_channels, | ||
kernel_size, | ||
padding, | ||
depth=1, | ||
stride=1, | ||
activation="gelu", | ||
batchnorm=False, | ||
dilation=1, | ||
dropout=0.1, | ||
): | ||
super(ConvolutionLanguageBlock, self).__init__() | ||
self.in_channels = in_channels | ||
self.out_channels = out_channels | ||
self.kernel_size = kernel_size | ||
self.padding = padding | ||
self.depth = depth | ||
self.stride = stride | ||
self.activation = activation | ||
self.batchnorm = batchnorm | ||
self.dilation = dilation | ||
|
||
layers = [] | ||
for _ in range(depth): | ||
layers.append( | ||
nn.Conv1d( | ||
in_channels, | ||
out_channels, | ||
kernel_size, | ||
stride=stride, | ||
padding=padding, | ||
dilation=dilation, | ||
) | ||
) | ||
if batchnorm: | ||
layers.append(nn.BatchNorm1d(out_channels)) | ||
if activation == "relu": | ||
layers.append(nn.ReLU()) | ||
elif activation == "gelu": | ||
layers.append(nn.GELU()) | ||
if dropout > 0: | ||
layers.append(nn.Dropout(dropout)) | ||
in_channels = out_channels # For stacking layers | ||
|
||
self.conv_layers = nn.Sequential(*layers) | ||
|
||
def forward(self, x): | ||
"""Forward pass with residual connection. | ||
Args: | ||
x (_type_): _description_ | ||
Returns: | ||
_type_: _description_ | ||
""" | ||
# Apply residual connection if dimensions match | ||
residual = x if x.size(1) == self.conv_layers[0].in_channels else None | ||
|
||
# Apply convolutional layers | ||
x = self.conv_layers(x) | ||
|
||
# Apply residual connection | ||
if residual is not None: | ||
x = x + residual | ||
|
||
# Return output | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters