Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jcaip/llm bsr #1601

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
26 changes: 21 additions & 5 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, get_model_size_in_bytes

torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False
torch.backends.cuda.enable_cudnn_sdp(True)


class HostEvent:
Expand Down Expand Up @@ -794,9 +793,26 @@ def ffn_or_attn_only(mod, fqn):
from torchao.sparsity import semi_sparse_weight, sparsify_

if "semi" in sparsity:
# TODO there is a bug here, need to fix
# Fixed sparsity level for 2:4
sparsify_(model.to(device), semi_sparse_weight(), filter_fn=ffn_only)

if "bsr" in sparsity:
# Apply Supermask to get sparse weights
from torchao.prototype.sparsity.superblock.supermask import SupermaskLinear
sparsify_(
model,
lambda x: SupermaskLinear.from_linear(x,
sparsity_level=0.9,
blocksize=64,
),
filter_fn=ffn_only,
)

from torchao.prototype.sparsity.superblock.blocksparse import block_sparse_weight
sparsify_(model,
block_sparse_weight(blocksize=64),
filter_fn=ffn_only)

model_size = get_model_size_in_bytes(model, ignore_embeddings=True) / 1e9

if save:
Expand All @@ -811,7 +827,7 @@ def ffn_or_attn_only(mod, fqn):
print("Compiling Model")
global decode_one_token, prefill
decode_one_token = torch.compile(
decode_one_token, mode="reduce-overhead", fullgraph=True
decode_one_token, mode="reduce-overhead", fullgraph=True, dynamic=True,
)

if compile_prefill:
Expand Down Expand Up @@ -850,7 +866,7 @@ def ffn_or_attn_only(mod, fqn):
prompt = f"{B_INST} {prompt.strip()} {E_INST}"
encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)

if interactive and i >= 0:
if interactive and i >= 0 and prefill_size is None:
buffer = []
period_id = tokenizer.encode(".")[0]
done_generating = False
Expand Down Expand Up @@ -920,7 +936,7 @@ def callback(x):
device_sync(device=device) # MKG
t = time.perf_counter() - t0

if not interactive and demo_summarize_prompt is None:
if not interactive and demo_summarize_prompt is None and prefill_size is None:
tok_list = y[0].tolist()
# truncate text after end of string token
tokens = (
Expand Down
Loading
Loading