Skip to content

Commit

Permalink
[BUGF][MixtureOfMambas] [Fusion Methods] [average, weighted, absmax, …
Browse files Browse the repository at this point in the history
…weighted_softmax, or your own custom function] [CODE QUALITY][+++]
  • Loading branch information
Kye committed Jan 7, 2024
1 parent 610feaf commit 9ea2d7b
Show file tree
Hide file tree
Showing 29 changed files with 387 additions and 123 deletions.
34 changes: 18 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,25 +111,27 @@ visualize_swarmalators(xi)
```

### Mixture of Mambas
- Mixture of Mamba models of SSMs, could be improved with a router of some kind or better aggregation methods!
- An 100% novel implementation of a swarm of MixtureOfMambas.
- Various fusion methods through averages, weighted_aggegrate, and more to come like a gating mechanism or other various methods.
- fusion methods: average, weighted, absmax, weighted_softmax, or your own custom function

```python
import torch
from swarms_torch.mixture_of_mamba import MixtureOfMambas

# Example Usage
num_models = 3
dim = 16
state_range = (1, 20)
conv_range = (1, 10)
expand_range = (1, 5)

mixture_model = MixtureOfMambas(num_models, dim, state_range, conv_range, expand_range)
x = torch.randn(2, 64, dim).to("cuda")
output = mixture_model(
x, aggregation_method="average"
) # Or use 'weighted' with weights

from swarms_torch import MixtureOfMambas

# 3D Tensor for text
x = torch.rand(1, 512, 512)
model = MixtureOfMambas(
num_mambas=2,
dim=512,
d_state=1024,
depth=4,
d_conv=1024,
expand=4,
fusion_method="average",
custom_fusion_func=None,
)
print(model(x).shape)

```

Expand Down
4 changes: 3 additions & 1 deletion example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from swarms_torch import ParticleSwarmOptimization

# test
pso = ParticleSwarmOptimization(goal="Attention is all you need", n_particles=100)
pso = ParticleSwarmOptimization(
goal="Attention is all you need", n_particles=100
)
pso.optimize(iterations=1000)
6 changes: 4 additions & 2 deletions playground/mixture_of_mambas.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
conv_range = (1, 10)
expand_range = (1, 5)

mixture_model = MixtureOfMambas(num_models, dim, state_range, conv_range, expand_range)
mixture_model = MixtureOfMambas(
num_models, dim, state_range, conv_range, expand_range
)
x = torch.randn(2, 64, dim).to("cuda")
output = mixture_model(
x, aggregation_method="average"
x, fusion_method="average"
) # Or use 'weighted' with weights
4 changes: 3 additions & 1 deletion playground/silu_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ def silu(x):
ax = fig.add_subplot(111, projection="3d")

# Scatter plot of SiLU outputs
ax.scatter(input_values, output_values, input_values, c=output_values, cmap="viridis")
ax.scatter(
input_values, output_values, input_values, c=output_values, cmap="viridis"
)
ax.set_xlabel("Input")
ax.set_ylabel("Output")
ax.set_zlabel("Input")
Expand Down
27 changes: 24 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,32 @@ python = "^3.6"
torch = "*"
einops = "*"
zetascale = "*"
mamba-ssm = "*"
causal-conv1d = "1.1.0"




[tool.poetry.dev-dependencies]
# Add development dependencies here
[tool.poetry.group.lint.dependencies]
ruff = ">=0.0.249,<0.1.10"
types-toml = "^0.10.8.1"
types-redis = "^4.3.21.6"
types-pytz = "^2023.3.0.0"
black = "^23.1.0"
types-chardet = "^5.0.4.6"
mypy-protobuf = "^3.0.0"


[tool.autopep8]
max_line_length = 80
ignore = "E501,W6" # or ["E501", "W6"]
in-place = true
recursive = true
aggressive = 3

[tool.ruff]
line-length = 80

[tool.black]
line-length = 80
target-version = ['py38']
preview = true
6 changes: 4 additions & 2 deletions scripts/requirementstxt_to_pyproject.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ def update_pyproject_versions(pyproject_path):
print(f"Error: The file '{pyproject_path}' was not found.")
return
except toml.TomlDecodeError:
print(f"Error: The file '{pyproject_path}' is not a valid TOML" " file.")
print(f"Error: The file '{pyproject_path}' is not a valid TOML file.")
return

dependencies = data.get("tool", {}).get("poetry", {}).get("dependencies", {})
dependencies = (
data.get("tool", {}).get("poetry", {}).get("dependencies", {})
)

for package in dependencies:
if package.lower() == "python":
Expand Down
5 changes: 4 additions & 1 deletion swarms_torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
from swarms_torch.particle_swarm import ParticleSwarmOptimization
from swarms_torch.queen_bee import QueenBeeGa
from swarms_torch.spiral_optimization import SPO
from swarms_torch.transformer_pso import Particle, TransformerParticleSwarmOptimization
from swarms_torch.transformer_pso import (
Particle,
TransformerParticleSwarmOptimization,
)

__all__ = [
"ParticleSwarmOptimization",
Expand Down
49 changes: 35 additions & 14 deletions swarms_torch/autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ def align_right(t, lens, pad_id=0):
pad_lens = seq_len - lens
max_pad_len = pad_lens.amax()

batch_arange = torch.arange(batch, device=device, dtype=torch.long)[..., None]
batch_arange = torch.arange(batch, device=device, dtype=torch.long)[
..., None
]
prompt_len_arange = torch.arange(seq_len, device=device, dtype=torch.long)

t = F.pad(t, (max_pad_len, 0), value=0)
Expand All @@ -65,7 +67,9 @@ def top_p(logits, thres=0.9):
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

sorted_indices_to_remove = cum_probs > thres
sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, -1), value=False)
sorted_indices_to_remove = F.pad(
sorted_indices_to_remove, (1, -1), value=False
)

sorted_logits[sorted_indices_to_remove] = float("-inf")
return sorted_logits.scatter(1, sorted_indices, sorted_logits)
Expand Down Expand Up @@ -118,7 +122,12 @@ def contrastive_decode_fn(expert_logits, amateur_logits, alpha=0.1, beta=0.5):

class AutoregressiveWrapper(Module):
def __init__(
self, net, ignore_index=-100, pad_value=0, mask_prob=0.0, add_attn_z_loss=False
self,
net,
ignore_index=-100,
pad_value=0,
mask_prob=0.0,
add_attn_z_loss=False,
):
super().__init__()
self.pad_value = pad_value
Expand Down Expand Up @@ -149,9 +158,11 @@ def generate(
restrict_to_max_seq_len=True,
amateur_model: Optional[Union[Module, Tuple[Module]]] = None,
filter_kwargs: dict = dict(),
contrastive_decode_kwargs: Union[dict, Tuple[dict]] = dict(beta=0.5, alpha=0.1),
contrastive_decode_kwargs: Union[dict, Tuple[dict]] = dict(
beta=0.5, alpha=0.1
),
cache_kv=True,
**kwargs
**kwargs,
):
max_seq_len, device = self.max_seq_len, prompts.device

Expand Down Expand Up @@ -200,15 +211,16 @@ def generate(
if exists(cache):
for inter in cache.attn_intermediates:
inter.cached_kv = [
t[..., -(max_seq_len - 1) :, :] for t in inter.cached_kv
t[..., -(max_seq_len - 1) :, :]
for t in inter.cached_kv
]

logits, new_cache = self.net(
x,
return_intermediates=True,
cache=cache,
seq_start_pos=seq_start_pos,
**kwargs
**kwargs,
)

if cache_kv and self.net.can_cache_kv:
Expand All @@ -225,24 +237,28 @@ def generate(
amateur_cache,
amateur_contrastive_decode_kwargs,
) in enumerate(
zip(amateur_model, amateur_caches, contrastive_decode_kwargs)
zip(
amateur_model, amateur_caches, contrastive_decode_kwargs
)
):
amateur_logits, next_amateur_cache = amateur(
x,
return_intermediates=True,
cache=amateur_cache,
seq_start_pos=seq_start_pos,
**kwargs
**kwargs,
)

amateur_logits = amateur_logits[:, -1]

assert amateur_logits.shape == logits.shape, (
"logits dimension are not the same between amateur and expert"
" model"
"logits dimension are not the same between amateur and"
" expert model"
)
logits = contrastive_decode_fn(
logits, amateur_logits, **amateur_contrastive_decode_kwargs
logits,
amateur_logits,
**amateur_contrastive_decode_kwargs,
)

if cache_kv and amateur.can_cache_kv:
Expand Down Expand Up @@ -295,11 +311,16 @@ def forward(self, x, **kwargs):
kwargs.update(self_attn_kv_mask=mask)

logits, cache = self.net(
inp, return_intermediates=True, return_attn_z_loss=add_attn_z_loss, **kwargs
inp,
return_intermediates=True,
return_attn_z_loss=add_attn_z_loss,
**kwargs,
)

loss = F.cross_entropy(
rearrange(logits, "b n c -> b c n"), target, ignore_index=ignore_index
rearrange(logits, "b n c -> b c n"),
target,
ignore_index=ignore_index,
)

if add_attn_z_loss:
Expand Down
17 changes: 13 additions & 4 deletions swarms_torch/fish_school.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@ def __init__(
):
super().__init__()
self.model = Transformer(
d_model=dim, nhead=heads, num_encoder_layers=depth, num_decoder_layers=depth
d_model=dim,
nhead=heads,
num_encoder_layers=depth,
num_decoder_layers=depth,
)
self.optimizer = Adam(self.parameters())
self.scheduler = ReduceLROnPlateau(self.optimizer, "min")
Expand All @@ -96,13 +99,17 @@ def train(self, src, tgt, labels):
outputs = self.model(src, tgt)

# cross entropy loss
loss = CrossEntropyLoss()(outputs.view(-1, outputs.size(-1)), labels.view(-1))
loss = CrossEntropyLoss()(
outputs.view(-1, outputs.size(-1)), labels.view(-1)
)

# complexity regularization by adding the sum of the squares of the
# weights
if self.complexity_regularization:
# complexity regularization
loss += self.alpha * sum(p.pow(2.0).sum() for p in self.model.parameters())
loss += self.alpha * sum(
p.pow(2.0).sum() for p in self.model.parameters()
)

# backpropagation
loss.backward()
Expand Down Expand Up @@ -211,7 +218,9 @@ def forward(self, src, tgt, labels):
# with higher food
if self.complex_school:
for fish in self.fish:
neighbor = self.fish[torch.randint(0, len(self.fish), (1,)).item()]
neighbor = self.fish[
torch.randint(0, len(self.fish), (1,)).item()
]
if neighbor.food > fish.food:
fish.model.load_state_dict(neighbor.model.state_dict())

Expand Down
12 changes: 9 additions & 3 deletions swarms_torch/graph_cellular_automa.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def __init__(self, input_dim, hidden_dim):
super(WeightUpdateModel, self).__init__()

self.mlp = nn.Sequential(
nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 1)
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1),
)

def forward(self, x):
Expand All @@ -51,7 +53,9 @@ def __init__(self, embedding_dim, hidden_dim):
embedding_dim, hidden_dim, embedding_dim
)
self.replication_model = ReplicationModel(embedding_dim, hidden_dim)
self.weight_update_model = WeightUpdateModel(2 * embedding_dim, hidden_dim)
self.weight_update_model = WeightUpdateModel(
2 * embedding_dim, hidden_dim
)

def forward(self, node_embeddings, adjacency_matrix):
# Update node embeddings using Graph Cellular Automata
Expand All @@ -70,7 +74,9 @@ def forward(self, node_embeddings, adjacency_matrix):
(updated_embeddings[i], updated_embeddings[j])
)

edge_weights[i, j] = self.weight_update_model(combined_embedding)
edge_weights[i, j] = self.weight_update_model(
combined_embedding
)

return updated_embeddings, replication_decisions, edge_weights

Expand Down
4 changes: 3 additions & 1 deletion swarms_torch/hivemind_swarm_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ def __init__(
dim_head=dim_head,
)
# Create a list of transformers sharing the same weights
self.experts = nn.ModuleList([self.base_transformer for _ in range(num_models)])
self.experts = nn.ModuleList(
[self.base_transformer for _ in range(num_models)]
)

# Gating mechniams allows the model to dynamically weight the contribution of each transformer
# in the swarm. This is done by learning a weight for each transformer and then using a softmax
Expand Down
7 changes: 5 additions & 2 deletions swarms_torch/ma_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,14 @@ def __init__(self, env_name, num_agents):
self.envs = [gym.make(env_name) for _ in range(num_agents)]
self.agents = [
MAgent.Agent(
self.envs[0].observation_space.shape[0], self.envs[0].action_space.n
self.envs[0].observation_space.shape[0],
self.envs[0].action_space.n,
)
for _ in range(num_agents)
]
self.optimizers = [optim.Adam(agent.parameters()) for agent in self.agents]
self.optimizers = [
optim.Adam(agent.parameters()) for agent in self.agents
]

def step(self, agent_actions):
rewards = []
Expand Down
Loading

0 comments on commit 9ea2d7b

Please sign in to comment.