Skip to content

Commit

Permalink
[CODE QUALITY]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Dec 19, 2023
1 parent aedbcd4 commit 5e1132a
Show file tree
Hide file tree
Showing 30 changed files with 562 additions and 356 deletions.
1 change: 0 additions & 1 deletion playground/silu_visualization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D


# SiLU (Sigmoid-weighted Linear Unit) activation function
Expand Down
2 changes: 0 additions & 2 deletions swarms_torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from swarms_torch.spiral_optimization import SPO
from swarms_torch.multi_swarm_pso import MultiSwarmPSO
from swarms_torch.transformer_pso import Particle, TransformerParticleSwarmOptimization
from swarms_torch.swarmalators.swarmalator_visualize import visualize_swarmalators
from swarms_torch.swarmalators.swarmalator_base import simulate_swarmalators

__all__ = [
"ParticleSwarmOptimization",
Expand Down
12 changes: 7 additions & 5 deletions swarms_torch/ant_colony_swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,15 @@ def fitness(self, solution):
def update_pheromones(self):
"""Update pheromone levels"""
for i, solution in enumerate(self.solutions):
self.pheromones[i] = (1 - self.evaporation_rate
) * self.pheromones[i] + self.fitness(solution)
self.pheromones[i] = (1 - self.evaporation_rate) * self.pheromones[
i
] + self.fitness(solution)

def choose_next_path(self):
"""Choose the next path based on the pheromone levels"""
probabilities = (self.pheromones**self.alpha) * (
(1.0 / (1 + self.pheromones))**self.beta)
(1.0 / (1 + self.pheromones)) ** self.beta
)

probabilities /= probabilities.sum()

Expand All @@ -88,8 +90,8 @@ def optimize(self):
# This is a placeholder. Actual implementation will define how
# ants traverse the search space.
solution = torch.randint(
32, 127, (len(self.goal),),
dtype=torch.float32) # Random characters.
32, 127, (len(self.goal),), dtype=torch.float32
) # Random characters.
self.solutions.append(solution)
self.update_pheromones()

Expand Down
100 changes: 49 additions & 51 deletions swarms_torch/autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def cast_tuple(t, length=1):


def eval_decorator(fn):

def inner(self, *args, **kwargs):
was_training = self.training
self.eval()
Expand All @@ -48,8 +47,7 @@ def align_right(t, lens, pad_id=0):
pad_lens = seq_len - lens
max_pad_len = pad_lens.amax()

batch_arange = torch.arange(batch, device=device, dtype=torch.long)[...,
None]
batch_arange = torch.arange(batch, device=device, dtype=torch.long)[..., None]
prompt_len_arange = torch.arange(seq_len, device=device, dtype=torch.long)

t = F.pad(t, (max_pad_len, 0), value=0)
Expand All @@ -67,8 +65,7 @@ def top_p(logits, thres=0.9):
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

sorted_indices_to_remove = cum_probs > thres
sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, -1),
value=False)
sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, -1), value=False)

sorted_logits[sorted_indices_to_remove] = float("-inf")
return sorted_logits.scatter(1, sorted_indices, sorted_logits)
Expand Down Expand Up @@ -111,21 +108,18 @@ def contrastive_decode_fn(expert_logits, amateur_logits, alpha=0.1, beta=0.5):
cutoff = log(alpha) + expert_logits.amax(dim=-1, keepdim=True)
diffs = (1 + beta) * expert_logits - beta * amateur_logits
contrastive_decode_logits = diffs.masked_fill(
expert_logits < cutoff, -torch.finfo(expert_logits.dtype).max)
expert_logits < cutoff, -torch.finfo(expert_logits.dtype).max
)
return contrastive_decode_logits


# autoregressive wrapper class


class AutoregressiveWrapper(Module):

def __init__(self,
net,
ignore_index=-100,
pad_value=0,
mask_prob=0.0,
add_attn_z_loss=False):
def __init__(
self, net, ignore_index=-100, pad_value=0, mask_prob=0.0, add_attn_z_loss=False
):
super().__init__()
self.pad_value = pad_value
self.ignore_index = ignore_index
Expand All @@ -144,20 +138,21 @@ def __init__(self,

@torch.no_grad()
@eval_decorator
def generate(self,
prompts,
seq_len,
eos_token=None,
temperature=1.0,
prompt_lens: Optional[Tensor] = None,
filter_logits_fn: Callable = top_k,
restrict_to_max_seq_len=True,
amateur_model: Optional[Union[Module, Tuple[Module]]] = None,
filter_kwargs: dict = dict(),
contrastive_decode_kwargs: Union[dict, Tuple[dict]] = dict(
beta=0.5, alpha=0.1),
cache_kv=True,
**kwargs):
def generate(
self,
prompts,
seq_len,
eos_token=None,
temperature=1.0,
prompt_lens: Optional[Tensor] = None,
filter_logits_fn: Callable = top_k,
restrict_to_max_seq_len=True,
amateur_model: Optional[Union[Module, Tuple[Module]]] = None,
filter_kwargs: dict = dict(),
contrastive_decode_kwargs: Union[dict, Tuple[dict]] = dict(beta=0.5, alpha=0.1),
cache_kv=True,
**kwargs
):
max_seq_len, device = self.max_seq_len, prompts.device

prompts, ps = pack([prompts], "* n")
Expand Down Expand Up @@ -205,15 +200,16 @@ def generate(self,
if exists(cache):
for inter in cache.attn_intermediates:
inter.cached_kv = [
t[..., -(max_seq_len - 1):, :]
for t in inter.cached_kv
t[..., -(max_seq_len - 1) :, :] for t in inter.cached_kv
]

logits, new_cache = self.net(x,
return_intermediates=True,
cache=cache,
seq_start_pos=seq_start_pos,
**kwargs)
logits, new_cache = self.net(
x,
return_intermediates=True,
cache=cache,
seq_start_pos=seq_start_pos,
**kwargs
)

if cache_kv and self.net.can_cache_kv:
cache = new_cache
Expand All @@ -225,27 +221,29 @@ def generate(self,

if exists(amateur_model):
for i, (
amateur,
amateur_cache,
amateur_contrastive_decode_kwargs,
amateur,
amateur_cache,
amateur_contrastive_decode_kwargs,
) in enumerate(
zip(amateur_model, amateur_caches,
contrastive_decode_kwargs)):
zip(amateur_model, amateur_caches, contrastive_decode_kwargs)
):
amateur_logits, next_amateur_cache = amateur(
x,
return_intermediates=True,
cache=amateur_cache,
seq_start_pos=seq_start_pos,
**kwargs)
**kwargs
)

amateur_logits = amateur_logits[:, -1]

assert amateur_logits.shape == logits.shape, (
"logits dimension are not the same between amateur and expert"
" model")
" model"
)
logits = contrastive_decode_fn(
logits, amateur_logits,
**amateur_contrastive_decode_kwargs)
logits, amateur_logits, **amateur_contrastive_decode_kwargs
)

if cache_kv and amateur.can_cache_kv:
amateur_caches[i] = next_amateur_cache
Expand Down Expand Up @@ -289,20 +287,20 @@ def forward(self, x, **kwargs):
if self.mask_prob > 0.0:
rand = torch.randn(inp.shape, device=x.device)
rand[:, 0] = -torch.finfo(
rand.dtype).max # first token should not be masked out
rand.dtype
).max # first token should not be masked out
num_mask = min(int(seq * self.mask_prob), seq - 1)
indices = rand.topk(num_mask, dim=-1).indices
mask = ~torch.zeros_like(inp).scatter(1, indices, 1.0).bool()
kwargs.update(self_attn_kv_mask=mask)

logits, cache = self.net(inp,
return_intermediates=True,
return_attn_z_loss=add_attn_z_loss,
**kwargs)
logits, cache = self.net(
inp, return_intermediates=True, return_attn_z_loss=add_attn_z_loss, **kwargs
)

loss = F.cross_entropy(rearrange(logits, "b n c -> b c n"),
target,
ignore_index=ignore_index)
loss = F.cross_entropy(
rearrange(logits, "b n c -> b c n"), target, ignore_index=ignore_index
)

if add_attn_z_loss:
loss = loss + cache.attn_z_loss
Expand Down
10 changes: 5 additions & 5 deletions swarms_torch/cellular_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@


class TransformerCell(nn.Module):

def __init__(
self,
input_dim,
Expand All @@ -12,9 +11,9 @@ def __init__(
neighborhood_size=3,
):
super(TransformerCell, self).__init__()
self.transformer = nn.Transformer(input_dim,
nhead=nhead,
num_encoder_layers=num_layers)
self.transformer = nn.Transformer(
input_dim, nhead=nhead, num_encoder_layers=num_layers
)
self.neighborhood_size = neighborhood_size

def forward(self, x, neigbors):
Expand Down Expand Up @@ -57,7 +56,8 @@ class CellularSwarm(nn.Module):
def __init__(self, cell_count, input_dim, nhead, time_steps=4):
super(CellularSwarm, self).__init__()
self.cells = nn.ModuleList(
[TransformerCell(input_dim, nhead) for _ in range(cell_count)])
[TransformerCell(input_dim, nhead) for _ in range(cell_count)]
)
self.time_steps = time_steps

def forward(self, x):
Expand Down
21 changes: 9 additions & 12 deletions swarms_torch/fish_school.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,9 @@ def __init__(
alpha=0.1,
):
super().__init__()
self.model = Transformer(d_model=dim,
nhead=heads,
num_encoder_layers=depth,
num_decoder_layers=depth)
self.model = Transformer(
d_model=dim, nhead=heads, num_encoder_layers=depth, num_decoder_layers=depth
)
self.optimizer = Adam(self.parameters())
self.scheduler = ReduceLROnPlateau(self.optimizer, "min")

Expand All @@ -97,15 +96,13 @@ def train(self, src, tgt, labels):
outputs = self.model(src, tgt)

# cross entropy loss
loss = CrossEntropyLoss()(outputs.view(-1, outputs.size(-1)),
labels.view(-1))
loss = CrossEntropyLoss()(outputs.view(-1, outputs.size(-1)), labels.view(-1))

# complexity regularization by adding the sum of the squares of the
# weights
if self.complexity_regularization:
# complexity regularization
loss += self.alpha * sum(
p.pow(2.0).sum() for p in self.model.parameters())
loss += self.alpha * sum(p.pow(2.0).sum() for p in self.model.parameters())

# backpropagation
loss.backward()
Expand Down Expand Up @@ -214,8 +211,7 @@ def forward(self, src, tgt, labels):
# with higher food
if self.complex_school:
for fish in self.fish:
neighbor = self.fish[torch.randint(0, len(self.fish),
(1,)).item()]
neighbor = self.fish[torch.randint(0, len(self.fish), (1,)).item()]
if neighbor.food > fish.food:
fish.model.load_state_dict(neighbor.model.state_dict())

Expand All @@ -238,8 +234,9 @@ def predict(self, src, tgt):
averages outputs of the top peforming models
"""
top_fish = sorted(self.fish, key=lambda f: f.food,
reverse=True)[:self.num_top_fish]
top_fish = sorted(self.fish, key=lambda f: f.food, reverse=True)[
: self.num_top_fish
]

self.model.eval()

Expand Down
23 changes: 10 additions & 13 deletions swarms_torch/graph_cellular_automa.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@


class GraphCellularAutomata(nn.Module):

def __init__(self, input_dim, hidden_dim, output_dim):
super(GraphCellularAutomata, self).__init__()

Expand All @@ -18,7 +17,6 @@ def forward(self, x):


class ReplicationModel(nn.Module):

def __init__(self, input_dim, hidden_dim):
super(ReplicationModel, self).__init__()

Expand All @@ -34,27 +32,26 @@ def forward(self, x):


class WeightUpdateModel(nn.Module):

def __init__(self, input_dim, hidden_dim):
super(WeightUpdateModel, self).__init__()

self.mlp = nn.Sequential(nn.Linear(input_dim, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, 1))
self.mlp = nn.Sequential(
nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 1)
)

def forward(self, x):
return self.mlp(x)


class NDP(nn.Module):

def __init__(self, embedding_dim, hidden_dim):
super(NDP, self).__init__()

self.gc_automata = GraphCellularAutomata(embedding_dim, hidden_dim,
embedding_dim)
self.gc_automata = GraphCellularAutomata(
embedding_dim, hidden_dim, embedding_dim
)
self.replication_model = ReplicationModel(embedding_dim, hidden_dim)
self.weight_update_model = WeightUpdateModel(2 * embedding_dim,
hidden_dim)
self.weight_update_model = WeightUpdateModel(2 * embedding_dim, hidden_dim)

def forward(self, node_embeddings, adjacency_matrix):
# Update node embeddings using Graph Cellular Automata
Expand All @@ -70,10 +67,10 @@ def forward(self, node_embeddings, adjacency_matrix):
for i in range(num_nodes):
for j in range(num_nodes):
combined_embedding = torch.cat(
(updated_embeddings[i], updated_embeddings[j]))
(updated_embeddings[i], updated_embeddings[j])
)

edge_weights[i,
j] = self.weight_update_model(combined_embedding)
edge_weights[i, j] = self.weight_update_model(combined_embedding)

return updated_embeddings, replication_decisions, edge_weights

Expand Down
Loading

0 comments on commit 5e1132a

Please sign in to comment.