Skip to content

Commit

Permalink
[FEAT][FractorialNet][FractorialBlock]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Nov 25, 2023
1 parent e3e5185 commit 7a5975d
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 5 deletions.
1 change: 1 addition & 0 deletions zeta/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from zeta.nn.modules.log_ff import LogFF, compute_entropy_safe
from zeta.nn.modules.polymorphic_neuron import PolymorphicNeuronLayer
from zeta.nn.modules.flexible_mlp import CustomMLP
from zeta.nn.modules.fractoril_net import

__all__ = [
"CNNNew",
Expand Down
85 changes: 80 additions & 5 deletions zeta/nn/modules/fractorial_net.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,83 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class FractorialBlock(nn.Module):
def __init__(self, in_channels, out_channels, depth: int = 3):
super(FractorialBlock, self).__init__()
class FractalBlock(nn.Module):
def __init__(self, in_channels, out_channels, depth=3):
"""
Initialize a Fractal Block.
:param in_channels: Number of input channels.
:param out_channels: Number of output channels.
:param depth: Depth of the fractal block.
"""
super(FractalBlock, self).__init__()
self.depth = depth

# Base case for recursion
if depth == 1:
self.block = nn.Conv2d(
in_channels, out_channels, kernel_size=3, padding=1
)
else:
# Recursive case: create smaller fractal blocks
self.block1 = FractalBlock(in_channels, out_channels, depth - 1)
self.block2 = FractalBlock(in_channels, out_channels, depth - 1)

def forward(self, x):
"""
Forward pass of the fractal block.
:param x: Input tensor.
:return: Output tensor.
"""
if self.depth == 1:
return self.block(x)
else:
# Recursively compute the outputs of the sub-blocks
out1 = self.block1(x)
out2 = self.block2(x)

# Combine the outputs of the sub-blocks
return out1 + out2


class FractalNetwork(nn.Module):
def __init__(self, in_channels, out_channels, num_blocks, block_depth):
"""
Initialize the Fractal Network.
:param in_channels: Number of input channels.
:param out_channels: Number of output channels.
:param num_blocks: Number of fractal blocks in the network.
:param block_depth: Depth of each fractal block.
"""
super(FractalNetwork, self).__init__()
self.blocks = nn.ModuleList(
[
FractalBlock(
in_channels if i == 0 else out_channels,
out_channels,
block_depth,
)
for i in range(num_blocks)
]
)
self.final_layer = nn.Conv2d(out_channels, out_channels, kernel_size=1)

def forward(self, x):
"""
Forward pass of the fractal network.
:param x: Input tensor.
:return: Output tensor.
"""
for block in self.blocks:
x = block(x)
return self.final_layer(x)


# # Example usage
# fractal_net = FractalNetwork(in_channels=3, out_channels=16, num_blocks=4, block_depth=3)

# # Example input
# input_tensor = torch.randn(1, 3, 64, 64)

# # Forward pass
# output = fractal_net(input_tensor)
# print(output)

0 comments on commit 7a5975d

Please sign in to comment.