Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama-Vision: Enable tracing, refactor generation code #15005

Merged
merged 19 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
e6651e9
#14519: Use FlashDecode in LlamaVision xattn
cglagovichTT Nov 1, 2024
7ed2318
#14519: WIP create simpler interface for LlamaVision
cglagovichTT Nov 4, 2024
42f87a7
#14519: Update xattn test input shapes since masks with non-causal Fl…
cglagovichTT Nov 4, 2024
97a9d45
#14519: Change TMs in xattn. Naive TMs now fail xattn test, so this c…
cglagovichTT Nov 4, 2024
236b5f8
#14519: Simple vision demo is functional, with llama_vision_model sup…
cglagovichTT Nov 4, 2024
555f27d
#14519: unit tests for xattn, xblock, and xtransformer now support ba…
cglagovichTT Nov 5, 2024
ef53146
#14519: Fix up Llama vision model. Simple demo works again with batch…
cglagovichTT Nov 6, 2024
7d50e6c
#14519: Refactor LlamaVision class to clean up separation of input pr…
cglagovichTT Nov 6, 2024
ff2f7ba
#14519: Don't pass full token tensor into decode and prefill
cglagovichTT Nov 7, 2024
51e2c89
#14519: Fix rebase issues
cglagovichTT Nov 7, 2024
0b7807f
#14519: Refactored decode input preparation to separate host tensor c…
cglagovichTT Nov 7, 2024
71f0727
#14519: Implement LlamaVision generation class which plugs into exist…
cglagovichTT Nov 8, 2024
81cc9b1
#14519: Fix test script now that pytest params changed
cglagovichTT Nov 8, 2024
1149598
#14519: Remove breakpoint
cglagovichTT Nov 8, 2024
9e5d0b9
#14519: license
cglagovichTT Nov 8, 2024
ac7ffcb
#14519: Remove trace decorator
cglagovichTT Nov 13, 2024
751e4b1
#14519: remove batch option from rot mat
cglagovichTT Nov 13, 2024
659c111
#14519: Add traced demo to CI
cglagovichTT Nov 13, 2024
31944db
#14519: Fix merge bug in xblock test
cglagovichTT Nov 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 38 additions & 44 deletions models/demos/llama3/demo/multimodal_demo_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,21 @@
from PIL import Image as PIL_Image
from termcolor import cprint

from models.demos.llama3.demo.multimodal_demo_text import create_multimodal_model
import llama_models.llama3.reference_impl.generation as llama_reference_generation
import pytest
import os
import ttnn

import llama_models.llama3.reference_impl.generation as llama_reference_generation
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import ImageMedia, UserMessage

from pkg_resources import resource_filename

IMG_PATH = Path(resource_filename("llama_models", "scripts/resources/"))

import torch
import pytest
import os
import ttnn
from models.demos.llama3.tt.multimodal.vision_generator import LlamaVision
from models.demos.llama3.demo.simple_vision_demo import create_multimodal_model


@pytest.mark.parametrize(
Expand All @@ -36,39 +38,36 @@
"target",
("tt", "cpu"),
)
@pytest.mark.parametrize(
"warmup_iters",
(0, 1),
)
def test_llama_multimodal_demo_chat(
mesh_device,
target,
warmup_iters,
temperature: float = 0.5,
top_p: float = 0.9,
max_seq_len: int = 512,
max_batch_size: int = 4,
max_batch_size: int = 1,
max_gen_len: Optional[int] = 200,
model_parallel_size: Optional[int] = None,
):
mesh_device.enable_program_cache()
mesh_device.enable_async(True)
ckpt_dir = os.environ["LLAMA_DIR"]
tokenizer_path = str(Path(ckpt_dir) / "tokenizer.model")

logger.info(f"Creating reference model from checkpoint in '{ckpt_dir}'")
generator = llama_reference_generation.Llama.build(
ckpt_dir,
tokenizer_path=tokenizer_path,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
model_parallel_size=model_parallel_size,
)

if target == "tt":
if target == "cpu":
generator = llama_reference_generation.Llama.build(
ckpt_dir,
tokenizer_path=tokenizer_path,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
model_parallel_size=model_parallel_size,
)
else:
logger.info(f"Creating TT model on {len(mesh_device.get_devices())} devices")
model = create_multimodal_model(generator.args, mesh_device)
generator.model = model
mesh_device.enable_program_cache()
mesh_device.enable_async(True)
model_args, model = create_multimodal_model(mesh_device, max_batch_size=max_batch_size, max_seq_len=max_seq_len)
tokenizer = Tokenizer(model_path=tokenizer_path)
formatter = ChatFormat(tokenizer)
generator = LlamaVision(model, model_args, mesh_device, tokenizer=tokenizer, formatter=formatter)

# image understanding
dialogs = []
Expand All @@ -85,26 +84,21 @@ def test_llama_multimodal_demo_chat(
)
],
]
# text only
dialogs += [
[UserMessage(content="what is the recipe of mayonnaise in two sentences?")],
]

print(f"Running text completion on {target}")
for _ in range(warmup_iters + 1):
for dialog in dialogs:
result = generator.chat_completion(
dialog,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
)
for dialog in dialogs:
result = generator.chat_completion(
dialog,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
)

for msg in dialog:
print(f"{msg.role.capitalize()}: {msg.content}\n")
for msg in dialog:
print(f"{msg.role.capitalize()}: {msg.content}\n")

out_message = result.generation
print(f"> {out_message.role.capitalize()}: {out_message.content}")
for t in out_message.tool_calls:
print(f" Tool call: {t.tool_name} ({t.arguments})")
print("\n==================================\n")
out_message = result.generation
print(f"> {out_message.role.capitalize()}: {out_message.content}")
for t in out_message.tool_calls:
print(f" Tool call: {t.tool_name} ({t.arguments})")
print("\n==================================\n")
64 changes: 25 additions & 39 deletions models/demos/llama3/demo/multimodal_demo_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,36 +8,22 @@
from PIL import Image as PIL_Image
from termcolor import cprint

import llama_models.llama3.reference_impl.generation as llama_reference_generation
import pytest
import os
import ttnn

import llama_models.llama3.reference_impl.generation as llama_reference_generation
from llama_models.llama3.api.datatypes import ImageMedia
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.api.chat_format import ChatFormat


from pkg_resources import resource_filename

IMG_PATH = Path(resource_filename("llama_models", "scripts/resources/"))

import torch
import pytest
import os
import ttnn


def create_multimodal_model(model_args, mesh_device, dtype=ttnn.bfloat16):
from models.demos.llama3.tt.multimodal.llama_vision_model import CrossAttentionTransformer
from models.demos.llama3.tt.model_config import TtModelArgs

tt_model_args = TtModelArgs(mesh_device)
checkpoint = torch.load(tt_model_args.consolidated_weights_path, map_location="cpu", weights_only=True)
model = CrossAttentionTransformer(
model_args,
mesh_device,
checkpoint,
weight_cache_path=tt_model_args.weight_cache_path(dtype),
dtype=dtype,
configuration=tt_model_args,
)
model.setup_cache(model_args.max_batch_size, torch.float32)
return model
from models.demos.llama3.demo.simple_vision_demo import create_multimodal_model
from models.demos.llama3.tt.multimodal.vision_generator import LlamaVision


@pytest.mark.parametrize(
Expand All @@ -64,28 +50,30 @@ def test_llama_multimodal_demo_text(
temperature: float = 0.5,
top_p: float = 0.9,
max_seq_len: int = 512,
max_batch_size: int = 4,
max_batch_size: int = 1,
max_gen_len: Optional[int] = 200,
model_parallel_size: Optional[int] = None,
):
mesh_device.enable_program_cache()
mesh_device.enable_async(True)
ckpt_dir = os.environ["LLAMA_DIR"]
tokenizer_path = str(Path(ckpt_dir) / "tokenizer.model")

logger.info(f"Creating reference model from checkpoint in '{ckpt_dir}'")
generator = llama_reference_generation.Llama.build(
ckpt_dir,
tokenizer_path=tokenizer_path,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
model_parallel_size=model_parallel_size,
)

if target == "tt":
if target == "cpu":
generator = llama_reference_generation.Llama.build(
ckpt_dir,
tokenizer_path=tokenizer_path,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
model_parallel_size=model_parallel_size,
)
else:
logger.info(f"Creating TT model on {len(mesh_device.get_devices())} devices")
model = create_multimodal_model(generator.args, mesh_device)
generator.model = model
mesh_device.enable_program_cache()
mesh_device.enable_async(True)
model_args, model = create_multimodal_model(mesh_device, max_batch_size=max_batch_size, max_seq_len=max_seq_len)
tokenizer = Tokenizer(model_path=tokenizer_path)
formatter = ChatFormat(tokenizer)
generator = LlamaVision(model, model_args, mesh_device, tokenizer=tokenizer, formatter=formatter)

with open(IMG_PATH / "dog.jpg", "rb") as f:
img = PIL_Image.open(f).convert("RGB")
Expand All @@ -100,8 +88,6 @@ def test_llama_multimodal_demo_text(
clutter = PIL_Image.open(f).convert("RGB")

interleaved_contents = [
# text only
"The color of the sky is blue but sometimes it can also be",
# image understanding
[ImageMedia(image=img), "If I had to write a haiku for this one"],
[ImageMedia(image=img2), "Couting the number of individual spaghetti strands in this image"],
Expand Down
Loading
Loading