Skip to content

Commit

Permalink
feat(Dynamic Models): together AI Dynamic Models
Browse files Browse the repository at this point in the history
  • Loading branch information
thecodacus committed Dec 2, 2024
1 parent 115dcbb commit 1589d2a
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 25 deletions.
18 changes: 7 additions & 11 deletions app/lib/.server/llm/stream-text.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-nocheck – TODO: Provider proper types

import { convertToCoreMessages, streamText as _streamText } from 'ai';
import { getModel } from '~/lib/.server/llm/model';
import { MAX_TOKENS } from './constants';
import { getSystemPrompt } from './prompts';
import { DEFAULT_MODEL, DEFAULT_PROVIDER, MODEL_LIST, MODEL_REGEX, PROVIDER_REGEX } from '~/utils/constants';
import { DEFAULT_MODEL, DEFAULT_PROVIDER, getModelList, MODEL_REGEX, PROVIDER_REGEX } from '~/utils/constants';

interface ToolResult<Name extends string, Args, Result> {
toolCallId: string;
Expand All @@ -32,26 +29,25 @@ function extractPropertiesFromMessage(message: Message): { model: string; provid

// Extract provider
const providerMatch = message.content.match(PROVIDER_REGEX);
const provider = providerMatch ? providerMatch[1] : DEFAULT_PROVIDER;
const provider = providerMatch ? providerMatch[1] : DEFAULT_PROVIDER.name;

// Remove model and provider lines from content
const cleanedContent = message.content.replace(MODEL_REGEX, '').replace(PROVIDER_REGEX, '').trim();

return { model, provider, content: cleanedContent };
}

export function streamText(messages: Messages, env: Env, options?: StreamingOptions, apiKeys?: Record<string, string>) {
export async function streamText(messages: Messages, env: Env, options?: StreamingOptions,apiKeys?: Record<string, string>) {
let currentModel = DEFAULT_MODEL;
let currentProvider = DEFAULT_PROVIDER;

let currentProvider = DEFAULT_PROVIDER.name;
const MODEL_LIST = await getModelList(apiKeys||{});
const processedMessages = messages.map((message) => {
if (message.role === 'user') {
const { model, provider, content } = extractPropertiesFromMessage(message);

if (MODEL_LIST.find((m) => m.name === model)) {
currentModel = model;
}

currentProvider = provider;

return { ...message, content };
Expand All @@ -65,10 +61,10 @@ export function streamText(messages: Messages, env: Env, options?: StreamingOpti
const dynamicMaxTokens = modelDetails && modelDetails.maxTokenAllowed ? modelDetails.maxTokenAllowed : MAX_TOKENS;

return _streamText({
model: getModel(currentProvider, currentModel, env, apiKeys),
model: getModel(currentProvider, currentModel, env, apiKeys) as any,
system: getSystemPrompt(),
maxTokens: dynamicMaxTokens,
messages: convertToCoreMessages(processedMessages),
messages: convertToCoreMessages(processedMessages as any),
...options,
});
}
14 changes: 5 additions & 9 deletions app/routes/api.chat.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-nocheck – TODO: Provider proper types

import { type ActionFunctionArgs } from '@remix-run/cloudflare';
import { MAX_RESPONSE_SEGMENTS, MAX_TOKENS } from '~/lib/.server/llm/constants';
import { CONTINUE_PROMPT } from '~/lib/.server/llm/prompts';
Expand All @@ -11,8 +8,8 @@ export async function action(args: ActionFunctionArgs) {
return chatAction(args);
}

function parseCookies(cookieHeader) {
const cookies = {};
function parseCookies(cookieHeader:string) {
const cookies:any = {};

// Split the cookie string by semicolons and spaces
const items = cookieHeader.split(';').map((cookie) => cookie.trim());
Expand All @@ -39,14 +36,13 @@ async function chatAction({ context, request }: ActionFunctionArgs) {
const cookieHeader = request.headers.get('Cookie');

// Parse the cookie's value (returns an object or null if no cookie exists)
const apiKeys = JSON.parse(parseCookies(cookieHeader).apiKeys || '{}');
const apiKeys = JSON.parse(parseCookies(cookieHeader||"").apiKeys || '{}');

const stream = new SwitchableStream();

try {
const options: StreamingOptions = {
toolChoice: 'none',
apiKeys,
onFinish: async ({ text: content, finishReason }) => {
if (finishReason !== 'length') {
return stream.close();
Expand All @@ -63,7 +59,7 @@ async function chatAction({ context, request }: ActionFunctionArgs) {
messages.push({ role: 'assistant', content });
messages.push({ role: 'user', content: CONTINUE_PROMPT });

const result = await streamText(messages, context.cloudflare.env, options);
const result = await streamText(messages, context.cloudflare.env, options,apiKeys);

return stream.switchSource(result.toAIStream());
},
Expand All @@ -79,7 +75,7 @@ async function chatAction({ context, request }: ActionFunctionArgs) {
contentType: 'text/plain; charset=utf-8',
},
});
} catch (error) {
} catch (error:any) {
console.log(error);

if (error.message?.includes('API key')) {
Expand Down
2 changes: 1 addition & 1 deletion app/types/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import type { ModelInfo } from '~/utils/types';
export type ProviderInfo = {
staticModels: ModelInfo[];
name: string;
getDynamicModels?: () => Promise<ModelInfo[]>;
getDynamicModels?: (apiKeys?: Record<string, string>) => Promise<ModelInfo[]>;
getApiKeyLink?: string;
labelForGetApiKey?: string;
icon?: string;
Expand Down
83 changes: 80 additions & 3 deletions app/utils/constants.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import Cookies from 'js-cookie';
import { parseCookies } from './parseCookies';
import type { ModelInfo, OllamaApiResponse, OllamaModel } from './types';
import type { ProviderInfo } from '~/types/model';

Expand Down Expand Up @@ -262,6 +264,7 @@ const PROVIDER_LIST: ProviderInfo[] = [
},
{
name: 'Together',
getDynamicModels: getTogetherModels,
staticModels: [
{
name: 'Qwen/Qwen2.5-Coder-32B-Instruct',
Expand Down Expand Up @@ -293,6 +296,61 @@ const staticModels: ModelInfo[] = PROVIDER_LIST.map((p) => p.staticModels).flat(

export let MODEL_LIST: ModelInfo[] = [...staticModels];


export async function getModelList(apiKeys: Record<string, string>) {
MODEL_LIST = [
...(
await Promise.all(
PROVIDER_LIST.filter(
(p): p is ProviderInfo & { getDynamicModels: () => Promise<ModelInfo[]> } => !!p.getDynamicModels,
).map((p) => p.getDynamicModels(apiKeys)),
)
).flat(),
...staticModels,
];
return MODEL_LIST;
}

async function getTogetherModels(apiKeys?: Record<string, string>): Promise<ModelInfo[]> {
try {
let baseUrl = import.meta.env.TOGETHER_API_BASE_URL || '';
let provider='Together'

if (!baseUrl) {
return [];
}
let apiKey = import.meta.env.OPENAI_LIKE_API_KEY ?? ''

if (apiKeys && apiKeys[provider]) {
apiKey = apiKeys[provider];
}

if (!apiKey) {
return [];
}

const response = await fetch(`${baseUrl}/models`, {
headers: {
Authorization: `Bearer ${apiKey}`,
},
});
const res = (await response.json()) as any;
let data: any[] = (res || []).filter((model: any) => model.type=='chat')
return data.map((m: any) => ({
name: m.id,
label: `${m.display_name} - in:$${(m.pricing.input).toFixed(
2,
)} out:$${(m.pricing.output).toFixed(2)} - context ${Math.floor(m.context_length / 1000)}k`,
provider: provider,
maxTokenAllowed: 8000,
}));
} catch (e) {
console.error('Error getting OpenAILike models:', e);
return [];
}
}


const getOllamaBaseUrl = () => {
const defaultBaseUrl = import.meta.env.OLLAMA_API_BASE_URL || 'http://localhost:11434';

Expand Down Expand Up @@ -339,8 +397,13 @@ async function getOpenAILikeModels(): Promise<ModelInfo[]> {
if (!baseUrl) {
return [];
}
let apiKey = import.meta.env.OPENAI_LIKE_API_KEY ?? '';

let apikeys = JSON.parse(Cookies.get('apiKeys')||'{}')
if (apikeys && apikeys['OpenAILike']){
apiKey = apikeys['OpenAILike'];
}

const apiKey = import.meta.env.OPENAI_LIKE_API_KEY ?? '';
const response = await fetch(`${baseUrl}/models`, {
headers: {
Authorization: `Bearer ${apiKey}`,
Expand Down Expand Up @@ -396,7 +459,6 @@ async function getLMStudioModels(): Promise<ModelInfo[]> {
if (typeof window === 'undefined') {
return [];
}

try {
const baseUrl = import.meta.env.LMSTUDIO_API_BASE_URL || 'http://localhost:1234';
const response = await fetch(`${baseUrl}/v1/models`);
Expand All @@ -414,12 +476,27 @@ async function getLMStudioModels(): Promise<ModelInfo[]> {
}

async function initializeModelList(): Promise<ModelInfo[]> {
let apiKeys: Record<string, string> = {};
try {
const storedApiKeys = Cookies.get('apiKeys');

if (storedApiKeys) {
const parsedKeys = JSON.parse(storedApiKeys);

if (typeof parsedKeys === 'object' && parsedKeys !== null) {
apiKeys = parsedKeys;
}
}

} catch (error) {

}
MODEL_LIST = [
...(
await Promise.all(
PROVIDER_LIST.filter(
(p): p is ProviderInfo & { getDynamicModels: () => Promise<ModelInfo[]> } => !!p.getDynamicModels,
).map((p) => p.getDynamicModels()),
).map((p) => p.getDynamicModels(apiKeys)),
)
).flat(),
...staticModels,
Expand Down
19 changes: 19 additions & 0 deletions app/utils/parseCookies.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
export function parseCookies(cookieHeader: string) {
const cookies: any = {};

// Split the cookie string by semicolons and spaces
const items = cookieHeader.split(';').map((cookie) => cookie.trim());

items.forEach((item) => {
const [name, ...rest] = item.split('=');

if (name && rest) {
// Decode the name and value, and join value parts in case it contains '='
const decodedName = decodeURIComponent(name.trim());
const decodedValue = decodeURIComponent(rest.join('=').trim());
cookies[decodedName] = decodedValue;
}
});

return cookies;
}
2 changes: 1 addition & 1 deletion vite.config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ export default defineConfig((config) => {
chrome129IssuePlugin(),
config.mode === 'production' && optimizeCssModules({ apply: 'build' }),
],
envPrefix:["VITE_","OPENAI_LIKE_API_","OLLAMA_API_BASE_URL","LMSTUDIO_API_BASE_URL"],
envPrefix: ["VITE_", "OPENAI_LIKE_API_", "OLLAMA_API_BASE_URL", "LMSTUDIO_API_BASE_URL","TOGETHER_API_BASE_URL"],
css: {
preprocessorOptions: {
scss: {
Expand Down

0 comments on commit 1589d2a

Please sign in to comment.