From c8113d2e2cafec2c51ac5e288dc6b5f6bec4c3a3 Mon Sep 17 00:00:00 2001 From: Bohan Cheng <47214785+cbh778899@users.noreply.github.com> Date: Thu, 17 Oct 2024 14:57:48 +1100 Subject: [PATCH] bug fix for file upload & session switch (#67) * fix aws file upload bugs Signed-off-by: cbh778899 * edit styles of code section Signed-off-by: cbh778899 * add padding for user input element Signed-off-by: cbh778899 * add MIN_TOKENS Signed-off-by: cbh778899 * if max_tokens set to 0, there will be no limitations Signed-off-by: cbh778899 * remove output to console Signed-off-by: cbh778899 * add special value to bypass checkValue Signed-off-by: cbh778899 * change default max_tokens to 1024 as we have gpu inference now Signed-off-by: cbh778899 * fix engine switching bugs, engines are not shared between sessions any more Signed-off-by: cbh778899 --------- Signed-off-by: cbh778899 --- preloader/node-llama-cpp-preloader.js | 2 +- src/components/chat/UserMessage.jsx | 4 ++-- src/components/chat/index.jsx | 24 ++++++++++--------- src/components/settings/ModelSettings.jsx | 10 +++++--- .../components/ScrollBarComponent.jsx | 4 ++-- src/styles/chat.css | 13 ++++++++-- src/utils/general_settings.js | 2 +- src/utils/types.js | 4 +++- src/utils/workers/aws-worker.js | 17 ++++++++----- src/utils/workers/wllama-worker.js | 3 ++- 10 files changed, 53 insertions(+), 30 deletions(-) diff --git a/preloader/node-llama-cpp-preloader.js b/preloader/node-llama-cpp-preloader.js index b897797..0b8599b 100644 --- a/preloader/node-llama-cpp-preloader.js +++ b/preloader/node-llama-cpp-preloader.js @@ -106,10 +106,10 @@ async function chatCompletions(latest_message, cb=null) { const options = { signal: stop_signal.signal, stopOnAbortSignal: true, - maxTokens: max_tokens, topP: top_p, temperature } + if(max_tokens) options.maxTokens = max_tokens let resp_text = '' if(cb) options.onTextChunk = chunk => { resp_text += chunk; diff --git a/src/components/chat/UserMessage.jsx b/src/components/chat/UserMessage.jsx index dce3278..97a8142 100644 --- a/src/components/chat/UserMessage.jsx +++ b/src/components/chat/UserMessage.jsx @@ -16,7 +16,7 @@ export default function UserMessage({ uid, enable_send, file_available, abort_co event.preventDefault(); send(message, files); setMessage(''); - setFiles(''); + setFiles([]); } // update when uid changed, means we entered a new conversation @@ -41,7 +41,7 @@ export default function UserMessage({ uid, enable_send, file_available, abort_co e.name).join('; ')}` : "Select file to append"} - onChange={evt=>setFiles(evt.target.files.length ? evt.target.files[0] : null)} /> + onChange={evt=>setFiles([...evt.target.files])} /> } setMessage(evt.target.value)}/> diff --git a/src/components/chat/index.jsx b/src/components/chat/index.jsx index 956fb3a..dd969fc 100644 --- a/src/components/chat/index.jsx +++ b/src/components/chat/index.jsx @@ -1,11 +1,11 @@ -import { useEffect, useState } from "react"; +import { useEffect, useRef, useState } from "react"; import Tickets from "./Tickets"; // import Conversation from "./Conversation"; import useIDB from "../../utils/idb"; import DeleteConfirm from "./DeleteConfirm"; import ChatPage from "./ChatPage"; -import { useRef } from "react"; import { getCompletionFunctions } from "../../utils/workers"; +import { getPlatformSettings } from "../../utils/general_settings"; export default function Chat() { @@ -20,8 +20,8 @@ export default function Chat() { const [pending_message, setPendingMessage] = useState(null); const idb = useIDB(); - // const settings = useRef(getCompletionFunctions()); - const settings = useRef(getCompletionFunctions()); + const platform = useRef(getPlatformSettings().enabled_platform); + const [session_setting, setSessionSetting] = useState({}); async function sendMessage(message, files) { // save user messages @@ -59,11 +59,11 @@ export default function Chat() { // start inference const send_message = ( - settings.current.formator ? - await settings.current.formator(history_save, files) : history_save + session_setting.formator ? + await session_setting.formator(history_save, files) : history_save ) setPendingMessage('') - await settings.current.completions(send_message, cb) + await session_setting.completions(send_message, cb) } function updateChatClient(client) { @@ -127,27 +127,29 @@ export default function Chat() { message_history = messages; setChatHistory(messages) }).finally(()=>{ - const client = settings.current.initClient(chat.client || null, message_history) + const ss = getCompletionFunctions(chat.platform); + const client = ss.initClient(chat.client || null, message_history) if(!chat.client) { updateChatClient(client) } + setSessionSetting(ss); }) } // eslint-disable-next-line }, [chat]) return ( - settings.current ? + platform.current ?
= min; + return (v <= max && v >= min) || v === special; } function setValue(value, is_scroll = false) { diff --git a/src/styles/chat.css b/src/styles/chat.css index 2465dc3..1450f11 100644 --- a/src/styles/chat.css +++ b/src/styles/chat.css @@ -248,12 +248,21 @@ padding: 10px; overflow: auto; } -.chat > .conversation-main > .bubbles > .bubble :not(pre) code { + +.chat > .conversation-main > .bubbles > .bubble code { padding: 0px 5px; background-color: rgb(227, 227, 227); border-radius: 5px; } +.chat > .conversation-main > .bubbles > .bubble.user code { + background-color: rgb(22, 113, 203); +} + +.chat > .conversation-main > .bubbles > .bubble pre code { + background-color: unset; +} + @keyframes dotAnimation { 0% { color: rgb(90, 90, 90); } 50% { color: rgb(150, 150, 150); } @@ -290,7 +299,7 @@ input[type="text"] { height: 100%; position: relative; border: none; - padding: 0px var(--elem-size) 0px 10px; + padding: 0px calc(var(--elem-size) + 5px) 0px 10px; } .chat > .conversation-main > .send-message-form > .input-container > diff --git a/src/utils/general_settings.js b/src/utils/general_settings.js index f5104b5..e116dc0 100644 --- a/src/utils/general_settings.js +++ b/src/utils/general_settings.js @@ -34,7 +34,7 @@ const MODEL_SETTINGS_KEY = 'general-model-settings' * @property {Number} temperature */ const DEFAULT_MODEL_SETTINGS = { - max_tokens: 128, + max_tokens: 1024, top_p: 0.9, temperature: 0.7 } diff --git a/src/utils/types.js b/src/utils/types.js index a607171..823ab1a 100644 --- a/src/utils/types.js +++ b/src/utils/types.js @@ -6,4 +6,6 @@ export const LOAD_FINISHED = 1; export const LOAD_SET_SETTINGS = 2; export const LOAD_SKIP_SETTINGS = 3; -export const DEFAULT_LLAMA_CPP_MODEL_URL = "https://huggingface.co/aisuko/Phi-3-mini-4k-instruct-gguf/resolve/main/Phi3-mini-4k-instruct-Q4.gguf" \ No newline at end of file +export const DEFAULT_LLAMA_CPP_MODEL_URL = "https://huggingface.co/aisuko/Phi-3-mini-4k-instruct-gguf/resolve/main/Phi3-mini-4k-instruct-Q4.gguf" + +export const MIN_TOKENS = 32; \ No newline at end of file diff --git a/src/utils/workers/aws-worker.js b/src/utils/workers/aws-worker.js index b598113..6f6e85e 100644 --- a/src/utils/workers/aws-worker.js +++ b/src/utils/workers/aws-worker.js @@ -125,15 +125,19 @@ export async function chatCompletions(messages, cb = null) { } }) - const { max_tokens:maxTokens, top_p:topP, temperature } = getModelSettings(); + const { max_tokens, top_p:topP, temperature } = getModelSettings(); const input = { modelId: aws_model_id, messages: normal_messages, inferenceConfig: { - maxTokens, temperature, topP + temperature, topP } } + if(max_tokens) { + input.inferenceConfig.maxTokens = max_tokens + } + if(system.length) input.system = system; let response_text = '', usage = {} @@ -195,9 +199,9 @@ export async function formator(messages, files = []) { if(files.length) { for(const file of files) { const file_info = file.name.split('.') - const extension = file_info.pop(); + const extension = file_info.pop().toLowerCase(); const filename = file_info.join('_'); - const bytes = await file.arrayBuffer() + const bytes = new Uint8Array(await file.arrayBuffer()) if(/^image\/.+/.test(file.type)) { common_messages[common_messages.length - 1].content.push( @@ -209,11 +213,12 @@ export async function formator(messages, files = []) { } ) } else { + const is_valid_format = /^(docx|csv|html|txt|pdf|md|doc|xlsx|xls)$/.test(extension) common_messages[common_messages.length - 1].content.push( { document: { - name: filename, - format: extension, + name: filename + (is_valid_format ? '' : `_${extension}`), + format: is_valid_format ? extension : 'txt' , source: { bytes } } } diff --git a/src/utils/workers/wllama-worker.js b/src/utils/workers/wllama-worker.js index 1440835..5f945e3 100644 --- a/src/utils/workers/wllama-worker.js +++ b/src/utils/workers/wllama-worker.js @@ -35,10 +35,11 @@ export function loadModelSamplingSettings() { n_threads: wllama_threads, n_batch: wllama_batch_size, n_ctx: wllama_context_length, - nPredict: max_tokens, temp: temperature, top_p } + + if(max_tokens) model_sampling_settings.nPredict = max_tokens; } loadModelSamplingSettings();