Skip to content

Commit

Permalink
fix some stuff for llama2 compat
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Collina <[email protected]>
  • Loading branch information
mcollina committed May 9, 2024
1 parent 0472cb2 commit f6ba320
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 13 deletions.
15 changes: 15 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,21 @@ Steps for downloading and setting up AI Warp for local development.
```bash
node ../dist/cli/start.js
```
### Testing a model with OpenAI
To test a remote model with with OpenAI, you can use the following to
download the model we used for testing:
```json
"aiProvider": {
"openai": {
"model": "gpt-3.5-turbo",
"apiKey": "{PLT_OPENAI_API_KEY}"
}
}
```
Make sure to add your OpenAI api key as `PLT_OPENAI_API_KEY` in your `.env` file.
### Testing a local model with llama2
Expand Down
43 changes: 34 additions & 9 deletions ai-providers/llama2.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { ReadableByteStreamController, ReadableStream, UnderlyingByteSource } from 'stream/web'
import { FastifyLoggerInstance } from 'fastify'
import {
LLamaChatPromptOptions,
LlamaChatSession,
Expand Down Expand Up @@ -60,13 +61,16 @@ class Llama2ByteSource implements UnderlyingByteSource {
backloggedChunks: ChunkQueue = new ChunkQueue()
finished: boolean = false
controller?: ReadableByteStreamController
abortController: AbortController

constructor (session: LlamaChatSession, prompt: string, chunkCallback?: StreamChunkCallback) {
constructor (session: LlamaChatSession, prompt: string, logger: FastifyLoggerInstance, chunkCallback?: StreamChunkCallback) {
this.session = session
this.chunkCallback = chunkCallback
this.abortController = new AbortController()

session.prompt(prompt, {
onToken: this.onToken
onToken: this.onToken,
signal: this.abortController.signal
}).then(() => {
this.finished = true
// Don't close the stream if we still have chunks to send
Expand All @@ -75,22 +79,36 @@ class Llama2ByteSource implements UnderlyingByteSource {
}
}).catch((err: any) => {
this.finished = true
if (this.controller !== undefined) {
this.controller.close()
logger.info({ err })
if (!this.abortController.signal.aborted && this.controller !== undefined) {
try {
this.controller.close()
} catch (err) {
logger.info({ err })
}
}
throw err
})
}

cancel (): void {
this.abortController.abort()
}

onToken: LLamaChatPromptOptions['onToken'] = async (chunk) => {
if (this.controller === undefined) {
// Stream hasn't started yet, added it to the backlog queue
this.backloggedChunks.push(chunk)
return
}

await this.clearBacklog()
await this.enqueueChunk(chunk)
try {
await this.clearBacklog()
await this.enqueueChunk(chunk)
// Ignore all errors, we can't do anything about them
// TODO: Log these errors
} catch (err) {
console.error(err)
}
}

private async enqueueChunk (chunk: number[]): Promise<void> {
Expand All @@ -103,6 +121,10 @@ class Llama2ByteSource implements UnderlyingByteSource {
response = await this.chunkCallback(response)
}

if (response === '') {
response = '\n' // It seems empty chunks are newlines
}

const eventData: AiStreamEvent = {
event: 'content',
data: {
Expand Down Expand Up @@ -139,14 +161,17 @@ class Llama2ByteSource implements UnderlyingByteSource {

interface Llama2ProviderCtorOptions {
modelPath: string
logger: FastifyLoggerInstance
}

export class Llama2Provider implements AiProvider {
context: LlamaContext
logger: FastifyLoggerInstance

constructor ({ modelPath }: Llama2ProviderCtorOptions) {
constructor ({ modelPath, logger }: Llama2ProviderCtorOptions) {
const model = new LlamaModel({ modelPath })
this.context = new LlamaContext({ model })
this.logger = logger
}

async ask (prompt: string): Promise<string> {
Expand All @@ -159,6 +184,6 @@ export class Llama2Provider implements AiProvider {
async askStream (prompt: string, chunkCallback?: StreamChunkCallback): Promise<ReadableStream> {
const session = new LlamaChatSession({ context: this.context })

return new ReadableStream(new Llama2ByteSource(session, prompt, chunkCallback))
return new ReadableStream(new Llama2ByteSource(session, prompt, this.logger, chunkCallback))
}
}
10 changes: 7 additions & 3 deletions plugins/warp.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// eslint-disable-next-line
/// <reference path="../index.d.ts" />
import { FastifyLoggerInstance } from 'fastify'
import fastifyPlugin from 'fastify-plugin'
import { OpenAiProvider } from '../ai-providers/open-ai.js'
import { MistralProvider } from '../ai-providers/mistral.js'
Expand All @@ -12,7 +13,7 @@ import { Llama2Provider } from '../ai-providers/llama2.js'

const UnknownAiProviderError = createError('UNKNOWN_AI_PROVIDER', 'Unknown AI Provider')

function build (aiProvider: AiWarpConfig['aiProvider']): AiProvider {
function build (aiProvider: AiWarpConfig['aiProvider'], logger: FastifyLoggerInstance): AiProvider {
if ('openai' in aiProvider) {
return new OpenAiProvider(aiProvider.openai)
} else if ('mistral' in aiProvider) {
Expand All @@ -22,15 +23,18 @@ function build (aiProvider: AiWarpConfig['aiProvider']): AiProvider {
} else if ('azure' in aiProvider) {
return new AzureProvider(aiProvider.azure)
} else if ('llama2' in aiProvider) {
return new Llama2Provider(aiProvider.llama2)
return new Llama2Provider({
...aiProvider.llama2,
logger
})
} else {
throw new UnknownAiProviderError()
}
}

export default fastifyPlugin(async (fastify) => {
const { config } = fastify.platformatic
const provider = build(config.aiProvider)
const provider = build(config.aiProvider, fastify.log)

fastify.decorate('ai', {
warp: async (request, prompt) => {
Expand Down
2 changes: 1 addition & 1 deletion static/scripts/chat.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ const messagesElement = document.getElementById('messages')
/**
* List of completed messages to easily keep track of them instead of making
* calls to the DOM
*
*
* { type: 'prompt' | 'response' | 'error', message?: string }
*/
const messages = []
Expand Down

0 comments on commit f6ba320

Please sign in to comment.