Skip to content

Commit

Permalink
fix imports sorting
Browse files Browse the repository at this point in the history
  • Loading branch information
Vectorrent committed Nov 16, 2024
1 parent b596c4c commit 06b9db2
Show file tree
Hide file tree
Showing 15 changed files with 99 additions and 43 deletions.
7 changes: 2 additions & 5 deletions eval.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import os

import lm_eval
from lm_eval.models.huggingface import HFLM
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer
from praxis import (
PraxisConfig,
PraxisForCausalLM,
PraxisModel,
)

from praxis import PraxisConfig, PraxisForCausalLM, PraxisModel

AutoConfig.register("praxis", PraxisConfig)
AutoModel.register(PraxisConfig, PraxisModel)
Expand Down
4 changes: 2 additions & 2 deletions praxis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from praxis.blocks import BLOCK_REGISTRY
from praxis.configuration_praxis import PraxisConfig
from praxis.modeling_praxis import PraxisForCausalLM, PraxisModel
from praxis.blocks import BLOCK_REGISTRY
from praxis.modules import EXPERT_REGISTRY, ENCODING_REGISTRY, EMBEDDING_REGISTRY
from praxis.modules import EMBEDDING_REGISTRY, ENCODING_REGISTRY, EXPERT_REGISTRY
4 changes: 2 additions & 2 deletions praxis/blocks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from praxis.blocks.transformer import PraxisBlock
from praxis.blocks.nano import PraxisNano
from praxis.blocks.conv import PraxisConv
from praxis.blocks.nano import PraxisNano
from praxis.blocks.transformer import PraxisBlock

BLOCK_REGISTRY = {"transformer": PraxisBlock, "nano": PraxisNano, "conv": PraxisConv}
86 changes: 70 additions & 16 deletions praxis/blocks/conv.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import math
import time
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import time
import math
from typing import Optional
from praxis.modules.dense import PraxisGLU

from praxis.activations import ACT2FN
from praxis.modules.dense import PraxisGLU


class PraxisConv(nn.Module):
Expand Down Expand Up @@ -139,41 +141,93 @@ def forward(self, x):


class CausalGlobalContext(nn.Module):
"""
Implements a kind of squeeze-and-excitation mechanism, which allows
us to bridge convolutional operations' local contexts, into a global one.
https://arxiv.org/abs/1904.11492v1
"""

def __init__(self, in_channels, reduction=0.125):
super().__init__()
bottleneck = int(in_channels * reduction)

# Context modeling - single 1x1 conv to generate global attention weights
self.context = nn.Conv1d(in_channels, 1, kernel_size=1)

# Bottleneck transform
# Bottleneck transform with Conv1d layers
self.transform = nn.Sequential(
nn.Linear(in_channels, bottleneck),
nn.LayerNorm(bottleneck),
# First conv reduces channels
nn.Conv1d(in_channels, bottleneck, kernel_size=1),
# LayerNorm needs to be applied to channel dim for conv
nn.GroupNorm(1, bottleneck), # equivalent to LayerNorm for conv
ACT2FN["periodic_relu"],
nn.Linear(bottleneck, in_channels),
# Second conv restores channels
nn.Conv1d(bottleneck, in_channels, kernel_size=1),
)

def forward(self, x):
# Generate attention weights
attn = self.context(x) # B, 1, T

# Apply causal masking by setting future weights to -inf
# Apply causal masking
mask = torch.triu(torch.ones_like(attn), diagonal=1)
attn = attn.masked_fill(mask.bool(), float("-inf"))

# Softmax to get attention distribution
attn = F.softmax(attn, dim=-1) # B, 1, T

# Calculate global context
context = torch.matmul(x, attn.transpose(-2, -1)) # B, C, 1
context = context.squeeze(-1) # B, C

# Transform through bottleneck
context = self.transform(context) # B, C
# Transform through bottleneck (no need to squeeze/unsqueeze)
context = self.transform(context) # B, C, 1

# Broadcast and add to input
return x + context.expand(-1, -1, x.size(2))


# class CausalGlobalContext(nn.Module):
# """
# Implements a kind of squeeze-and-excitation mechanism, which allows
# us to bridge convolutional operations' local contexts, into a global one.
# https://arxiv.org/abs/1904.11492v1
# """

# def __init__(self, in_channels, reduction=0.125):
# super().__init__()
# bottleneck = int(in_channels * reduction)

# # Context modeling - single 1x1 conv to generate global attention weights
# self.context = nn.Conv1d(in_channels, 1, kernel_size=1)

# # Bottleneck transform
# self.transform = nn.Sequential(
# nn.Linear(in_channels, bottleneck),
# # nn.Conv1d(in_channels, 1, kernel_size=1),
# nn.LayerNorm(bottleneck),
# ACT2FN["periodic_relu"],
# # nn.Conv1d(1, in_channels, kernel_size=1),
# nn.Linear(bottleneck, in_channels),
# )

# def forward(self, x):
# # Generate attention weights
# attn = self.context(x) # B, 1, T

# # Apply causal masking by setting future weights to -inf
# mask = torch.triu(torch.ones_like(attn), diagonal=1)
# attn = attn.masked_fill(mask.bool(), float("-inf"))

# # Softmax to get attention distribution
# attn = F.softmax(attn, dim=-1) # B, 1, T

# # Calculate global context
# context = torch.matmul(x, attn.transpose(-2, -1)) # B, C, 1
# context = context.squeeze(-1) # B, C

# # Transform through bottleneck
# context = self.transform(context) # B, C

# Add to all positions
return x + context.unsqueeze(-1) # B, C, T
# # Add to all positions
# return x + context.unsqueeze(-1) # B, C, T


# class CausalGlobalContext(nn.Module):
Expand Down
11 changes: 7 additions & 4 deletions praxis/blocks/nano.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import math
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import math
from typing import Optional

from praxis.modules.dense import PraxisGLU


Expand Down Expand Up @@ -106,10 +108,11 @@ def _interpolate_weights(self, features):


if __name__ == "__main__":
import numpy as np
from dataclasses import dataclass
import random
import time
from dataclasses import dataclass

import numpy as np

# Mock AutoConfig class to simulate the configuration
@dataclass
Expand Down
1 change: 0 additions & 1 deletion praxis/blocks/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from praxis.modules.attention import PraxisAttention
from praxis.modules.experts import EXPERT_REGISTRY


input_shape = lambda batch_size, hidden_dim: torch.empty((batch_size, hidden_dim))


Expand Down
2 changes: 1 addition & 1 deletion praxis/modeling_praxis.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
)

from praxis import PraxisConfig
from praxis.modules import MultiIdentity, EMBEDDING_REGISTRY
from praxis.modules import EMBEDDING_REGISTRY, MultiIdentity
from praxis.modules.compression import PraxisCompressor
from praxis.modules.decoder import PraxisDecoder

Expand Down
4 changes: 2 additions & 2 deletions praxis/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from praxis.modules.common import MultiIdentity, ENCODING_REGISTRY
from praxis.blocks import BLOCK_REGISTRY
from praxis.modules.experts import EXPERT_CONFIGS, EXPERT_REGISTRY
from praxis.modules.common import ENCODING_REGISTRY, MultiIdentity
from praxis.modules.embeddings import EMBEDDING_REGISTRY
from praxis.modules.experts import EXPERT_CONFIGS, EXPERT_REGISTRY
2 changes: 1 addition & 1 deletion praxis/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from torch import Tensor
from transformers import AutoConfig

from praxis.modules.memory import PraxisMemory
from praxis.modules.common import ENCODING_REGISTRY
from praxis.modules.memory import PraxisMemory


class PraxisAttention(nn.Module):
Expand Down
1 change: 1 addition & 0 deletions praxis/modules/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch.nn as nn

from praxis.modules.encoding import ALiBi, NoPE, RoPE, YaRN

ENCODING_REGISTRY = {"alibi": ALiBi, "nope": NoPE, "rope": RoPE, "yarn": YaRN}
Expand Down
6 changes: 4 additions & 2 deletions praxis/modules/dense.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import OrderedDict, Optional
from typing import Optional, OrderedDict

from torch import nn
from transformers import AutoConfig
from praxis.activations import ACT2FN, ACT2CLS

from praxis.activations import ACT2CLS, ACT2FN


class PraxisMLP(nn.Sequential):
Expand Down
6 changes: 3 additions & 3 deletions praxis/modules/encoding.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
from torch import nn
import torch.nn.functional as F
import math

import torch
import torch.nn.functional as F
from torch import nn
from transformers import AutoConfig


Expand Down
2 changes: 1 addition & 1 deletion praxis/modules/experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

from praxis.modules.dense import PraxisGLU, PraxisMLP
from praxis.modules.peer import PraxisPEER
from praxis.modules.smear import PraxisSMEAR
from praxis.modules.router import PraxisMixtureOfDepths
from praxis.modules.smear import PraxisSMEAR

EXPERT_REGISTRY = {
"mlp": PraxisMLP,
Expand Down
4 changes: 2 additions & 2 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def check_for_updates():
import time
import traceback
import uuid
import warnings
from collections import Counter
from dataclasses import dataclass
from datetime import datetime, timedelta
Expand Down Expand Up @@ -110,7 +111,6 @@ def check_for_updates():
AutoTokenizer,
PreTrainedTokenizer,
)
import warnings

warnings.filterwarnings("ignore", ".*Checkpoint directory.*exists and is not empty*")

Expand All @@ -119,8 +119,8 @@ def check_for_updates():
from interface import TerminalDashboard
from praxis import (
BLOCK_REGISTRY,
EXPERT_REGISTRY,
ENCODING_REGISTRY,
EXPERT_REGISTRY,
PraxisConfig,
PraxisForCausalLM,
PraxisModel,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
]

# Requirements primarily for IDE-related tooling
dev_requirements = ["isort", "lsprotocol", "pygls"]
dev_requirements = ["attrs", "isort", "lsprotocol", "pygls"]

setup(
name="praxis",
Expand Down

0 comments on commit 06b9db2

Please sign in to comment.