Skip to content

Commit

Permalink
code quality
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Nov 9, 2023
1 parent aa061f2 commit 90b2692
Show file tree
Hide file tree
Showing 31 changed files with 395 additions and 337 deletions.
19 changes: 19 additions & 0 deletions code_quality.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/bin/bash

# Navigate to the directory containing the 'swarms_torch' folder
# cd /path/to/your/code/directory

# Run autopep8 with max aggressiveness (-aaa) and in-place modification (-i)
# on all Python files (*.py) under the 'swarms_torch' directory.
autopep8 --in-place --aggressive --aggressive --recursive --experimental --list-fixes swarms_torch/

# Run black with default settings, since black does not have an aggressiveness level.
# Black will format all Python files it finds in the 'swarms_torch' directory.
black --experimental-string-processing swarms_torch/

# Run ruff on the 'swarms_torch' directory.
# Add any additional flags if needed according to your version of ruff.
ruff swarms_torch/

# YAPF
# yapf --recursive --in-place --verbose --style=google --parallel swarms_torch
Binary file modified docs/.DS_Store
Binary file not shown.
16 changes: 9 additions & 7 deletions playground/silu_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,25 @@
import numpy as np
from mpl_toolkits.mplot3d import Axes3D


# SiLU (Sigmoid-weighted Linear Unit) activation function
def silu(x):
return x * (1 / (1 + np.exp(-x)))


# Generate inputs and calculate SiLU outputs
input_values = np.linspace(-10, 10, 100)
output_values = silu(input_values)

# Create 3D plot
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax = fig.add_subplot(111, projection="3d")

# Scatter plot of SiLU outputs
ax.scatter(input_values, output_values, input_values, c=output_values, cmap='viridis')
ax.set_xlabel('Input')
ax.set_ylabel('Output')
ax.set_zlabel('Input')
ax.set_title('3D Visualization of SiLU Activation Function')
ax.scatter(input_values, output_values, input_values, c=output_values, cmap="viridis")
ax.set_xlabel("Input")
ax.set_ylabel("Output")
ax.set_zlabel("Input")
ax.set_title("3D Visualization of SiLU Activation Function")

plt.show()
plt.show()
2 changes: 1 addition & 1 deletion playground/swarmalator_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
)

# Call the visualization function
visualize_swarmalators(xi)
visualize_swarmalators(xi)
11 changes: 11 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
torch
einops
pandas








mkdocs
mkdocs-material
mkdocs-glightbox
12 changes: 5 additions & 7 deletions swarms_torch/ant_colony_swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,13 @@ def fitness(self, solution):
def update_pheromones(self):
"""Update pheromone levels"""
for i, solution in enumerate(self.solutions):
self.pheromones[i] = (1 - self.evaporation_rate) * self.pheromones[
i
] + self.fitness(solution)
self.pheromones[i] = (1 - self.evaporation_rate
) * self.pheromones[i] + self.fitness(solution)

def choose_next_path(self):
"""Choose the next path based on the pheromone levels"""
probabilities = (self.pheromones**self.alpha) * (
(1.0 / (1 + self.pheromones)) ** self.beta
)
(1.0 / (1 + self.pheromones))**self.beta)

probabilities /= probabilities.sum()

Expand All @@ -90,8 +88,8 @@ def optimize(self):
# This is a placeholder. Actual implementation will define how
# ants traverse the search space.
solution = torch.randint(
32, 127, (len(self.goal),), dtype=torch.float32
) # Random characters.
32, 127, (len(self.goal),),
dtype=torch.float32) # Random characters.
self.solutions.append(solution)
self.update_pheromones()

Expand Down
103 changes: 53 additions & 50 deletions swarms_torch/autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def cast_tuple(t, length=1):


def eval_decorator(fn):

def inner(self, *args, **kwargs):
was_training = self.training
self.eval()
Expand All @@ -47,7 +48,8 @@ def align_right(t, lens, pad_id=0):
pad_lens = seq_len - lens
max_pad_len = pad_lens.amax()

batch_arange = torch.arange(batch, device=device, dtype=torch.long)[..., None]
batch_arange = torch.arange(batch, device=device, dtype=torch.long)[...,
None]
prompt_len_arange = torch.arange(seq_len, device=device, dtype=torch.long)

t = F.pad(t, (max_pad_len, 0), value=0)
Expand All @@ -65,7 +67,8 @@ def top_p(logits, thres=0.9):
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

sorted_indices_to_remove = cum_probs > thres
sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, -1), value=False)
sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, -1),
value=False)

sorted_logits[sorted_indices_to_remove] = float("-inf")
return sorted_logits.scatter(1, sorted_indices, sorted_logits)
Expand Down Expand Up @@ -108,18 +111,21 @@ def contrastive_decode_fn(expert_logits, amateur_logits, alpha=0.1, beta=0.5):
cutoff = log(alpha) + expert_logits.amax(dim=-1, keepdim=True)
diffs = (1 + beta) * expert_logits - beta * amateur_logits
contrastive_decode_logits = diffs.masked_fill(
expert_logits < cutoff, -torch.finfo(expert_logits.dtype).max
)
expert_logits < cutoff, -torch.finfo(expert_logits.dtype).max)
return contrastive_decode_logits


# autoregressive wrapper class


class AutoregressiveWrapper(Module):
def __init__(
self, net, ignore_index=-100, pad_value=0, mask_prob=0.0, add_attn_z_loss=False
):

def __init__(self,
net,
ignore_index=-100,
pad_value=0,
mask_prob=0.0,
add_attn_z_loss=False):
super().__init__()
self.pad_value = pad_value
self.ignore_index = ignore_index
Expand All @@ -138,21 +144,20 @@ def __init__(

@torch.no_grad()
@eval_decorator
def generate(
self,
prompts,
seq_len,
eos_token=None,
temperature=1.0,
prompt_lens: Optional[Tensor] = None,
filter_logits_fn: Callable = top_k,
restrict_to_max_seq_len=True,
amateur_model: Optional[Union[Module, Tuple[Module]]] = None,
filter_kwargs: dict = dict(),
contrastive_decode_kwargs: Union[dict, Tuple[dict]] = dict(beta=0.5, alpha=0.1),
cache_kv=True,
**kwargs
):
def generate(self,
prompts,
seq_len,
eos_token=None,
temperature=1.0,
prompt_lens: Optional[Tensor] = None,
filter_logits_fn: Callable = top_k,
restrict_to_max_seq_len=True,
amateur_model: Optional[Union[Module, Tuple[Module]]] = None,
filter_kwargs: dict = dict(),
contrastive_decode_kwargs: Union[dict, Tuple[dict]] = dict(
beta=0.5, alpha=0.1),
cache_kv=True,
**kwargs):
max_seq_len, device = self.max_seq_len, prompts.device

prompts, ps = pack([prompts], "* n")
Expand Down Expand Up @@ -200,16 +205,15 @@ def generate(
if exists(cache):
for inter in cache.attn_intermediates:
inter.cached_kv = [
t[..., -(max_seq_len - 1) :, :] for t in inter.cached_kv
t[..., -(max_seq_len - 1):, :]
for t in inter.cached_kv
]

logits, new_cache = self.net(
x,
return_intermediates=True,
cache=cache,
seq_start_pos=seq_start_pos,
**kwargs
)
logits, new_cache = self.net(x,
return_intermediates=True,
cache=cache,
seq_start_pos=seq_start_pos,
**kwargs)

if cache_kv and self.net.can_cache_kv:
cache = new_cache
Expand All @@ -221,28 +225,27 @@ def generate(

if exists(amateur_model):
for i, (
amateur,
amateur_cache,
amateur_contrastive_decode_kwargs,
amateur,
amateur_cache,
amateur_contrastive_decode_kwargs,
) in enumerate(
zip(amateur_model, amateur_caches, contrastive_decode_kwargs)
):
zip(amateur_model, amateur_caches,
contrastive_decode_kwargs)):
amateur_logits, next_amateur_cache = amateur(
x,
return_intermediates=True,
cache=amateur_cache,
seq_start_pos=seq_start_pos,
**kwargs
)
**kwargs)

amateur_logits = amateur_logits[:, -1]

assert (
amateur_logits.shape == logits.shape
), "logits dimension are not the same between amateur and expert model"
assert amateur_logits.shape == logits.shape, (
"logits dimension are not the same between amateur and expert"
" model")
logits = contrastive_decode_fn(
logits, amateur_logits, **amateur_contrastive_decode_kwargs
)
logits, amateur_logits,
**amateur_contrastive_decode_kwargs)

if cache_kv and amateur.can_cache_kv:
amateur_caches[i] = next_amateur_cache
Expand Down Expand Up @@ -286,20 +289,20 @@ def forward(self, x, **kwargs):
if self.mask_prob > 0.0:
rand = torch.randn(inp.shape, device=x.device)
rand[:, 0] = -torch.finfo(
rand.dtype
).max # first token should not be masked out
rand.dtype).max # first token should not be masked out
num_mask = min(int(seq * self.mask_prob), seq - 1)
indices = rand.topk(num_mask, dim=-1).indices
mask = ~torch.zeros_like(inp).scatter(1, indices, 1.0).bool()
kwargs.update(self_attn_kv_mask=mask)

logits, cache = self.net(
inp, return_intermediates=True, return_attn_z_loss=add_attn_z_loss, **kwargs
)
logits, cache = self.net(inp,
return_intermediates=True,
return_attn_z_loss=add_attn_z_loss,
**kwargs)

loss = F.cross_entropy(
rearrange(logits, "b n c -> b c n"), target, ignore_index=ignore_index
)
loss = F.cross_entropy(rearrange(logits, "b n c -> b c n"),
target,
ignore_index=ignore_index)

if add_attn_z_loss:
loss = loss + cache.attn_z_loss
Expand Down
10 changes: 5 additions & 5 deletions swarms_torch/cellular_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@


class TransformerCell(nn.Module):

def __init__(
self,
input_dim,
Expand All @@ -11,9 +12,9 @@ def __init__(
neighborhood_size=3,
):
super(TransformerCell, self).__init__()
self.transformer = nn.Transformer(
input_dim, nhead=nhead, num_encoder_layers=num_layers
)
self.transformer = nn.Transformer(input_dim,
nhead=nhead,
num_encoder_layers=num_layers)
self.neighborhood_size = neighborhood_size

def forward(self, x, neigbors):
Expand Down Expand Up @@ -56,8 +57,7 @@ class CellularSwarm(nn.Module):
def __init__(self, cell_count, input_dim, nhead, time_steps=4):
super(CellularSwarm, self).__init__()
self.cells = nn.ModuleList(
[TransformerCell(input_dim, nhead) for _ in range(cell_count)]
)
[TransformerCell(input_dim, nhead) for _ in range(cell_count)])
self.time_steps = time_steps

def forward(self, x):
Expand Down
21 changes: 12 additions & 9 deletions swarms_torch/fish_school.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,10 @@ def __init__(
alpha=0.1,
):
super().__init__()
self.model = Transformer(
d_model=dim, nhead=heads, num_encoder_layers=depth, num_decoder_layers=depth
)
self.model = Transformer(d_model=dim,
nhead=heads,
num_encoder_layers=depth,
num_decoder_layers=depth)
self.optimizer = Adam(self.parameters())
self.scheduler = ReduceLROnPlateau(self.optimizer, "min")

Expand All @@ -96,13 +97,15 @@ def train(self, src, tgt, labels):
outputs = self.model(src, tgt)

# cross entropy loss
loss = CrossEntropyLoss()(outputs.view(-1, outputs.size(-1)), labels.view(-1))
loss = CrossEntropyLoss()(outputs.view(-1, outputs.size(-1)),
labels.view(-1))

# complexity regularization by adding the sum of the squares of the
# weights
if self.complexity_regularization:
# complexity regularization
loss += self.alpha * sum(p.pow(2.0).sum() for p in self.model.parameters())
loss += self.alpha * sum(
p.pow(2.0).sum() for p in self.model.parameters())

# backpropagation
loss.backward()
Expand Down Expand Up @@ -211,7 +214,8 @@ def forward(self, src, tgt, labels):
# with higher food
if self.complex_school:
for fish in self.fish:
neighbor = self.fish[torch.randint(0, len(self.fish), (1,)).item()]
neighbor = self.fish[torch.randint(0, len(self.fish),
(1,)).item()]
if neighbor.food > fish.food:
fish.model.load_state_dict(neighbor.model.state_dict())

Expand All @@ -234,9 +238,8 @@ def predict(self, src, tgt):
averages outputs of the top peforming models
"""
top_fish = sorted(self.fish, key=lambda f: f.food, reverse=True)[
: self.num_top_fish
]
top_fish = sorted(self.fish, key=lambda f: f.food,
reverse=True)[:self.num_top_fish]

self.model.eval()

Expand Down
Loading

0 comments on commit 90b2692

Please sign in to comment.