Skip to content

Commit

Permalink
Add special case of Conv1d circular padding multiple channels
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Aug 3, 2020
1 parent d69c751 commit 61ed6e0
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 3 deletions.
20 changes: 19 additions & 1 deletion tests/test_special.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,25 @@ def test_conv1d_circular_singlechannel(self):
weight = conv.weight
input = torch.randn(batch_size, 1, n)
out_torch = conv(input)
b = torch_butterfly.special.conv1d_circular_singlechannel(n, weight)
for separate_diagonal in [True, False]:
b = torch_butterfly.special.conv1d_circular_singlechannel(n, weight,
separate_diagonal)
out = b(input)
self.assertTrue(torch.allclose(out, out_torch, self.rtol, self.atol))

def test_conv1d_circular_multichannel(self):
batch_size = 10
in_channels = 3
out_channels = 4
for n in [13, 16]:
for kernel_size in [1, 3, 5, 7]:
conv = nn.Conv1d(in_channels, out_channels, kernel_size,
padding=(kernel_size - 1) // 2, padding_mode='circular',
bias=False)
weight = conv.weight
input = torch.randn(batch_size, in_channels, n)
out_torch = conv(input)
b = torch_butterfly.special.conv1d_circular_multichannel(n, weight)
out = b(input)
self.assertTrue(torch.allclose(out, out_torch, self.rtol, self.atol))

Expand Down
5 changes: 4 additions & 1 deletion torch_butterfly/diagonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torch
from torch import nn

from torch_butterfly.complex_utils import complex_mul


class Diagonal(nn.Module):

Expand All @@ -29,4 +31,5 @@ def forward(self, input):
Return:
output: (batch, *, size)
"""
return input * self.diagonal
# return input * self.diagonal
return complex_mul(input, self.diagonal)
73 changes: 72 additions & 1 deletion torch_butterfly/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch_butterfly.butterfly import Butterfly
from torch_butterfly.permutation import FixedPermutation, bitreversal_permutation
from torch_butterfly.diagonal import Diagonal
from torch_butterfly.complex_utils import real2complex, Real2Complex, Complex2Real
from torch_butterfly.complex_utils import complex_mul, real2complex, Real2Complex, Complex2Real


def fft(n, normalized=False, br_first=True, with_br_perm=True):
Expand Down Expand Up @@ -185,3 +185,74 @@ def conv1d_circular_singlechannel(n, weight, separate_diagonal=True):
padding = (kernel_size - 1) // 2
col = F.pad(weight.flip(dims=(-1,)), (0, n - kernel_size)).roll(-padding, dims=-1)
return circulant(col.squeeze(1).squeeze(0))


def conv1d_circular_multichannel(n, weight):
""" Construct an nn.Module based on Butterfly that exactly performs nn.Conv1d
with multiple in/out channels, with circular padding.
The output of nn.Conv1d must have the same size as the input (i.e. kernel size must be 2k + 1,
and padding k for some integer k).
Parameters:
n: size of the input.
weight: torch.Tensor of size (out_channels, in_channels, kernel_size). Kernel_size must be
odd, and smaller than n. Padding is assumed to be (kernel_size - 1) // 2.
"""
assert weight.dim() == 3, 'Weight must have dimension 3'
kernel_size = weight.shape[-1]
assert kernel_size < n
assert kernel_size % 2 == 1, 'Kernel size must be odd'
out_channels, in_channels = weight.shape[:2]
padding = (kernel_size - 1) // 2
col = F.pad(weight.flip(dims=(-1,)), (0, n - kernel_size)).roll(-padding, dims=-1)
# From here we mimic the circulant construction, but the diagonal multiply is replaced with
# multiply and then sum across the in-channels.
complex = col.is_complex()
log_n = int(math.ceil(math.log2(n)))
# For non-power-of-2, maybe there's a way to only pad up to size 1 << log_n?
# I've only figured out how to pad to size 1 << (log_n + 1).
# e.g., [a, b, c] -> [a, b, c, 0, 0, a, b, c]
n_extended = n if n == 1 << log_n else 1 << (log_n + 1)
b_fft = fft(n_extended, normalized=True, br_first=False, with_br_perm=False)
b_fft.in_size = n
b_ifft = ifft(n_extended, normalized=True, br_first=True, with_br_perm=False)
b_ifft.out_size = n
if n < n_extended:
col_0 = F.pad(col, (0, 2 * ((1 << log_n) - n)))
col = torch.cat((col_0, col), dim=-1)
if not col.is_complex():
col = real2complex(col)
# This fft must have normalized=False for the correct scaling. These are the eigenvalues of the
# circulant matrix.
col_f = torch.view_as_complex(torch.fft(torch.view_as_real(col),
signal_ndim=1, normalized=False))
br_perm = (bitreversal_permutation(n_extended, pytorch_format=True))
col_f = col_f[..., br_perm]
# We just want (input_f.unsqueeze(1) * col_f).sum(dim=2).
# This can be written as matrix multiply but Pytorch 1.6 doesn't yet support complex matrix
# multiply.

# We write this as an nn.Module just to use nn.Sequential
class DiagonalMultiplySum(nn.Module):
def __init__(self, diagonal_init):
"""
Parameters:
diagonal_init: (out_channels, in_channels, size)
"""
super().__init__()
self.diagonal = nn.Parameter(diagonal_init.detach().clone())

def forward(self, input):
"""
Parameters:
input: (batch, in_channels, size)
Return:
output: (batch, out_channels, size)
"""
# return input * self.diagonal
return complex_mul(input.unsqueeze(1), self.diagonal).sum(dim=2)

if not complex:
return nn.Sequential(Real2Complex(), b_fft, DiagonalMultiplySum(col_f), b_ifft,
Complex2Real())
else:
return nn.Sequential(b_fft, DiagonalMultiplySum(col_f), b_ifft)

0 comments on commit 61ed6e0

Please sign in to comment.