Skip to content

Commit

Permalink
deepseek-vl2 experiments
Browse files Browse the repository at this point in the history
Signed-off-by: Vladimir Mandic <[email protected]>
  • Loading branch information
vladmandic committed Feb 15, 2025
1 parent c2ec624 commit 8906f96
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 10 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ build/
dist/

# dynamically generated
/repositories/ip-instruct/
/repositories/deepseek-vl2/

# all dynamic stuff
/extensions/**/*
Expand Down
4 changes: 2 additions & 2 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,5 @@
url = https://github.com/ArtVentureX/sd-webui-agent-scheduler
ignore = dirty
[submodule "extensions-builtin/sdnext-modernui"]
path = extensions-builtin/sdnext-modernui
url = https://github.com/BinaryQuantumSoul/sdnext-modernui
path = extensions-builtin/sdnext-modernui
url = https://github.com/BinaryQuantumSoul/sdnext-modernui
92 changes: 92 additions & 0 deletions modules/interrogate/deepseek.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# source: <https://huggingface.co/deepseek-ai/deepseek-vl2-tiny>
# implementation: <https://github.com/deepseek-ai/DeepSeek-VL2/tree/main/deepseek_vl2/serve>
"""
- run `git clone https://github.com/deepseek-ai/DeepSeek-VL2 repositories/deepseek-vl2 --depth 1`
- remove hardcoded `python==3.9` requirement due to obsolete attrdict package dependency
- patch transformers due to internal changes as deepseek requires obsolete `transformers==4.38.2`
- deepseek requires `xformers`
- broken flash_attention
"""

import os
import sys
import importlib
from transformers import AutoModelForCausalLM
from modules import shared, devices, paths


# model_path = "deepseek-ai/deepseek-vl2-small"
vl_gpt = None
vl_chat_processor = None


class fake_attrdict(object):
class AttrDict(dict): # dot notation access to dictionary attributes
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__

# def fake_is_flash_attn_2_available():
# return False


def predict(question, image, repo):
global vl_gpt, vl_chat_processor # pylint: disable=global-statement
if not shared.cmd_opts.experimental:
shared.log.error(f'Interrogate: type=vlm model="DeepSeek VL2" repo="{repo}" is experimental-only')
return ''
folder = os.path.join(paths.script_path, 'repositories', 'deepseek-vl2')
if not os.path.exists(folder):
shared.log.error(f'Interrogate: type=vlm model="DeepSeek VL2" repo="{repo}" deepseek-vl2 repo not found')
return ''
if vl_gpt is None:
sys.modules['attrdict'] = fake_attrdict
from transformers.models.llama import modeling_llama
modeling_llama.LlamaFlashAttention2 = modeling_llama.LlamaAttention
_deekseek_vl = importlib.import_module('repositories.deepseek-vl2.deepseek_vl2')
deekseek_vl_models = importlib.import_module('repositories.deepseek-vl2.deepseek_vl2.models')
vl_chat_processor = deekseek_vl_models.DeepseekVLV2Processor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
vl_gpt = AutoModelForCausalLM.from_pretrained(
repo,
trust_remote_code=True,
cache_dir=shared.opts.hfcache_dir,
)
vl_gpt = vl_gpt.to(device=devices.device, dtype=devices.dtype).eval()

if len(question) < 2:
question = "Describe the image."
question = question.replace('<', '').replace('>', '')
conversation = [
{
"role": "<|User|>",
"content": f"<image>\n<|ref|>{question}<|/ref|>.",
# "images": [image],
},
{"role": "<|Assistant|>", "content": ""},
]

prepare_inputs = vl_chat_processor(
conversations=conversation,
images=[image],
force_batchify=True,
system_prompt=""
).to(device=devices.device, dtype=devices.dtype)
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
inputs_embeds = inputs_embeds.to(device=devices.device, dtype=devices.dtype)
vl_gpt = vl_gpt.to(devices.device)
with devices.inference_context():
outputs = vl_gpt.language.generate(
inputs_embeds=inputs_embeds,
attention_mask=prepare_inputs.attention_mask,
pad_token_id=vl_chat_processor.tokenizer.eos_token_id,
bos_token_id=vl_chat_processor.tokenizer.bos_token_id,
eos_token_id=vl_chat_processor.tokenizer.eos_token_id,
max_new_tokens=shared.opts.interrogate_vlm_max_length,
do_sample=False,
use_cache=True
)
vl_gpt = vl_gpt.to(devices.cpu)
answer = vl_chat_processor.tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
print('inputs', prepare_inputs['sft_format'][0])
print('answer', answer)
return answer
10 changes: 5 additions & 5 deletions modules/interrogate/vqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,6 @@
from PIL import Image
from modules import shared, devices, errors

# TODO vlm: add additional models
# https://huggingface.co/nvidia/Eagle2-1B not compatible with latest transformers
# https://huggingface.co/deepseek-ai/deepseek-vl2-tiny requires custom code


processor = None
model = None
loaded: str = None
Expand All @@ -39,6 +34,8 @@
"ViLT Base": "dandelin/vilt-b32-finetuned-vqa", # 0.5GB
"JoyCaption": "fancyfeast/llama-joycaption-alpha-two-hf-llava", # 0.7GB
"JoyTag": "fancyfeast/joytag", # 17.4GB
# "DeepSeek VL2 Tiny": "deepseek-ai/deepseek-vl2-tiny", # broken
# "nVidia Eagle 2 1B": "nvidia/Eagle2-1B", # not compatible with latest transformers
}
vlm_prompts = [
'<CAPTION>',
Expand Down Expand Up @@ -352,6 +349,9 @@ def interrogate(question, image, model_name):
elif 'joycaption' in vqa_model.lower():
from modules.interrogate import joycaption
answer = joycaption.predict(question, image)
elif 'deepseek' in vqa_model.lower():
from modules.interrogate import deepseek
answer = deepseek.predict(question, image, vqa_model)
else:
answer = 'unknown model'
except Exception as e:
Expand Down
2 changes: 1 addition & 1 deletion modules/prompt_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
\\ | # Removes '\'
\( | # Start '('
\[ | # Start '['
:([+-]?[.\d]+)\) | # Weight ':', followed by an optional sign and a number (integer or decimal), and then ')'
:([+-]?[.\d]+)\) | # Weight ':', followed by an optional sign and a number, and then ')'
\) | # End ')'
\] | # End ']'
[^\\()\[\]:]+ | # Content matches any character except '\', '(', ')', '[', ']', ':'
Expand Down
2 changes: 1 addition & 1 deletion modules/sd_hijack_optimizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
shared.xformers_available = True
except Exception:
pass
else:
elif not shared.cmd_opts.experimental:
if sys.modules.get("xformers", None) is not None:
shared.log.debug('Unloading xFormers')
sys.modules["xformers"] = None
Expand Down

0 comments on commit 8906f96

Please sign in to comment.