Skip to content

Commit

Permalink
batch prompt :批量提示
Browse files Browse the repository at this point in the history
  • Loading branch information
shadowcz007 committed Aug 1, 2024
1 parent 4cd6a07 commit be6f47a
Show file tree
Hide file tree
Showing 8 changed files with 1,191 additions and 858 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ https://github.com/shadowcz007/comfyui-mixlab-nodes/assets/12645064/e7e77f90-e43

[workflow-5](./workflow/5-gpt-workflow.json)

最新:ChatGPT 节点支持 Local LLM(llama.cpp),Phi3、llama3 都可以直接一个节点运行了。
<!-- 最新:ChatGPT 节点支持 Local LLM(llama.cpp),Phi3、llama3 都可以直接一个节点运行了。
Model download,move to :`models/llamafile/`
Expand Down Expand Up @@ -143,7 +143,7 @@ pip install 'llama-cpp-python[server]'
```
pip install llama-cpp-python \
--extra-index-url https://abetlen.github.io/llama-cpp-python/whl/metal
```
``` -->

## Prompt

Expand Down
221 changes: 116 additions & 105 deletions __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@
llama_model=""
llama_chat_format=""

try:
from .nodes.ChatGPT import get_llama_models,get_llama_model_path,llama_cpp_client
llama_cpp_client("")
# try:
# from .nodes.ChatGPT import get_llama_models,get_llama_model_path,llama_cpp_client
# llama_cpp_client("")

except:
print("##nodes.ChatGPT ImportError")
# except:
# print("##nodes.ChatGPT ImportError")


from .nodes.RembgNode import get_rembg_models,U2NET_HOME,run_briarmbg,run_rembg
Expand Down Expand Up @@ -679,11 +679,11 @@ async def get_checkpoints(request):
except Exception as e:
print('/mixlab/folder_paths',False,e)

try:
if data['type']=='llamafile':
names=get_llama_models()
except:
print("llamafile none")
# try:
# if data['type']=='llamafile':
# names=get_llama_models()
# except:
# print("llamafile none")

try:
if data['type']=='rembg':
Expand Down Expand Up @@ -860,117 +860,128 @@ async def mixlab_post_prompt(request):
return web.json_response({"error": "no prompt", "node_errors": []}, status=400)


# AR页面
# @routes.get('/mixlab/AR')
async def handle_ar_page(request):
html_file = os.path.join(current_path, "web/ar.html")
if os.path.exists(html_file):
with open(html_file, 'r', encoding='utf-8', errors='ignore') as f:
html_data = f.read()
return web.Response(text=html_data, content_type='text/html')
else:
return web.Response(text="HTML file not found", status=404)

async def start_local_llm(data):
global llama_port,llama_model,llama_chat_format
if llama_port and llama_model and llama_chat_format:
return {"port":llama_port,"model":llama_model,"chat_format":llama_chat_format}

import threading
import uvicorn
from llama_cpp.server.app import create_app
from llama_cpp.server.settings import (
Settings,
ServerSettings,
ModelSettings,
ConfigFileSettings,
)
# async def start_local_llm(data):
# global llama_port,llama_model,llama_chat_format
# if llama_port and llama_model and llama_chat_format:
# return {"port":llama_port,"model":llama_model,"chat_format":llama_chat_format}

# import threading
# import uvicorn
# from llama_cpp.server.app import create_app
# from llama_cpp.server.settings import (
# Settings,
# ServerSettings,
# ModelSettings,
# ConfigFileSettings,
# )

if not "model" in data and "model_path" in data:
data['model']= os.path.basename(data["model_path"])
model=data["model_path"]
# if not "model" in data and "model_path" in data:
# data['model']= os.path.basename(data["model_path"])
# model=data["model_path"]

elif "model" in data:
model=get_llama_model_path(data['model'])
# elif "model" in data:
# model=get_llama_model_path(data['model'])

n_gpu_layers=-1
# n_gpu_layers=-1

if "n_gpu_layers" in data:
n_gpu_layers=data['n_gpu_layers']
# if "n_gpu_layers" in data:
# n_gpu_layers=data['n_gpu_layers']


chat_format="chatml"
# chat_format="chatml"

model_alias=os.path.basename(model)
# model_alias=os.path.basename(model)

# 多模态
clip_model_path=None
# # 多模态
# clip_model_path=None

prefix = "llava-phi-3-mini"
file_name = prefix+"-mmproj-"
if model_alias.startswith(prefix):
for file in os.listdir(os.path.dirname(model)):
if file.startswith(file_name):
clip_model_path=os.path.join(os.path.dirname(model),file)
chat_format='llava-1-5'
# print('#clip_model_path',chat_format,clip_model_path,model)

address="127.0.0.1"
port=9090
success = False
for i in range(11): # 尝试最多11次
if await check_port_available(address, port + i):
port = port + i
success = True
break

if success == False:
return {"port":None,"model":""}
# prefix = "llava-phi-3-mini"
# file_name = prefix+"-mmproj-"
# if model_alias.startswith(prefix):
# for file in os.listdir(os.path.dirname(model)):
# if file.startswith(file_name):
# clip_model_path=os.path.join(os.path.dirname(model),file)
# chat_format='llava-1-5'
# # print('#clip_model_path',chat_format,clip_model_path,model)

# address="127.0.0.1"
# port=9090
# success = False
# for i in range(11): # 尝试最多11次
# if await check_port_available(address, port + i):
# port = port + i
# success = True
# break

# if success == False:
# return {"port":None,"model":""}


server_settings=ServerSettings(host=address,port=port)

name, ext = os.path.splitext(os.path.basename(model))
if name:
# print('#model',name)
app = create_app(
server_settings=server_settings,
model_settings=[
ModelSettings(
model=model,
model_alias=name,
n_gpu_layers=n_gpu_layers,
n_ctx=4098,
chat_format=chat_format,
embedding=False,
clip_model_path=clip_model_path
)])

def run_uvicorn():
uvicorn.run(
app,
host=os.getenv("HOST", server_settings.host),
port=int(os.getenv("PORT", server_settings.port)),
ssl_keyfile=server_settings.ssl_keyfile,
ssl_certfile=server_settings.ssl_certfile,
)

# 创建一个子线程
thread = threading.Thread(target=run_uvicorn)

# 启动子线程
thread.start()

llama_port=port
llama_model=data['model']
llama_chat_format=chat_format

return {"port":llama_port,"model":llama_model,"chat_format":llama_chat_format}
# server_settings=ServerSettings(host=address,port=port)

# name, ext = os.path.splitext(os.path.basename(model))
# if name:
# # print('#model',name)
# app = create_app(
# server_settings=server_settings,
# model_settings=[
# ModelSettings(
# model=model,
# model_alias=name,
# n_gpu_layers=n_gpu_layers,
# n_ctx=4098,
# chat_format=chat_format,
# embedding=False,
# clip_model_path=clip_model_path
# )])

# def run_uvicorn():
# uvicorn.run(
# app,
# host=os.getenv("HOST", server_settings.host),
# port=int(os.getenv("PORT", server_settings.port)),
# ssl_keyfile=server_settings.ssl_keyfile,
# ssl_certfile=server_settings.ssl_certfile,
# )

# # 创建一个子线程
# thread = threading.Thread(target=run_uvicorn)

# # 启动子线程
# thread.start()

# llama_port=port
# llama_model=data['model']
# llama_chat_format=chat_format

# return {"port":llama_port,"model":llama_model,"chat_format":llama_chat_format}

# llam服务的开启
@routes.post('/mixlab/start_llama')
async def my_hander_method(request):
data =await request.json()
# print(data)
if llama_port and llama_model and llama_chat_format:
return web.json_response({"port":llama_port,"model":llama_model,"chat_format":llama_chat_format} )
try:
result=await start_local_llm(data)
except:
result= {"port":None,"model":"","llama_cpp_error":True}
print('start_local_llm error')
# @routes.post('/mixlab/start_llama')
# async def my_hander_method(request):
# data =await request.json()
# # print(data)
# if llama_port and llama_model and llama_chat_format:
# return web.json_response({"port":llama_port,"model":llama_model,"chat_format":llama_chat_format} )
# try:
# result=await start_local_llm(data)
# except:
# result= {"port":None,"model":"","llama_cpp_error":True}
# print('start_local_llm error')

return web.json_response(result)
# return web.json_response(result)

# 重启服务
@routes.post('/mixlab/re_start')
Expand Down
Loading

0 comments on commit be6f47a

Please sign in to comment.