Skip to content

Commit

Permalink
multiswarm pso prototype
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Oct 12, 2023
1 parent faccaa2 commit b00b19c
Show file tree
Hide file tree
Showing 10 changed files with 190 additions and 66 deletions.
29 changes: 11 additions & 18 deletions swarms_torch/autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ def align_right(t, lens, pad_id=0):
pad_lens = seq_len - lens
max_pad_len = pad_lens.amax()

batch_arange = torch.arange(
batch, device=device, dtype=torch.long)[..., None]
batch_arange = torch.arange(batch, device=device, dtype=torch.long)[..., None]
prompt_len_arange = torch.arange(seq_len, device=device, dtype=torch.long)

t = F.pad(t, (max_pad_len, 0), value=0)
Expand All @@ -66,8 +65,7 @@ def top_p(logits, thres=0.9):
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

sorted_indices_to_remove = cum_probs > thres
sorted_indices_to_remove = F.pad(
sorted_indices_to_remove, (1, -1), value=False)
sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, -1), value=False)

sorted_logits[sorted_indices_to_remove] = float("-inf")
return sorted_logits.scatter(1, sorted_indices, sorted_logits)
Expand Down Expand Up @@ -120,12 +118,8 @@ def contrastive_decode_fn(expert_logits, amateur_logits, alpha=0.1, beta=0.5):

class AutoregressiveWrapper(Module):
def __init__(
self,
net,
ignore_index=-100,
pad_value=0,
mask_prob=0.0,
add_attn_z_loss=False):
self, net, ignore_index=-100, pad_value=0, mask_prob=0.0, add_attn_z_loss=False
):
super().__init__()
self.pad_value = pad_value
self.ignore_index = ignore_index
Expand Down Expand Up @@ -206,7 +200,7 @@ def generate(
if exists(cache):
for inter in cache.attn_intermediates:
inter.cached_kv = [
t[..., -(max_seq_len - 1):, :] for t in inter.cached_kv
t[..., -(max_seq_len - 1) :, :] for t in inter.cached_kv
]

logits, new_cache = self.net(
Expand Down Expand Up @@ -247,7 +241,8 @@ def generate(
amateur_logits.shape == logits.shape
), "logits dimension are not the same between amateur and expert model"

Check failure on line 242 in swarms_torch/autoregressive.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E501)

swarms_torch/autoregressive.py:242:89: E501 Line too long (91 > 88 characters)
logits = contrastive_decode_fn(
logits, amateur_logits, **amateur_contrastive_decode_kwargs)
logits, amateur_logits, **amateur_contrastive_decode_kwargs
)

if cache_kv and amateur.can_cache_kv:
amateur_caches[i] = next_amateur_cache
Expand Down Expand Up @@ -299,14 +294,12 @@ def forward(self, x, **kwargs):
kwargs.update(self_attn_kv_mask=mask)

logits, cache = self.net(
inp, return_intermediates=True, return_attn_z_loss=add_attn_z_loss, **kwargs)
inp, return_intermediates=True, return_attn_z_loss=add_attn_z_loss, **kwargs
)

loss = F.cross_entropy(
rearrange(
logits,
"b n c -> b c n"),
target,
ignore_index=ignore_index)
rearrange(logits, "b n c -> b c n"), target, ignore_index=ignore_index
)

if add_attn_z_loss:
loss = loss + cache.attn_z_loss
Expand Down
12 changes: 4 additions & 8 deletions swarms_torch/fish_school.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,8 @@ def __init__(
):
super().__init__()
self.model = Transformer(
d_model=dim,
nhead=heads,
num_encoder_layers=depth,
num_decoder_layers=depth)
d_model=dim, nhead=heads, num_encoder_layers=depth, num_decoder_layers=depth
)
self.optimizer = Adam(self.parameters())
self.scheduler = ReduceLROnPlateau(self.optimizer, "min")

Expand Down Expand Up @@ -104,8 +102,7 @@ def train(self, src, tgt, labels):
# weights
if self.complexity_regularization:
# complexity regularization
loss += self.alpha * sum(p.pow(2.0).sum()
for p in self.model.parameters())
loss += self.alpha * sum(p.pow(2.0).sum() for p in self.model.parameters())

# backpropagation
loss.backward()
Expand Down Expand Up @@ -214,8 +211,7 @@ def forward(self, src, tgt, labels):
# with higher food
if self.complex_school:
for fish in self.fish:
neighbor = self.fish[torch.randint(
0, len(self.fish), (1,)).item()]
neighbor = self.fish[torch.randint(0, len(self.fish), (1,)).item()]
if neighbor.food > fish.food:
fish.model.load_state_dict(neighbor.model.state_dict())

Expand Down
11 changes: 4 additions & 7 deletions swarms_torch/graph_cellular_automa.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,8 @@ def __init__(self, input_dim, hidden_dim):
super(WeightUpdateModel, self).__init__()

self.mlp = nn.Sequential(
nn.Linear(
input_dim, hidden_dim), nn.ReLU(), nn.Linear(
hidden_dim, 1))
nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 1)
)

def forward(self, x):
return self.mlp(x)
Expand All @@ -52,8 +51,7 @@ def __init__(self, embedding_dim, hidden_dim):
embedding_dim, hidden_dim, embedding_dim
)
self.replication_model = ReplicationModel(embedding_dim, hidden_dim)
self.weight_update_model = WeightUpdateModel(
2 * embedding_dim, hidden_dim)
self.weight_update_model = WeightUpdateModel(2 * embedding_dim, hidden_dim)

def forward(self, node_embeddings, adjacency_matrix):
# Update node embeddings using Graph Cellular Automata
Expand All @@ -72,8 +70,7 @@ def forward(self, node_embeddings, adjacency_matrix):
(updated_embeddings[i], updated_embeddings[j])
)

edge_weights[i, j] = self.weight_update_model(
combined_embedding)
edge_weights[i, j] = self.weight_update_model(combined_embedding)

return updated_embeddings, replication_decisions, edge_weights

Expand Down
3 changes: 1 addition & 2 deletions swarms_torch/ma_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ def __init__(self, env_name, num_agents):
)
for _ in range(num_agents)
]
self.optimizers = [optim.Adam(agent.parameters())
for agent in self.agents]
self.optimizers = [optim.Adam(agent.parameters()) for agent in self.agents]

def step(self, agent_actions):
rewards = []
Expand Down
152 changes: 152 additions & 0 deletions swarms_torch/multi_swarm_pso.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import torch

Check failure on line 1 in swarms_torch/multi_swarm_pso.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

swarms_torch/multi_swarm_pso.py:1:8: F401 `torch` imported but unused
import random
import string


class MultiSwarmPSO:
"""
Multi-Swarm PSO Algorithm
Parameters
----------
target_string : str
The target string to be generated
num_sub_swarms : int
The number of sub-swarms
num_particles_per_swarm : int
The number of particles per sub-swarm
max_iterations : int
The maximum number of iterations to run the algorithm
Attributes
----------
target_string : str
The target string to be generated
num_sub_swarms : int
The number of sub-swarms
num_particles_per_swarm : int
The number of particles per sub-swarm
num_dimensions : int
The number of dimensions in the search space
max_iterations : int
The maximum number of iterations to run the algorithm
Methods
-------
generate_random_string()
Generates a random string of length num_dimensions
fitness_function(position)
Calculates the fitness of a given position
diversification_method(sub_swarms)
Adds a new sub-swarm if the number of sub-swarms is less than the maximum
optimize()
Runs the Multi-Swarm PSO algorithm
References
----------
.. [1] https://www.researchgate.net/publication/221172800_Multi-swarm_Particle_Swarm_Optimization
Usage:
------
target_string = "hello world"
multi_swarm = MultiSwarm(target_string)
multi_swarm.optimize()
"""
def __init__(
self,
target_string,
num_sub_swarms=5,
num_particles_per_swarm=20,
max_iterations=100,
):
self.target_string = target_string
self.num_sub_swarms = num_sub_swarms
self.num_particles_per_swarm = num_particles_per_swarm
self.num_dimensions = len(target_string)
self.max_iterations = max_iterations

def generate_random_string(self):
"""
Generates a random string of length num_dimensions
"""
return "".join(
random.choice(string.ascii_lowercase + " ")
for _ in range(self.num_dimensions)
)

def fitness_function(self, position):
"""Fitness function to be maximized"""
fitness = sum(a == b for a, b in zip(position, self.target_string))
return fitness

def diversification_method(self, sub_swarms):
"""Diversification method to add a new sub-swarm if the number of sub-swarms is less than the maximum"""

Check failure on line 88 in swarms_torch/multi_swarm_pso.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E501)

swarms_torch/multi_swarm_pso.py:88:89: E501 Line too long (112 > 88 characters)
if len(sub_swarms) < self.num_sub_swarms:
new_sub_swarm = [
self.generate_random_string()
for _ in range(self.num_particles_per_swarm)
]
sub_swarms.append(new_sub_swarm)

def optimize(self):
"""Optimizes the fitness function"""
sub_swarms = [
[self.generate_random_string() for _ in range(self.num_particles_per_swarm)]
for _ in range(self.num_sub_swarms)
]

for iteration in range(self.max_iterations):
for sub_swarm in sub_swarms:
for particle in sub_swarm:
fitness = self.fitness_function(particle)
if fitness > 0:
index_to_change = random.randint(0, self.num_dimensions - 1)
new_char = random.choice(string.ascii_lowercase + " ")
new_position = list(particle)
new_position[index_to_change] = new_char
new_position = "".join(new_position)
particle = new_position

self.diversification_method(sub_swarms)

global_best_fitness = max(
self.fitness_function(particle)
for sub_swarm in sub_swarms
for particle in sub_swarm
)
global_best_position = [
particle
for sub_swarm in sub_swarms
for particle in sub_swarm
if self.fitness_function(particle) == global_best_fitness
][0]
print(
f"Iteration {iteration}: Global Best Fitness = {global_best_fitness}, Global Best Position = {global_best_position}"

Check failure on line 129 in swarms_torch/multi_swarm_pso.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E501)

swarms_torch/multi_swarm_pso.py:129:89: E501 Line too long (132 > 88 characters)
)

global_best_fitness = max(
self.fitness_function(particle)
for sub_swarm in sub_swarms
for particle in sub_swarm
)
global_best_position = [
particle
for sub_swarm in sub_swarms
for particle in sub_swarm
if self.fitness_function(particle) == global_best_fitness
][0]
print(
f"Final Result: Global Best Fitness = {global_best_fitness}, Global Best Position = {global_best_position}"

Check failure on line 144 in swarms_torch/multi_swarm_pso.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E501)

swarms_torch/multi_swarm_pso.py:144:89: E501 Line too long (119 > 88 characters)
)


# Example usage
if __name__ == "__main__":
target_string = "hello world"
multi_swarm = MultiSwarm(target_string)

Check failure on line 151 in swarms_torch/multi_swarm_pso.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

swarms_torch/multi_swarm_pso.py:151:19: F821 Undefined name `MultiSwarm`
multi_swarm.optimize()
3 changes: 1 addition & 2 deletions swarms_torch/neuronal_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,7 @@ def __init__(self, neuron_count, num_states, input_dim, output_dim, nhead):
super(NNTransformer, self).__init__()

# Initialize neurons and synapses
self.neurons = nn.ModuleList(
[Neuron(num_states) for _ in range(neuron_count)])
self.neurons = nn.ModuleList([Neuron(num_states) for _ in range(neuron_count)])
self.synapses = nn.ModuleList(
[
SynapseTransformer(input_dim, output_dim, nhead)
Expand Down
3 changes: 1 addition & 2 deletions swarms_torch/particle_swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,5 +130,4 @@ def optimize(
for _ in range(iterations):
self.update()
best_particle = self.global_best
print("Best Particle: ", "".join(
[chr(int(i)) for i in best_particle]))
print("Best Particle: ", "".join([chr(int(i)) for i in best_particle]))
24 changes: 9 additions & 15 deletions swarms_torch/queen_bee.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,7 @@ def _evolve(self):
"""

# Sort population by fitness
fitnesses = 1.0 / \
torch.square(self.pool - self.target_gene).sum(dim=-1)
fitnesses = 1.0 / torch.square(self.pool - self.target_gene).sum(dim=-1)
indices = fitnesses.sort(descending=True).indices
self.pool, fitnesses = self.pool[indices], fitnesses[indices]

Expand All @@ -131,13 +130,11 @@ def _evolve(self):
self.queen_fitness, fitnesses = fitnesses[0], fitnesses[1:]

# Deterministic tournament selection
contender_ids = torch.randn((self.pop_size -
1, self.pop_size -
1)).argsort(dim=-
1)[..., : self.num_tournament_participants]
contender_ids = torch.randn((self.pop_size - 1, self.pop_size - 1)).argsort(
dim=-1
)[..., : self.num_tournament_participants]
participants, tournaments = self.pool[contender_ids], fitnesses[contender_ids]
top_winner = tournaments.topk(
1, dim=-1, largest=True, sorted=False).indices
top_winner = tournaments.topk(1, dim=-1, largest=True, sorted=False).indices
top_winner = top_winner.unsqueeze(-1).expand(-1, -1, self.gene_length)
parents = participants.gather(1, top_winner).squeeze(1)

Expand All @@ -146,7 +143,7 @@ def _evolve(self):
self.pop_size - 1, self.gene_length
)
self.pool = torch.cat(
(queen_parents[:, : self.gene_midpoint], parents[:, self.gene_midpoint:]),
(queen_parents[:, : self.gene_midpoint], parents[:, self.gene_midpoint :]),
dim=-1,
)

Expand All @@ -158,10 +155,8 @@ def _evolve(self):
mutated_pool = torch.where(mutate_mask, self.pool + noise, self.pool)

strong_mutate_mask = (
torch.randn(
self.pool.shape).argsort(
dim=-
1) < self.strong_num_code_mutate)
torch.randn(self.pool.shape).argsort(dim=-1) < self.strong_num_code_mutate
)
noise = torch.randint(0, 2, self.pool.shape) * 2 - 1
strong_mutated_pool = torch.where(
strong_mutate_mask, self.pool + noise, self.pool
Expand All @@ -180,8 +175,7 @@ def _check_convergence(self):
"""
Check if any of the solutions has achieved the goal
"""
fitnesses = 1.0 / \
torch.square(self.pool - self.target_gene).sum(dim=-1)
fitnesses = 1.0 / torch.square(self.pool - self.target_gene).sum(dim=-1)
return (fitnesses == float("inf")).any().item()


Expand Down
6 changes: 2 additions & 4 deletions swarms_torch/spiral_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,8 @@ def __init__(self, goal: str = None, m: int = 10, k_max: int = 1000):

# Initializing the search points and center randomly
# Note: 32-126 is the ASCII range for all printable characters
self.points = torch.randint(
32, 127, (self.m, self.n_dim), dtype=torch.float32)
self.center = torch.randint(
32, 127, (self.n_dim,), dtype=torch.float32)
self.points = torch.randint(32, 127, (self.m, self.n_dim), dtype=torch.float32)
self.center = torch.randint(32, 127, (self.n_dim,), dtype=torch.float32)

def _step_rate(self, k):
"""
Expand Down
Loading

0 comments on commit b00b19c

Please sign in to comment.