diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index ed5231b5..8c2af36a 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -5,9 +5,6 @@ on: tags: - "*" - branches: - - main - permissions: contents: write packages: write @@ -70,7 +67,10 @@ jobs: fetch-depth: 0 - name: login into Github Container Registry - run: echo "${{ secrets.GITHUB_TOKEN }}" | docker login ghcr.io -u $ --password-stdin + run: echo "${{ secrets.GITHUB_TOKEN }}" | docker login ghcr.io -u --password-stdin + + - name: login into Github Container Registry + run: echo "${{ secrets.DOCKER_HUB_TOKEN }}" | docker login -u einstack --password-stdin - name: login into Github Container Registry run: echo "${{ secrets.DOCKER_HUB_TOKEN }}" | docker login -u einstack $ --password-stdin @@ -85,6 +85,6 @@ jobs: working-directory: ./images run: VERSION=${{ github.ref_name }} make publish-ghcr-${{ matrix.image }} - - name: publish ${{ matrix.image }} image to Github Container Registry + - name: publish ${{ matrix.image }} image to DockerHub working-directory: ./images run: VERSION=${{ github.ref_name }} make publish-docherhub-${{ matrix.image }} diff --git a/CHANGELOG.md b/CHANGELOG.md index 48a393dc..05c37577 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,18 @@ The changelog consists of three categories: - **Improvements** - bugfixes, performance and other types of improvements to existing functionality - **Miscellaneous** - all other updates like build, release, CLI, etc. +## 0.0.2-rc.2 (Feb 22nd, 2024) + +### Features + +- ✨ #142: [Lang Chat Router] Ollama Support (@mkrueger12) +- ✨ #131: [Lang Chat Router] AWS Bedrock Support (@mkrueger12) + +### Miscellaneous + +- πŸ‘· #155 Fixing the dockerhub authorization step in the release workflow (@roma-glushko) +- ♻️ #151: Moved specific provider schemas closer to provider's packages (@roma-glushko) + ## 0.0.2-rc.1 (Feb 12th, 2024) ### Features diff --git a/README.md b/README.md index d35e914a..c3597081 100644 --- a/README.md +++ b/README.md @@ -1,22 +1,24 @@ -# Glide: Cloud-Native LLM Gateway for Seamless LLMOps
Glide GH Header +

Glide: Cloud-Native LLM Gateway for Seamless LLMOps

+ CodeCov + Discord + Glide Docs + License + ArtifactHub + FOSSA Status
- -[![codecov](https://codecov.io/github/EinStack/glide/graph/badge.svg?token=F7JT39RHX9)](https://codecov.io/github/EinStack/glide) -[![Discord](https://img.shields.io/discord/1181281407813828710)](https://discord.gg/pt53Ej7rrc) -[![Documentation](https://img.shields.io/badge/build-view-violet%20?style=flat&logo=books&label=docs&link=https%3A%2F%2Fglide.einstack.ai%2F)](https://glide.einstack.ai/) -[![LICENSE](https://img.shields.io/github/license/EinStack/glide.svg?style=flat-square&color=%233f90c8)](https://github.com/EinStack/glide/blob/main/LICENSE) -[![FOSSA Status](https://app.fossa.com/api/projects/git%2Bgithub.com%2FEinStack%2Fglide.svg?type=shield)](https://app.fossa.com/projects/git%2Bgithub.com%2FEinStack%2Fglide?ref=badge_shield) - --- -Glide is your go-to cloud-native LLM gateway, delivering high-performance LLMOps in a lightweight, all-in-one package. +**Glide** is your go-to cloud-native LLM gateway, delivering high-performance LLMOps in a lightweight, all-in-one package. We take all problems of managing and communicating with external providers out of your applications, so you can dive into tackling your core challenges. +> [!Important] +> Give us a star⭐ to support the project and watchπŸ‘€ our repositories not to miss any update. Appriciate your interest πŸ™ + Glide sits between your application and model providers to seamlessly handle various LLMOps tasks like model failover, caching, key management, etc. @@ -27,7 +29,7 @@ Take a look at the develop branch. Check out our [documentation](https://glide.einstack.ai)! > [!Warning] -> Glide is under active development right now. Give us a star to support the project ✨ +> Glide is under active development right now πŸ› οΈ ## Features @@ -38,35 +40,18 @@ Check out our [documentation](https://glide.einstack.ai)! - **Production-ready observability** via OpenTelemetry, emit metrics on models health, allows whitebox monitoring (coming soon) - Straightforward and simple maintenance and configuration, centralized API key control & management & rotation, etc. -## Supported Providers - ### Large Language Models | | Provider | Support Status | |-----------------------------------------------------|---------------|-----------------| -| | OpenAI | πŸ‘ Supported | +| | Anthropic | πŸ‘ Supported | | | Azure OpenAI | πŸ‘ Supported | +| | AWS Bedrock (Titan) | πŸ‘ Supported | | | Cohere | πŸ‘ Supported | -| | OctoML | πŸ‘ Supported | -| | Anthropic | πŸ‘ Supported | | | Google Gemini | πŸ—οΈ Coming Soon | - - -### Routers - -Routers are a core functionality of Glide. Think of routers as a group of models with some predefined logic. For example, the resilience router allows a user to define a set of backup models should the initial model fail. Another example, would be to leverage the least-latency router to make latency sensitive LLM calls in the most efficient manner. - -Detailed info on routers can be found [here](https://glide.einstack.ai/essentials/routers). - -#### Available Routers - -| Router | Description | -|---------------|-----------------| -| Priority | When the target model fails the request is sent to the secondary model. The entire service instance keeps track of the number of failures for a specific model reducing latency upon model failure | -| Least Latency | This router selects the model with the lowest average latency over time. If the least latency model becomes unhealthy, it will pick the second the best, etc. | -| Round Robin | Split traffic equally among specified models. Great for A/B testing. | -| Weighted Round Robin | Split traffic based on weights. For example, 70% of traffic to Model A and 30% of traffic to Model B. | - +| | OctoML | πŸ‘ Supported | +| | Ollama | πŸ‘ Supported | +| | OpenAI | πŸ‘ Supported | ## Get Started @@ -107,7 +92,6 @@ See [API Reference](https://glide.einstack.ai/api-reference/introduction) for mo ```json { - "model": "gpt-3.5-turbo", # this is not required but can be used to specify different prompts to different models "message": { "role": "user", @@ -196,7 +180,58 @@ docker pull ghcr.io/einstack/glide:latest-redhat ### Helm Chart -Coming Soon +Add the EinStack repository: + +```bash +helm repo add einstack https://einstack.github.io/helm-charts +helm repo update +``` + +Before installing the Helm chart, you need to create a Kubernetes secret with your API keys like: + +```bash +kubectl create secret generic api-keys --from-literal=OPENAI_API_KEY=sk-abcdXYZ +``` + +Then, you need to create a custom values.yaml file to override the secret name like: + +```yaml +# save as custom.values.yaml, for example +glide: + apiKeySecret: "api-keys" +``` + +Finally, you should be able to install Glide's chart via: + +``` +helm upgrade glide-gateway einstack/glide --values custom.values.yaml --install +``` + +## SDKs + +To let you work with Glide's API with ease, we are going to provide you with SDKs that fits your tech stack: + +- Python (coming soon) +- NodeJS (coming soon) +- Golang (coming soon) +- Rust (coming soon) + +## Core Concepts + +### Routers + +Routers are a core functionality of Glide. Think of routers as a group of models with some predefined logic. For example, the resilience router allows a user to define a set of backup models should the initial model fail. Another example, would be to leverage the least-latency router to make latency sensitive LLM calls in the most efficient manner. + +Detailed info on routers can be found [here](https://glide.einstack.ai/essentials/routers). + +#### Available Routers + +| Router | Description | +|---------------|-----------------| +| Priority | When the target model fails the request is sent to the secondary model. The entire service instance keeps track of the number of failures for a specific model reducing latency upon model failure | +| Least Latency | This router selects the model with the lowest average latency over time. If the least latency model becomes unhealthy, it will pick the second the best, etc. | +| Round Robin | Split traffic equally among specified models. Great for A/B testing. | +| Weighted Round Robin | Split traffic based on weights. For example, 70% of traffic to Model A and 30% of traffic to Model B. | ## Community diff --git a/docs/docs.go b/docs/docs.go index 5ede09c8..96ddabc5 100644 --- a/docs/docs.go +++ b/docs/docs.go @@ -98,7 +98,7 @@ const docTemplate = `{ "in": "body", "required": true, "schema": { - "$ref": "#/definitions/schemas.UnifiedChatRequest" + "$ref": "#/definitions/schemas.ChatRequest" } } ], @@ -106,7 +106,7 @@ const docTemplate = `{ "200": { "description": "OK", "schema": { - "$ref": "#/definitions/schemas.UnifiedChatResponse" + "$ref": "#/definitions/schemas.ChatResponse" } }, "400": { @@ -256,6 +256,52 @@ const docTemplate = `{ } } }, + "bedrock.Config": { + "type": "object", + "required": [ + "awsRegion", + "baseUrl", + "chatEndpoint", + "model" + ], + "properties": { + "awsRegion": { + "type": "string" + }, + "baseUrl": { + "type": "string" + }, + "chatEndpoint": { + "type": "string" + }, + "defaultParams": { + "$ref": "#/definitions/bedrock.Params" + }, + "model": { + "type": "string" + } + } + }, + "bedrock.Params": { + "type": "object", + "properties": { + "max_tokens": { + "type": "integer" + }, + "stop": { + "type": "array", + "items": { + "type": "string" + } + }, + "temperature": { + "type": "number" + }, + "top_p": { + "type": "number" + } + } + }, "clients.ClientConfig": { "type": "object", "properties": { @@ -431,6 +477,84 @@ const docTemplate = `{ } } }, + "ollama.Config": { + "type": "object", + "required": [ + "baseUrl", + "chatEndpoint", + "model" + ], + "properties": { + "baseUrl": { + "type": "string" + }, + "chatEndpoint": { + "type": "string" + }, + "defaultParams": { + "$ref": "#/definitions/ollama.Params" + }, + "model": { + "type": "string" + } + } + }, + "ollama.Params": { + "type": "object", + "properties": { + "microstat": { + "type": "integer" + }, + "microstat_eta": { + "type": "number" + }, + "microstat_tau": { + "type": "number" + }, + "num_ctx": { + "type": "integer" + }, + "num_gpu": { + "type": "integer" + }, + "num_gqa": { + "type": "integer" + }, + "num_predict": { + "type": "integer" + }, + "num_thread": { + "type": "integer" + }, + "repeat_last_n": { + "type": "integer" + }, + "seed": { + "type": "integer" + }, + "stop": { + "type": "array", + "items": { + "type": "string" + } + }, + "stream": { + "type": "boolean" + }, + "temperature": { + "type": "number" + }, + "tfs_z": { + "type": "number" + }, + "top_k": { + "type": "integer" + }, + "top_p": { + "type": "number" + } + } + }, "openai.Config": { "type": "object", "required": [ @@ -517,6 +641,9 @@ const docTemplate = `{ "azureopenai": { "$ref": "#/definitions/azureopenai.Config" }, + "bedrock": { + "$ref": "#/definitions/bedrock.Config" + }, "client": { "$ref": "#/definitions/clients.ClientConfig" }, @@ -540,6 +667,9 @@ const docTemplate = `{ "octoml": { "$ref": "#/definitions/octoml.Config" }, + "ollama": { + "$ref": "#/definitions/ollama.Config" + }, "openai": { "description": "Add other providers like", "allOf": [ @@ -627,49 +757,7 @@ const docTemplate = `{ } } }, - "schemas.OverrideChatRequest": { - "type": "object", - "properties": { - "message": { - "$ref": "#/definitions/schemas.ChatMessage" - }, - "model_id": { - "type": "string" - } - } - }, - "schemas.ProviderResponse": { - "type": "object", - "properties": { - "message": { - "$ref": "#/definitions/schemas.ChatMessage" - }, - "responseId": { - "type": "object", - "additionalProperties": { - "type": "string" - } - }, - "tokenCount": { - "$ref": "#/definitions/schemas.TokenUsage" - } - } - }, - "schemas.TokenUsage": { - "type": "object", - "properties": { - "promptTokens": { - "type": "number" - }, - "responseTokens": { - "type": "number" - }, - "totalTokens": { - "type": "number" - } - } - }, - "schemas.UnifiedChatRequest": { + "schemas.ChatRequest": { "type": "object", "properties": { "message": { @@ -686,7 +774,7 @@ const docTemplate = `{ } } }, - "schemas.UnifiedChatResponse": { + "schemas.ChatResponse": { "type": "object", "properties": { "cached": { @@ -714,6 +802,48 @@ const docTemplate = `{ "type": "string" } } + }, + "schemas.OverrideChatRequest": { + "type": "object", + "properties": { + "message": { + "$ref": "#/definitions/schemas.ChatMessage" + }, + "model_id": { + "type": "string" + } + } + }, + "schemas.ProviderResponse": { + "type": "object", + "properties": { + "message": { + "$ref": "#/definitions/schemas.ChatMessage" + }, + "responseId": { + "type": "object", + "additionalProperties": { + "type": "string" + } + }, + "tokenCount": { + "$ref": "#/definitions/schemas.TokenUsage" + } + } + }, + "schemas.TokenUsage": { + "type": "object", + "properties": { + "promptTokens": { + "type": "number" + }, + "responseTokens": { + "type": "number" + }, + "totalTokens": { + "type": "number" + } + } } }, "externalDocs": { diff --git a/docs/images/aws-icon.png b/docs/images/aws-icon.png new file mode 100644 index 00000000..2ca16dbd Binary files /dev/null and b/docs/images/aws-icon.png differ diff --git a/docs/images/ollama.png b/docs/images/ollama.png new file mode 100644 index 00000000..8cd2cf1e Binary files /dev/null and b/docs/images/ollama.png differ diff --git a/docs/swagger.json b/docs/swagger.json index 8146a113..aee257b6 100644 --- a/docs/swagger.json +++ b/docs/swagger.json @@ -95,7 +95,7 @@ "in": "body", "required": true, "schema": { - "$ref": "#/definitions/schemas.UnifiedChatRequest" + "$ref": "#/definitions/schemas.ChatRequest" } } ], @@ -103,7 +103,7 @@ "200": { "description": "OK", "schema": { - "$ref": "#/definitions/schemas.UnifiedChatResponse" + "$ref": "#/definitions/schemas.ChatResponse" } }, "400": { @@ -253,6 +253,52 @@ } } }, + "bedrock.Config": { + "type": "object", + "required": [ + "awsRegion", + "baseUrl", + "chatEndpoint", + "model" + ], + "properties": { + "awsRegion": { + "type": "string" + }, + "baseUrl": { + "type": "string" + }, + "chatEndpoint": { + "type": "string" + }, + "defaultParams": { + "$ref": "#/definitions/bedrock.Params" + }, + "model": { + "type": "string" + } + } + }, + "bedrock.Params": { + "type": "object", + "properties": { + "max_tokens": { + "type": "integer" + }, + "stop": { + "type": "array", + "items": { + "type": "string" + } + }, + "temperature": { + "type": "number" + }, + "top_p": { + "type": "number" + } + } + }, "clients.ClientConfig": { "type": "object", "properties": { @@ -428,6 +474,84 @@ } } }, + "ollama.Config": { + "type": "object", + "required": [ + "baseUrl", + "chatEndpoint", + "model" + ], + "properties": { + "baseUrl": { + "type": "string" + }, + "chatEndpoint": { + "type": "string" + }, + "defaultParams": { + "$ref": "#/definitions/ollama.Params" + }, + "model": { + "type": "string" + } + } + }, + "ollama.Params": { + "type": "object", + "properties": { + "microstat": { + "type": "integer" + }, + "microstat_eta": { + "type": "number" + }, + "microstat_tau": { + "type": "number" + }, + "num_ctx": { + "type": "integer" + }, + "num_gpu": { + "type": "integer" + }, + "num_gqa": { + "type": "integer" + }, + "num_predict": { + "type": "integer" + }, + "num_thread": { + "type": "integer" + }, + "repeat_last_n": { + "type": "integer" + }, + "seed": { + "type": "integer" + }, + "stop": { + "type": "array", + "items": { + "type": "string" + } + }, + "stream": { + "type": "boolean" + }, + "temperature": { + "type": "number" + }, + "tfs_z": { + "type": "number" + }, + "top_k": { + "type": "integer" + }, + "top_p": { + "type": "number" + } + } + }, "openai.Config": { "type": "object", "required": [ @@ -514,6 +638,9 @@ "azureopenai": { "$ref": "#/definitions/azureopenai.Config" }, + "bedrock": { + "$ref": "#/definitions/bedrock.Config" + }, "client": { "$ref": "#/definitions/clients.ClientConfig" }, @@ -537,6 +664,9 @@ "octoml": { "$ref": "#/definitions/octoml.Config" }, + "ollama": { + "$ref": "#/definitions/ollama.Config" + }, "openai": { "description": "Add other providers like", "allOf": [ @@ -624,49 +754,7 @@ } } }, - "schemas.OverrideChatRequest": { - "type": "object", - "properties": { - "message": { - "$ref": "#/definitions/schemas.ChatMessage" - }, - "model_id": { - "type": "string" - } - } - }, - "schemas.ProviderResponse": { - "type": "object", - "properties": { - "message": { - "$ref": "#/definitions/schemas.ChatMessage" - }, - "responseId": { - "type": "object", - "additionalProperties": { - "type": "string" - } - }, - "tokenCount": { - "$ref": "#/definitions/schemas.TokenUsage" - } - } - }, - "schemas.TokenUsage": { - "type": "object", - "properties": { - "promptTokens": { - "type": "number" - }, - "responseTokens": { - "type": "number" - }, - "totalTokens": { - "type": "number" - } - } - }, - "schemas.UnifiedChatRequest": { + "schemas.ChatRequest": { "type": "object", "properties": { "message": { @@ -683,7 +771,7 @@ } } }, - "schemas.UnifiedChatResponse": { + "schemas.ChatResponse": { "type": "object", "properties": { "cached": { @@ -711,6 +799,48 @@ "type": "string" } } + }, + "schemas.OverrideChatRequest": { + "type": "object", + "properties": { + "message": { + "$ref": "#/definitions/schemas.ChatMessage" + }, + "model_id": { + "type": "string" + } + } + }, + "schemas.ProviderResponse": { + "type": "object", + "properties": { + "message": { + "$ref": "#/definitions/schemas.ChatMessage" + }, + "responseId": { + "type": "object", + "additionalProperties": { + "type": "string" + } + }, + "tokenCount": { + "$ref": "#/definitions/schemas.TokenUsage" + } + } + }, + "schemas.TokenUsage": { + "type": "object", + "properties": { + "promptTokens": { + "type": "number" + }, + "responseTokens": { + "type": "number" + }, + "totalTokens": { + "type": "number" + } + } } }, "externalDocs": { diff --git a/docs/swagger.yaml b/docs/swagger.yaml index c0b25776..d5fb088f 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -90,6 +90,37 @@ definitions: user: type: string type: object + bedrock.Config: + properties: + awsRegion: + type: string + baseUrl: + type: string + chatEndpoint: + type: string + defaultParams: + $ref: '#/definitions/bedrock.Params' + model: + type: string + required: + - awsRegion + - baseUrl + - chatEndpoint + - model + type: object + bedrock.Params: + properties: + max_tokens: + type: integer + stop: + items: + type: string + type: array + temperature: + type: number + top_p: + type: number + type: object clients.ClientConfig: properties: timeout: @@ -207,6 +238,58 @@ definitions: top_p: type: number type: object + ollama.Config: + properties: + baseUrl: + type: string + chatEndpoint: + type: string + defaultParams: + $ref: '#/definitions/ollama.Params' + model: + type: string + required: + - baseUrl + - chatEndpoint + - model + type: object + ollama.Params: + properties: + microstat: + type: integer + microstat_eta: + type: number + microstat_tau: + type: number + num_ctx: + type: integer + num_gpu: + type: integer + num_gqa: + type: integer + num_predict: + type: integer + num_thread: + type: integer + repeat_last_n: + type: integer + seed: + type: integer + stop: + items: + type: string + type: array + stream: + type: boolean + temperature: + type: number + tfs_z: + type: number + top_k: + type: integer + top_p: + type: number + type: object openai.Config: properties: baseUrl: @@ -262,6 +345,8 @@ definitions: $ref: '#/definitions/anthropic.Config' azureopenai: $ref: '#/definitions/azureopenai.Config' + bedrock: + $ref: '#/definitions/bedrock.Config' client: $ref: '#/definitions/clients.ClientConfig' cohere: @@ -278,6 +363,8 @@ definitions: $ref: '#/definitions/latency.Config' octoml: $ref: '#/definitions/octoml.Config' + ollama: + $ref: '#/definitions/ollama.Config' openai: allOf: - $ref: '#/definitions/openai.Config' @@ -342,34 +429,7 @@ definitions: or assistant. type: string type: object - schemas.OverrideChatRequest: - properties: - message: - $ref: '#/definitions/schemas.ChatMessage' - model_id: - type: string - type: object - schemas.ProviderResponse: - properties: - message: - $ref: '#/definitions/schemas.ChatMessage' - responseId: - additionalProperties: - type: string - type: object - tokenCount: - $ref: '#/definitions/schemas.TokenUsage' - type: object - schemas.TokenUsage: - properties: - promptTokens: - type: number - responseTokens: - type: number - totalTokens: - type: number - type: object - schemas.UnifiedChatRequest: + schemas.ChatRequest: properties: message: $ref: '#/definitions/schemas.ChatMessage' @@ -380,7 +440,7 @@ definitions: override: $ref: '#/definitions/schemas.OverrideChatRequest' type: object - schemas.UnifiedChatResponse: + schemas.ChatResponse: properties: cached: type: boolean @@ -399,6 +459,33 @@ definitions: router: type: string type: object + schemas.OverrideChatRequest: + properties: + message: + $ref: '#/definitions/schemas.ChatMessage' + model_id: + type: string + type: object + schemas.ProviderResponse: + properties: + message: + $ref: '#/definitions/schemas.ChatMessage' + responseId: + additionalProperties: + type: string + type: object + tokenCount: + $ref: '#/definitions/schemas.TokenUsage' + type: object + schemas.TokenUsage: + properties: + promptTokens: + type: number + responseTokens: + type: number + totalTokens: + type: number + type: object externalDocs: description: Documentation url: https://glide.einstack.ai/ @@ -464,14 +551,14 @@ paths: name: payload required: true schema: - $ref: '#/definitions/schemas.UnifiedChatRequest' + $ref: '#/definitions/schemas.ChatRequest' produces: - application/json responses: "200": description: OK schema: - $ref: '#/definitions/schemas.UnifiedChatResponse' + $ref: '#/definitions/schemas.ChatResponse' "400": description: Bad Request schema: diff --git a/go.mod b/go.mod index f2288d51..73a99abe 100644 --- a/go.mod +++ b/go.mod @@ -3,10 +3,15 @@ module glide go 1.21.5 require ( + github.com/aws/aws-sdk-go-v2 v1.24.1 + github.com/aws/aws-sdk-go-v2/config v1.26.6 + github.com/aws/aws-sdk-go-v2/credentials v1.16.16 + github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.5.6 github.com/go-playground/validator/v10 v10.17.0 github.com/gofiber/contrib/fiberzap/v2 v2.1.2 github.com/gofiber/contrib/swagger v1.1.1 github.com/gofiber/fiber/v2 v2.52.0 + github.com/google/uuid v1.6.0 github.com/joho/godotenv v1.5.1 github.com/spf13/cobra v1.8.0 github.com/stretchr/testify v1.8.4 @@ -21,6 +26,17 @@ require ( github.com/KyleBanks/depth v1.2.1 // indirect github.com/andybalholm/brotli v1.0.5 // indirect github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.5.4 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.11 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.10 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.10 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.7.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.10 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.18.7 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.7 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.26.7 // indirect + github.com/aws/smithy-go v1.19.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/go-openapi/analysis v0.21.4 // indirect @@ -35,8 +51,6 @@ require ( github.com/go-openapi/validate v0.22.1 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect - github.com/google/go-cmp v0.5.5 // indirect - github.com/google/uuid v1.5.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/klauspost/compress v1.17.0 // indirect diff --git a/go.sum b/go.sum index 5d17b87c..a64851f6 100644 --- a/go.sum +++ b/go.sum @@ -8,6 +8,36 @@ github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHG github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3dyBCFEj5IhUbnKptjxatkF07cF2ak3yi77so= github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= +github.com/aws/aws-sdk-go-v2 v1.24.1 h1:xAojnj+ktS95YZlDf0zxWBkbFtymPeDP+rvUQIH3uAU= +github.com/aws/aws-sdk-go-v2 v1.24.1/go.mod h1:LNh45Br1YAkEKaAqvmE1m8FUx6a5b/V0oAKV7of29b4= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.5.4 h1:OCs21ST2LrepDfD3lwlQiOqIGp6JiEUqG84GzTDoyJs= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.5.4/go.mod h1:usURWEKSNNAcAZuzRn/9ZYPT8aZQkR7xcCtunK/LkJo= +github.com/aws/aws-sdk-go-v2/config v1.26.6 h1:Z/7w9bUqlRI0FFQpetVuFYEsjzE3h7fpU6HuGmfPL/o= +github.com/aws/aws-sdk-go-v2/config v1.26.6/go.mod h1:uKU6cnDmYCvJ+pxO9S4cWDb2yWWIH5hra+32hVh1MI4= +github.com/aws/aws-sdk-go-v2/credentials v1.16.16 h1:8q6Rliyv0aUFAVtzaldUEcS+T5gbadPbWdV1WcAddK8= +github.com/aws/aws-sdk-go-v2/credentials v1.16.16/go.mod h1:UHVZrdUsv63hPXFo1H7c5fEneoVo9UXiz36QG1GEPi0= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.11 h1:c5I5iH+DZcH3xOIMlz3/tCKJDaHFwYEmxvlh2fAcFo8= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.11/go.mod h1:cRrYDYAMUohBJUtUnOhydaMHtiK/1NZ0Otc9lIb6O0Y= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.10 h1:vF+Zgd9s+H4vOXd5BMaPWykta2a6Ih0AKLq/X6NYKn4= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.10/go.mod h1:6BkRjejp/GR4411UGqkX8+wFMbFbqsUIimfK4XjOKR4= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.10 h1:nYPe006ktcqUji8S2mqXf9c/7NdiKriOwMvWQHgYztw= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.10/go.mod h1:6UV4SZkVvmODfXKql4LCbaZUpF7HO2BX38FgBf9ZOLw= +github.com/aws/aws-sdk-go-v2/internal/ini v1.7.3 h1:n3GDfwqF2tzEkXlv5cuy4iy7LpKDtqDMcNLfZDu9rls= +github.com/aws/aws-sdk-go-v2/internal/ini v1.7.3/go.mod h1:6fQQgfuGmw8Al/3M2IgIllycxV7ZW7WCdVSqfBeUiCY= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.5.6 h1:o6JbuIU5d53AghLHApGekjggjcV6yvIGHWpGxaVW6sw= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.5.6/go.mod h1:iyd1BBtwZS1lU/GW7AlhblRUbppI2IIjH9H6dRF18TM= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4 h1:/b31bi3YVNlkzkBrm9LfpaKoaYZUxIAj4sHfOTmLfqw= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4/go.mod h1:2aGXHFmbInwgP9ZfpmdIfOELL79zhdNYNmReK8qDfdQ= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.10 h1:DBYTXwIGQSGs9w4jKm60F5dmCQ3EEruxdc0MFh+3EY4= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.10/go.mod h1:wohMUQiFdzo0NtxbBg0mSRGZ4vL3n0dKjLTINdcIino= +github.com/aws/aws-sdk-go-v2/service/sso v1.18.7 h1:eajuO3nykDPdYicLlP3AGgOyVN3MOlFmZv7WGTuJPow= +github.com/aws/aws-sdk-go-v2/service/sso v1.18.7/go.mod h1:+mJNDdF+qiUlNKNC3fxn74WWNN+sOiGOEImje+3ScPM= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.7 h1:QPMJf+Jw8E1l7zqhZmMlFw6w1NmfkfiSK8mS4zOx3BA= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.7/go.mod h1:ykf3COxYI0UJmxcfcxcVuz7b6uADi1FkiUz6Eb7AgM8= +github.com/aws/aws-sdk-go-v2/service/sts v1.26.7 h1:NzO4Vrau795RkUdSHKEwiR01FaGzGOH1EETJ+5QHnm0= +github.com/aws/aws-sdk-go-v2/service/sts v1.26.7/go.mod h1:6h2YuIoxaMSCFf5fi1EgZAwdfkGMgDY+DVfa61uLe4U= +github.com/aws/smithy-go v1.19.0 h1:KWFKQV80DpP3vJrrA9sVAHQ5gc2z8i4EzrLhLlWXcBM= +github.com/aws/smithy-go v1.19.0/go.mod h1:NukqUGpCZIILqqiV0NIjeFh24kd/FAa4beRb6nbIUPE= github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -93,11 +123,11 @@ github.com/gofiber/fiber/v2 v2.52.0 h1:S+qXi7y+/Pgvqq4DrSmREGiFwtB7Bu6+QFLuIHYw/ github.com/gofiber/fiber/v2 v2.52.0/go.mod h1:KEOE+cXMhXG0zHc9d8+E38hoX+ZN7bhOtgeF2oT6jrQ= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= +github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/google/uuid v1.5.0 h1:1p67kYwdtXjb0gL0BPiP1Av9wiZPo5A8z2cWkTZ+eyU= -github.com/google/uuid v1.5.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= diff --git a/images/Makefile b/images/Makefile index 533f10c8..145f8b59 100644 --- a/images/Makefile +++ b/images/Makefile @@ -104,7 +104,7 @@ publish-ghcr-%: ## Push images to Github Registry } publish-dockerhub-%: ## Push images to Docker Hub - @echo "🚚Pushing the $* image to Github Registry.." + @echo "🚚Pushing the $* image to Docker Hub.." @docker tag $(REPOSITORY):$(VERSION)-$* $(REPOSITORY):$(VERSION)-$* @echo "- pushing $(REPOSITORY):$(VERSION)-$*" @docker push $(REPOSITORY):$(VERSION)-$* @@ -116,5 +116,8 @@ publish-dockerhub-%: ## Push images to Docker Hub docker tag $(REPOSITORY):$(VERSION)-$* $(REPOSITORY):latest; \ echo "- pushing $(REPOSITORY):latest"; \ docker push $(REPOSITORY):latest; \ + docker tag $(REPOSITORY):$(VERSION)-$* $(REPOSITORY):$(VERSION); \ + echo "- pushing $(REPOSITORY):$(VERSION)"; \ + docker push $(REPOSITORY):$(VERSION); \ fi; \ } diff --git a/pkg/api/http/handlers.go b/pkg/api/http/handlers.go index 611bdedb..c97f542b 100644 --- a/pkg/api/http/handlers.go +++ b/pkg/api/http/handlers.go @@ -21,17 +21,17 @@ type Handler = func(c *fiber.Ctx) error // @Description Talk to different LLMs Chat API via unified endpoint // @tags Language // @Param router path string true "Router ID" -// @Param payload body schemas.UnifiedChatRequest true "Request Data" +// @Param payload body schemas.ChatRequest true "Request Data" // @Accept json // @Produce json -// @Success 200 {object} schemas.UnifiedChatResponse +// @Success 200 {object} schemas.ChatResponse // @Failure 400 {object} http.ErrorSchema // @Failure 404 {object} http.ErrorSchema // @Router /v1/language/{router}/chat [POST] func LangChatHandler(routerManager *routers.RouterManager) Handler { return func(c *fiber.Ctx) error { // Unmarshal request body - var req *schemas.UnifiedChatRequest + var req *schemas.ChatRequest err := c.BodyParser(&req) if err != nil { diff --git a/pkg/api/schemas/language.go b/pkg/api/schemas/language.go index c06699c5..7e2a2cdc 100644 --- a/pkg/api/schemas/language.go +++ b/pkg/api/schemas/language.go @@ -1,7 +1,7 @@ package schemas -// UnifiedChatRequest defines Glide's Chat Request Schema unified across all language models -type UnifiedChatRequest struct { +// ChatRequest defines Glide's Chat Request Schema unified across all language models +type ChatRequest struct { Message ChatMessage `json:"message"` MessageHistory []ChatMessage `json:"messageHistory"` Override OverrideChatRequest `json:"override,omitempty"` @@ -12,8 +12,8 @@ type OverrideChatRequest struct { Message ChatMessage `json:"message"` } -func NewChatFromStr(message string) *UnifiedChatRequest { - return &UnifiedChatRequest{ +func NewChatFromStr(message string) *ChatRequest { + return &ChatRequest{ Message: ChatMessage{ "human", message, @@ -22,8 +22,8 @@ func NewChatFromStr(message string) *UnifiedChatRequest { } } -// UnifiedChatResponse defines Glide's Chat Response Schema unified across all language models -type UnifiedChatResponse struct { +// ChatResponse defines Glide's Chat Response Schema unified across all language models +type ChatResponse struct { ID string `json:"id,omitempty"` Created int `json:"created,omitempty"` Provider string `json:"provider,omitempty"` @@ -58,110 +58,3 @@ type ChatMessage struct { // with a maximum length of 64 characters. Name string `json:"name,omitempty"` } - -// OpenAI Chat Response (also used by Azure OpenAI and OctoML) -// TODO: Should this live here? -type OpenAIChatCompletion struct { - ID string `json:"id"` - Object string `json:"object"` - Created int `json:"created"` - Model string `json:"model"` - SystemFingerprint string `json:"system_fingerprint"` - Choices []Choice `json:"choices"` - Usage Usage `json:"usage"` -} - -type Choice struct { - Index int `json:"index"` - Message ChatMessage `json:"message"` - Logprobs interface{} `json:"logprobs"` - FinishReason string `json:"finish_reason"` -} - -type Usage struct { - PromptTokens float64 `json:"prompt_tokens"` - CompletionTokens float64 `json:"completion_tokens"` - TotalTokens float64 `json:"total_tokens"` -} - -// Cohere Chat Response -type CohereChatCompletion struct { - Text string `json:"text"` - GenerationID string `json:"generation_id"` - ResponseID string `json:"response_id"` - TokenCount CohereTokenCount `json:"token_count"` - Citations []Citation `json:"citations"` - Documents []Documents `json:"documents"` - SearchQueries []SearchQuery `json:"search_queries"` - SearchResults []SearchResults `json:"search_results"` - Meta Meta `json:"meta"` - ToolInputs map[string]interface{} `json:"tool_inputs"` -} - -type CohereTokenCount struct { - PromptTokens float64 `json:"prompt_tokens"` - ResponseTokens float64 `json:"response_tokens"` - TotalTokens float64 `json:"total_tokens"` - BilledTokens float64 `json:"billed_tokens"` -} - -type Meta struct { - APIVersion struct { - Version string `json:"version"` - } `json:"api_version"` - BilledUnits struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - } `json:"billed_units"` -} - -type Citation struct { - Start int `json:"start"` - End int `json:"end"` - Text string `json:"text"` - DocumentID []string `json:"document_id"` -} - -type Documents struct { - ID string `json:"id"` - Data map[string]string `json:"data"` // TODO: This needs to be updated -} - -type SearchQuery struct { - Text string `json:"text"` - GenerationID string `json:"generation_id"` -} - -type SearchResults struct { - SearchQuery []SearchQueryObject `json:"search_query"` - Connectors []ConnectorsResponse `json:"connectors"` - DocumentID []string `json:"documentId"` -} - -type SearchQueryObject struct { - Text string `json:"text"` - GenerationID string `json:"generationId"` -} - -type ConnectorsResponse struct { - ID string `json:"id"` - UserAccessToken string `json:"user_access_token"` - ContOnFail string `json:"continue_on_failure"` - Options map[string]string `json:"options"` -} - -// Anthropic Chat Response -type AnthropicChatCompletion struct { - ID string `json:"id"` - Type string `json:"type"` - Model string `json:"model"` - Role string `json:"role"` - Content []Content `json:"content"` - StopReason string `json:"stop_reason"` - StopSequence string `json:"stop_sequence"` -} - -type Content struct { - Type string `json:"type"` - Text string `json:"text"` -} diff --git a/pkg/providers/anthropic/chat.go b/pkg/providers/anthropic/chat.go index b525bcb9..5a8d8ee3 100644 --- a/pkg/providers/anthropic/chat.go +++ b/pkg/providers/anthropic/chat.go @@ -49,7 +49,7 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest { } } -func NewChatMessagesFromUnifiedRequest(request *schemas.UnifiedChatRequest) []ChatMessage { +func NewChatMessagesFromUnifiedRequest(request *schemas.ChatRequest) []ChatMessage { messages := make([]ChatMessage, 0, len(request.MessageHistory)+1) // Add items from messageHistory first and the new chat message last @@ -63,7 +63,7 @@ func NewChatMessagesFromUnifiedRequest(request *schemas.UnifiedChatRequest) []Ch } // Chat sends a chat request to the specified anthropic model. -func (c *Client) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) (*schemas.UnifiedChatResponse, error) { +func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) { // Create a new chat request chatRequest := c.createChatRequestSchema(request) @@ -79,7 +79,7 @@ func (c *Client) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) return chatResponse, nil } -func (c *Client) createChatRequestSchema(request *schemas.UnifiedChatRequest) *ChatRequest { +func (c *Client) createChatRequestSchema(request *schemas.ChatRequest) *ChatRequest { // TODO: consider using objectpool to optimize memory allocation chatRequest := c.chatRequestTemplate // hoping to get a copy of the template chatRequest.Messages = NewChatMessagesFromUnifiedRequest(request) @@ -87,7 +87,7 @@ func (c *Client) createChatRequestSchema(request *schemas.UnifiedChatRequest) *C return chatRequest } -func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.UnifiedChatResponse, error) { +func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) { // Build request payload rawPayload, err := json.Marshal(payload) if err != nil { @@ -154,7 +154,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche } // Parse the response JSON - var anthropicCompletion schemas.AnthropicChatCompletion + var anthropicCompletion ChatCompletion err = json.Unmarshal(bodyBytes, &anthropicCompletion) if err != nil { @@ -162,8 +162,8 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche return nil, err } - // Map response to UnifiedChatResponse schema - response := schemas.UnifiedChatResponse{ + // Map response to ChatResponse schema + response := schemas.ChatResponse{ ID: anthropicCompletion.ID, Created: int(time.Now().UTC().Unix()), // not provided by anthropic Provider: providerName, diff --git a/pkg/providers/anthropic/client_test.go b/pkg/providers/anthropic/client_test.go index c8927a37..9c301365 100644 --- a/pkg/providers/anthropic/client_test.go +++ b/pkg/providers/anthropic/client_test.go @@ -56,7 +56,7 @@ func TestAnthropicClient_ChatRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - request := schemas.UnifiedChatRequest{Message: schemas.ChatMessage{ + request := schemas.ChatRequest{Message: schemas.ChatMessage{ Role: "human", Content: "What's the biggest animal?", }} @@ -86,7 +86,7 @@ func TestAnthropicClient_BadChatRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - request := schemas.UnifiedChatRequest{Message: schemas.ChatMessage{ + request := schemas.ChatRequest{Message: schemas.ChatMessage{ Role: "human", Content: "What's the biggest animal?", }} diff --git a/pkg/providers/anthropic/schamas.go b/pkg/providers/anthropic/schamas.go new file mode 100644 index 00000000..69b00248 --- /dev/null +++ b/pkg/providers/anthropic/schamas.go @@ -0,0 +1,17 @@ +package anthropic + +// Anthropic Chat Response +type ChatCompletion struct { + ID string `json:"id"` + Type string `json:"type"` + Model string `json:"model"` + Role string `json:"role"` + Content []Content `json:"content"` + StopReason string `json:"stop_reason"` + StopSequence string `json:"stop_sequence"` +} + +type Content struct { + Type string `json:"type"` + Text string `json:"text"` +} diff --git a/pkg/providers/azureopenai/chat.go b/pkg/providers/azureopenai/chat.go index 6fda0305..f961587c 100644 --- a/pkg/providers/azureopenai/chat.go +++ b/pkg/providers/azureopenai/chat.go @@ -9,9 +9,11 @@ import ( "net/http" "time" + "glide/pkg/api/schemas" + "glide/pkg/providers/openai" + "glide/pkg/providers/clients" - "glide/pkg/api/schemas" "go.uber.org/zap" ) @@ -59,7 +61,7 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest { } } -func NewChatMessagesFromUnifiedRequest(request *schemas.UnifiedChatRequest) []ChatMessage { +func NewChatMessagesFromUnifiedRequest(request *schemas.ChatRequest) []ChatMessage { messages := make([]ChatMessage, 0, len(request.MessageHistory)+1) // Add items from messageHistory first and the new chat message last @@ -73,7 +75,7 @@ func NewChatMessagesFromUnifiedRequest(request *schemas.UnifiedChatRequest) []Ch } // Chat sends a chat request to the specified azure openai model. -func (c *Client) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) (*schemas.UnifiedChatResponse, error) { +func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) { // Create a new chat request chatRequest := c.createChatRequestSchema(request) @@ -89,7 +91,7 @@ func (c *Client) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) return chatResponse, nil } -func (c *Client) createChatRequestSchema(request *schemas.UnifiedChatRequest) *ChatRequest { +func (c *Client) createChatRequestSchema(request *schemas.ChatRequest) *ChatRequest { // TODO: consider using objectpool to optimize memory allocation chatRequest := c.chatRequestTemplate // hoping to get a copy of the template chatRequest.Messages = NewChatMessagesFromUnifiedRequest(request) @@ -97,7 +99,7 @@ func (c *Client) createChatRequestSchema(request *schemas.UnifiedChatRequest) *C return chatRequest } -func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.UnifiedChatResponse, error) { +func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) { // Build request payload rawPayload, err := json.Marshal(payload) if err != nil { @@ -164,7 +166,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche } // Parse the response JSON - var openAICompletion schemas.OpenAIChatCompletion + var openAICompletion openai.ChatCompletion err = json.Unmarshal(bodyBytes, &openAICompletion) if err != nil { @@ -175,7 +177,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche openAICompletion.SystemFingerprint = "" // Azure OpenAI doesn't return this // Map response to UnifiedChatResponse schema - response := schemas.UnifiedChatResponse{ + response := schemas.ChatResponse{ ID: openAICompletion.ID, Created: openAICompletion.Created, Provider: providerName, diff --git a/pkg/providers/azureopenai/client_test.go b/pkg/providers/azureopenai/client_test.go index 62080029..8f5de037 100644 --- a/pkg/providers/azureopenai/client_test.go +++ b/pkg/providers/azureopenai/client_test.go @@ -55,7 +55,7 @@ func TestAzureOpenAIClient_ChatRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - request := schemas.UnifiedChatRequest{Message: schemas.ChatMessage{ + request := schemas.ChatRequest{Message: schemas.ChatMessage{ Role: "user", Content: "What's the biggest animal?", }} @@ -88,7 +88,7 @@ func TestAzureOpenAIClient_ChatError(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - request := schemas.UnifiedChatRequest{Message: schemas.ChatMessage{ + request := schemas.ChatRequest{Message: schemas.ChatMessage{ Role: "user", Content: "What's the biggest animal?", }} diff --git a/pkg/providers/bedrock/chat.go b/pkg/providers/bedrock/chat.go new file mode 100644 index 00000000..14feb9bc --- /dev/null +++ b/pkg/providers/bedrock/chat.go @@ -0,0 +1,129 @@ +package bedrock + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "glide/pkg/api/schemas" + + "go.uber.org/zap" + + "github.com/google/uuid" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" +) + +type ChatMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// ChatRequest is an Bedrock-specific request schema +type ChatRequest struct { + Messages string `json:"inputText"` + TextGenerationConfig TextGenerationConfig `json:"textGenerationConfig"` +} + +type TextGenerationConfig struct { + Temperature float64 `json:"temperature"` + TopP float64 `json:"topP"` + MaxTokenCount int `json:"maxTokenCount"` + StopSequences []string `json:"stopSequences,omitempty"` +} + +// NewChatRequestFromConfig fills the struct from the config. Not using reflection because of performance penalty it gives +func NewChatRequestFromConfig(cfg *Config) *ChatRequest { + return &ChatRequest{ + TextGenerationConfig: TextGenerationConfig{ + MaxTokenCount: cfg.DefaultParams.MaxTokens, + StopSequences: cfg.DefaultParams.StopSequence, + Temperature: cfg.DefaultParams.Temperature, + TopP: cfg.DefaultParams.TopP, + }, + } +} + +func NewChatMessagesFromUnifiedRequest(request *schemas.ChatRequest) string { + // message history not yet supported for AWS models + message := fmt.Sprintf("Role: %s, Content: %s", request.Message.Role, request.Message.Content) + + return message +} + +// Chat sends a chat request to the specified bedrock model. +func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) { + // Create a new chat request + chatRequest := c.createChatRequestSchema(request) + + chatResponse, err := c.doChatRequest(ctx, chatRequest) + if err != nil { + return nil, err + } + + if len(chatResponse.ModelResponse.Message.Content) == 0 { + return nil, ErrEmptyResponse + } + + return chatResponse, nil +} + +func (c *Client) createChatRequestSchema(request *schemas.ChatRequest) *ChatRequest { + // TODO: consider using objectpool to optimize memory allocation + chatRequest := c.chatRequestTemplate // hoping to get a copy of the template + chatRequest.Messages = NewChatMessagesFromUnifiedRequest(request) + + return chatRequest +} + +func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) { + rawPayload, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("unable to marshal chat request payload: %w", err) + } + + result, err := c.bedrockClient.InvokeModel(ctx, &bedrockruntime.InvokeModelInput{ + ModelId: aws.String(c.config.Model), + ContentType: aws.String("application/json"), + Body: rawPayload, + }) + if err != nil { + c.telemetry.Logger.Error("Error: Couldn't invoke model. Here's why: %v\n", zap.Error(err)) + return nil, err + } + + var bedrockCompletion ChatCompletion + + err = json.Unmarshal(result.Body, &bedrockCompletion) + if err != nil { + c.telemetry.Logger.Error("failed to parse bedrock chat response", zap.Error(err)) + return nil, err + } + + response := schemas.ChatResponse{ + ID: uuid.NewString(), + Created: int(time.Now().Unix()), + Provider: "aws-bedrock", + Model: c.config.Model, + Cached: false, + ModelResponse: schemas.ProviderResponse{ + SystemID: map[string]string{ + "system_fingerprint": "none", + }, + Message: schemas.ChatMessage{ + Role: "assistant", + Content: bedrockCompletion.Results[0].OutputText, + Name: "", + }, + TokenUsage: schemas.TokenUsage{ + PromptTokens: float64(bedrockCompletion.Results[0].TokenCount), + ResponseTokens: -1, + TotalTokens: float64(bedrockCompletion.Results[0].TokenCount), + }, + }, + } + + return &response, nil +} diff --git a/pkg/providers/bedrock/client.go b/pkg/providers/bedrock/client.go new file mode 100644 index 00000000..130d06a7 --- /dev/null +++ b/pkg/providers/bedrock/client.go @@ -0,0 +1,73 @@ +package bedrock + +import ( + "context" + "errors" + "net/http" + "net/url" + + "glide/pkg/providers/clients" + "glide/pkg/telemetry" + + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" +) + +const ( + providerName = "bedrock" +) + +// ErrEmptyResponse is returned when the OpenAI API returns an empty response. +var ( + ErrEmptyResponse = errors.New("empty response") +) + +// Client is a client for accessing OpenAI API +type Client struct { + baseURL string + bedrockClient *bedrockruntime.Client + chatURL string + chatRequestTemplate *ChatRequest + config *Config + httpClient *http.Client + telemetry *telemetry.Telemetry +} + +// NewClient creates a new OpenAI client for the OpenAI API. +func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *telemetry.Telemetry) (*Client, error) { + chatURL, err := url.JoinPath(providerConfig.BaseURL, providerConfig.ChatEndpoint, providerConfig.Model, "/invoke") + if err != nil { + return nil, err + } + + cfg, _ := config.LoadDefaultConfig(context.TODO(), // Is this the right context? + config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(providerConfig.AccessKey, providerConfig.SecretKey, "")), + config.WithRegion(providerConfig.AWSRegion), + ) + + bedrockClient := bedrockruntime.NewFromConfig(cfg) + + c := &Client{ + baseURL: providerConfig.BaseURL, + bedrockClient: bedrockClient, + chatURL: chatURL, + config: providerConfig, + chatRequestTemplate: NewChatRequestFromConfig(providerConfig), + httpClient: &http.Client{ + Timeout: *clientConfig.Timeout, + // TODO: use values from the config + Transport: &http.Transport{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 2, + }, + }, + telemetry: tel, + } + + return c, nil +} + +func (c *Client) Provider() string { + return providerName +} diff --git a/pkg/providers/bedrock/client_test.go b/pkg/providers/bedrock/client_test.go new file mode 100644 index 00000000..bcbd0fa1 --- /dev/null +++ b/pkg/providers/bedrock/client_test.go @@ -0,0 +1,76 @@ +package bedrock + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "glide/pkg/providers/clients" + + "glide/pkg/api/schemas" + + "glide/pkg/telemetry" + + "github.com/stretchr/testify/require" +) + +// TODO: Need to fix this test + +func TestBedrockClient_ChatRequest(t *testing.T) { + bedrockMock := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rawPayload, _ := io.ReadAll(r.Body) + + var data interface{} + // Parse the JSON body + err := json.Unmarshal(rawPayload, &data) + if err != nil { + t.Errorf("error decoding payload (%q): %v", string(rawPayload), err) + } + + chatResponse, err := os.ReadFile(filepath.Clean("./testdata/chat.success.json")) + if err != nil { + t.Errorf("error reading bedrock chat mock response: %v", err) + } + + w.Header().Set("Content-Type", "application/json") + + _, err = w.Write(chatResponse) + if err != nil { + t.Errorf("error on sending chat response: %v", err) + } + }) + + BedrockServer := httptest.NewServer(bedrockMock) + defer BedrockServer.Close() + + ctx := context.Background() + providerCfg := DefaultConfig() + clientCfg := clients.DefaultClientConfig() + + providerCfg.BaseURL = BedrockServer.URL + providerCfg.AccessKey = "abc" + providerCfg.SecretKey = "def" + providerCfg.AWSRegion = "us-west-2" + + client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) + require.NoError(t, err) + + request := schemas.ChatRequest{Message: schemas.ChatMessage{ + Role: "user", + Content: "What's the biggest animal?", + }} + + response, err := client.Chat(ctx, &request) + + responseString := fmt.Sprintf("%+v", response) + // errString := fmt.Sprintf("%+v", err) + fmt.Println(responseString) + + println(response, err) +} diff --git a/pkg/providers/bedrock/config.go b/pkg/providers/bedrock/config.go new file mode 100644 index 00000000..a5608ece --- /dev/null +++ b/pkg/providers/bedrock/config.go @@ -0,0 +1,62 @@ +package bedrock + +import ( + "glide/pkg/config/fields" +) + +// Params defines OpenAI-specific model params with the specific validation of values +// TODO: Add validations +type Params struct { + Temperature float64 `yaml:"temperature" json:"temperature"` + TopP float64 `yaml:"top_p" json:"top_p"` + MaxTokens int `yaml:"max_tokens" json:"max_tokens"` + StopSequence []string `yaml:"stop_sequences" json:"stop"` +} + +func DefaultParams() Params { + return Params{ + Temperature: 0, + TopP: 1, + MaxTokens: 512, + StopSequence: []string{}, + } +} + +func (p *Params) UnmarshalYAML(unmarshal func(interface{}) error) error { + *p = DefaultParams() + + type plain Params // to avoid recursion + + return unmarshal((*plain)(p)) +} + +type Config struct { + BaseURL string `yaml:"baseUrl" json:"baseUrl" validate:"required"` + ChatEndpoint string `yaml:"chatEndpoint" json:"chatEndpoint" validate:"required"` + Model string `yaml:"model" json:"model" validate:"required"` + APIKey fields.Secret `yaml:"api_key" json:"-" validate:"required"` + AccessKey string `yaml:"access_key" json:"-" validate:"required"` + SecretKey string `yaml:"secret_key" json:"-" validate:"required"` + AWSRegion string `yaml:"aws_region" json:"awsRegion" validate:"required"` + DefaultParams *Params `yaml:"defaultParams,omitempty" json:"defaultParams"` +} + +// DefaultConfig for OpenAI models +func DefaultConfig() *Config { + defaultParams := DefaultParams() + + return &Config{ + BaseURL: "", // This needs to come from config. https://bedrock-runtime.{{AWS_Region}}.amazonaws.com/ + ChatEndpoint: "/model", + Model: "amazon.titan-text-express-v1", + DefaultParams: &defaultParams, + } +} + +func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { + *c = *DefaultConfig() + + type plain Config // to avoid recursion + + return unmarshal((*plain)(c)) +} diff --git a/pkg/providers/bedrock/schemas.go b/pkg/providers/bedrock/schemas.go new file mode 100644 index 00000000..ac03de8e --- /dev/null +++ b/pkg/providers/bedrock/schemas.go @@ -0,0 +1,11 @@ +package bedrock + +// Bedrock Chat Response +type ChatCompletion struct { + InputTextTokenCount int `json:"inputTextTokenCount"` + Results []struct { + TokenCount int `json:"tokenCount"` + OutputText string `json:"outputText"` + CompletionReason string `json:"completionReason"` + } `json:"results"` +} diff --git a/pkg/providers/bedrock/testdata/chat.req.json b/pkg/providers/bedrock/testdata/chat.req.json new file mode 100644 index 00000000..c2e941d2 --- /dev/null +++ b/pkg/providers/bedrock/testdata/chat.req.json @@ -0,0 +1,12 @@ +{ + "model": "amazon.titan-text-express-v1", + "messages": [ + { + "role": "user", + "content": "What's the biggest animal?" + } + ], + "temperature": 0.8, + "top_p": 1, + "max_tokens": 100 +} diff --git a/pkg/providers/bedrock/testdata/chat.success.json b/pkg/providers/bedrock/testdata/chat.success.json new file mode 100644 index 00000000..eef056a7 --- /dev/null +++ b/pkg/providers/bedrock/testdata/chat.success.json @@ -0,0 +1 @@ +{"provider":"bedrock","model":"amazon.titan-text-express-v1","modelResponse":{"responseId":{"system_fingerprint":""},"message":{"role":"assistant","content":"\nThe largest animal in the world is the blue whale, which can reach lengths of over 100 feet and weigh as much as 200 tons."},"tokenCount":{"promptTokens":34,"responseTokens":34,"totalTokens":34}}} diff --git a/pkg/providers/cohere/chat.go b/pkg/providers/cohere/chat.go index 28712887..165b67bd 100644 --- a/pkg/providers/cohere/chat.go +++ b/pkg/providers/cohere/chat.go @@ -65,7 +65,7 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest { } // Chat sends a chat request to the specified cohere model. -func (c *Client) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) (*schemas.UnifiedChatResponse, error) { +func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) { // Create a new chat request chatRequest := c.createChatRequestSchema(request) @@ -81,7 +81,7 @@ func (c *Client) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) return chatResponse, nil } -func (c *Client) createChatRequestSchema(request *schemas.UnifiedChatRequest) *ChatRequest { +func (c *Client) createChatRequestSchema(request *schemas.ChatRequest) *ChatRequest { // TODO: consider using objectpool to optimize memory allocation chatRequest := c.chatRequestTemplate // hoping to get a copy of the template chatRequest.Message = request.Message.Content @@ -103,7 +103,7 @@ func (c *Client) createChatRequestSchema(request *schemas.UnifiedChatRequest) *C return chatRequest } -func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.UnifiedChatResponse, error) { +func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) { // Build request payload rawPayload, err := json.Marshal(payload) if err != nil { @@ -170,7 +170,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche } // Parse the response JSON - var cohereCompletion schemas.CohereChatCompletion + var cohereCompletion ChatCompletion err = json.Unmarshal(bodyBytes, &cohereCompletion) if err != nil { @@ -178,8 +178,8 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche return nil, err } - // Map response to UnifiedChatResponse schema - response := schemas.UnifiedChatResponse{ + // Map response to ChatResponse schema + response := schemas.ChatResponse{ ID: cohereCompletion.ResponseID, Created: int(time.Now().UTC().Unix()), // Cohere doesn't provide this Provider: providerName, @@ -206,7 +206,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche return &response, nil } -func (c *Client) handleErrorResponse(resp *http.Response) (*schemas.UnifiedChatResponse, error) { +func (c *Client) handleErrorResponse(resp *http.Response) (*schemas.ChatResponse, error) { bodyBytes, err := io.ReadAll(resp.Body) if err != nil { c.telemetry.Logger.Error("failed to read cohere chat response", zap.Error(err)) diff --git a/pkg/providers/cohere/client_test.go b/pkg/providers/cohere/client_test.go index 7828aa37..439e44d6 100644 --- a/pkg/providers/cohere/client_test.go +++ b/pkg/providers/cohere/client_test.go @@ -55,7 +55,7 @@ func TestCohereClient_ChatRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - request := schemas.UnifiedChatRequest{Message: schemas.ChatMessage{ + request := schemas.ChatRequest{Message: schemas.ChatMessage{ Role: "human", Content: "What's the biggest animal?", }} diff --git a/pkg/providers/cohere/schemas.go b/pkg/providers/cohere/schemas.go new file mode 100644 index 00000000..c807aa56 --- /dev/null +++ b/pkg/providers/cohere/schemas.go @@ -0,0 +1,67 @@ +package cohere + +// Cohere Chat Response +type ChatCompletion struct { + Text string `json:"text"` + GenerationID string `json:"generation_id"` + ResponseID string `json:"response_id"` + TokenCount TokenCount `json:"token_count"` + Citations []Citation `json:"citations"` + Documents []Documents `json:"documents"` + SearchQueries []SearchQuery `json:"search_queries"` + SearchResults []SearchResults `json:"search_results"` + Meta Meta `json:"meta"` + ToolInputs map[string]interface{} `json:"tool_inputs"` +} + +type TokenCount struct { + PromptTokens float64 `json:"prompt_tokens"` + ResponseTokens float64 `json:"response_tokens"` + TotalTokens float64 `json:"total_tokens"` + BilledTokens float64 `json:"billed_tokens"` +} + +type Meta struct { + APIVersion struct { + Version string `json:"version"` + } `json:"api_version"` + BilledUnits struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + } `json:"billed_units"` +} + +type Citation struct { + Start int `json:"start"` + End int `json:"end"` + Text string `json:"text"` + DocumentID []string `json:"document_id"` +} + +type Documents struct { + ID string `json:"id"` + Data map[string]string `json:"data"` // TODO: This needs to be updated +} + +type SearchQuery struct { + Text string `json:"text"` + GenerationID string `json:"generation_id"` +} + +type SearchResults struct { + SearchQuery []SearchQueryObject `json:"search_query"` + Connectors []ConnectorsResponse `json:"connectors"` + DocumentID []string `json:"documentId"` +} + +type SearchQueryObject struct { + Text string `json:"text"` + GenerationID string `json:"generationId"` +} + +type ConnectorsResponse struct { + ID string `json:"id"` + UserAccessToken string `json:"user_access_token"` + ContOnFail string `json:"continue_on_failure"` + Options map[string]string `json:"options"` +} diff --git a/pkg/providers/config.go b/pkg/providers/config.go index 07691e91..55d885f6 100644 --- a/pkg/providers/config.go +++ b/pkg/providers/config.go @@ -6,7 +6,9 @@ import ( "glide/pkg/routers/latency" + "glide/pkg/providers/bedrock" "glide/pkg/providers/clients" + "glide/pkg/providers/ollama" "glide/pkg/routers/health" @@ -33,6 +35,8 @@ type LangModelConfig struct { Cohere *cohere.Config `yaml:"cohere,omitempty" json:"cohere,omitempty"` OctoML *octoml.Config `yaml:"octoml,omitempty" json:"octoml,omitempty"` Anthropic *anthropic.Config `yaml:"anthropic,omitempty" json:"anthropic,omitempty"` + Bedrock *bedrock.Config `yaml:"bedrock,omitempty" json:"bedrock,omitempty"` + Ollama *ollama.Config `yaml:"ollama,omitempty" json:"ollama,omitempty"` } func DefaultLangModelConfig() *LangModelConfig { @@ -68,6 +72,8 @@ func (c *LangModelConfig) initClient(tel *telemetry.Telemetry) (LangModelProvide return octoml.NewClient(c.OctoML, c.Client, tel) case c.Anthropic != nil: return anthropic.NewClient(c.Anthropic, c.Client, tel) + case c.Bedrock != nil: + return bedrock.NewClient(c.Bedrock, c.Client, tel) default: return nil, ErrProviderNotFound } @@ -96,14 +102,22 @@ func (c *LangModelConfig) validateOneProvider() error { providersConfigured++ } + if c.Bedrock != nil { + providersConfigured++ + } + + if c.Ollama != nil { + providersConfigured++ + } + // check other providers here if providersConfigured == 0 { - return fmt.Errorf("exactly one provider must be cofigured for model \"%v\", none is configured", c.ID) + return fmt.Errorf("exactly one provider must be configured for model \"%v\", none is configured", c.ID) } if providersConfigured > 1 { return fmt.Errorf( - "exactly one provider must be cofigured for model \"%v\", %v are configured", + "exactly one provider must be configured for model \"%v\", %v are configured", c.ID, providersConfigured, ) diff --git a/pkg/providers/octoml/chat.go b/pkg/providers/octoml/chat.go index 29ca6b7d..4860a0b9 100644 --- a/pkg/providers/octoml/chat.go +++ b/pkg/providers/octoml/chat.go @@ -9,6 +9,8 @@ import ( "net/http" "time" + "glide/pkg/providers/openai" + "glide/pkg/providers/clients" "glide/pkg/api/schemas" @@ -47,7 +49,7 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest { } } -func NewChatMessagesFromUnifiedRequest(request *schemas.UnifiedChatRequest) []ChatMessage { +func NewChatMessagesFromUnifiedRequest(request *schemas.ChatRequest) []ChatMessage { messages := make([]ChatMessage, 0, len(request.MessageHistory)+1) // Add items from messageHistory first and the new chat message last @@ -61,7 +63,7 @@ func NewChatMessagesFromUnifiedRequest(request *schemas.UnifiedChatRequest) []Ch } // Chat sends a chat request to the specified octoml model. -func (c *Client) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) (*schemas.UnifiedChatResponse, error) { +func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) { // Create a new chat request chatRequest := c.createChatRequestSchema(request) @@ -77,7 +79,7 @@ func (c *Client) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) return chatResponse, nil } -func (c *Client) createChatRequestSchema(request *schemas.UnifiedChatRequest) *ChatRequest { +func (c *Client) createChatRequestSchema(request *schemas.ChatRequest) *ChatRequest { // TODO: consider using objectpool to optimize memory allocation chatRequest := c.chatRequestTemplate // hoping to get a copy of the template chatRequest.Messages = NewChatMessagesFromUnifiedRequest(request) @@ -85,7 +87,7 @@ func (c *Client) createChatRequestSchema(request *schemas.UnifiedChatRequest) *C return chatRequest } -func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.UnifiedChatResponse, error) { +func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) { // Build request payload rawPayload, err := json.Marshal(payload) if err != nil { @@ -152,7 +154,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche } // Parse the response JSON - var openAICompletion schemas.OpenAIChatCompletion // Octo uses the same response schema as OpenAI + var openAICompletion openai.ChatCompletion // Octo uses the same response schema as OpenAI err = json.Unmarshal(bodyBytes, &openAICompletion) if err != nil { @@ -161,7 +163,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche } // Map response to UnifiedChatResponse schema - response := schemas.UnifiedChatResponse{ + response := schemas.ChatResponse{ ID: openAICompletion.ID, Created: openAICompletion.Created, Provider: providerName, diff --git a/pkg/providers/octoml/client_test.go b/pkg/providers/octoml/client_test.go index 1c5c7e63..c8a438c1 100644 --- a/pkg/providers/octoml/client_test.go +++ b/pkg/providers/octoml/client_test.go @@ -55,7 +55,7 @@ func TestOctoMLClient_ChatRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - request := schemas.UnifiedChatRequest{Message: schemas.ChatMessage{ + request := schemas.ChatRequest{Message: schemas.ChatMessage{ Role: "human", Content: "What's the biggest animal?", }} @@ -88,7 +88,7 @@ func TestOctoMLClient_Chat_Error(t *testing.T) { require.NoError(t, err) // Create a chat request - request := schemas.UnifiedChatRequest{ + request := schemas.ChatRequest{ Message: schemas.ChatMessage{ Role: "human", Content: "What's the biggest animal?", diff --git a/pkg/providers/ollama/chat.go b/pkg/providers/ollama/chat.go new file mode 100644 index 00000000..f2247dd7 --- /dev/null +++ b/pkg/providers/ollama/chat.go @@ -0,0 +1,207 @@ +package ollama + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "glide/pkg/providers/clients" + + "github.com/google/uuid" + + "glide/pkg/api/schemas" + "go.uber.org/zap" +) + +type ChatMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// ChatRequest is an ollama-specific request schema +type ChatRequest struct { + Model string `json:"model"` + Messages []ChatMessage `json:"messages"` + Microstat int `json:"microstat,omitempty"` + MicrostatEta float64 `json:"microstat_eta,omitempty"` + MicrostatTau float64 `json:"microstat_tau,omitempty"` + NumCtx int `json:"num_ctx,omitempty"` + NumGqa int `json:"num_gqa,omitempty"` + NumGpu int `json:"num_gpu,omitempty"` + NumThread int `json:"num_thread,omitempty"` + RepeatLastN int `json:"repeat_last_n,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + Seed int `json:"seed,omitempty"` + StopWords []string `json:"stop,omitempty"` + Tfsz float64 `json:"tfs_z,omitempty"` + NumPredict int `json:"num_predict,omitempty"` + TopK int `json:"top_k,omitempty"` + TopP float64 `json:"top_p,omitempty"` + Stream bool `json:"stream"` +} + +// NewChatRequestFromConfig fills the struct from the config. Not using reflection because of performance penalty it gives +func NewChatRequestFromConfig(cfg *Config) *ChatRequest { + return &ChatRequest{ + Model: cfg.Model, + Temperature: cfg.DefaultParams.Temperature, + Microstat: cfg.DefaultParams.Microstat, + MicrostatEta: cfg.DefaultParams.MicrostatEta, + MicrostatTau: cfg.DefaultParams.MicrostatTau, + NumCtx: cfg.DefaultParams.NumCtx, + NumGqa: cfg.DefaultParams.NumGqa, + NumGpu: cfg.DefaultParams.NumGpu, + NumThread: cfg.DefaultParams.NumThread, + RepeatLastN: cfg.DefaultParams.RepeatLastN, + Seed: cfg.DefaultParams.Seed, + StopWords: cfg.DefaultParams.StopWords, + Tfsz: cfg.DefaultParams.Tfsz, + NumPredict: cfg.DefaultParams.NumPredict, + TopP: cfg.DefaultParams.TopP, + TopK: cfg.DefaultParams.TopK, + Stream: cfg.DefaultParams.Stream, + } +} + +func NewChatMessagesFromUnifiedRequest(request *schemas.ChatRequest) []ChatMessage { + messages := make([]ChatMessage, 0, len(request.MessageHistory)+1) + + // Add items from messageHistory first and the new chat message last + for _, message := range request.MessageHistory { + messages = append(messages, ChatMessage{Role: message.Role, Content: message.Content}) + } + + messages = append(messages, ChatMessage{Role: request.Message.Role, Content: request.Message.Content}) + + return messages +} + +// Chat sends a chat request to the specified ollama model. +func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) { + // Create a new chat request + chatRequest := c.createChatRequestSchema(request) + + chatResponse, err := c.doChatRequest(ctx, chatRequest) + if err != nil { + return nil, fmt.Errorf("chat request failed: %w", err) + } + + if len(chatResponse.ModelResponse.Message.Content) == 0 { + return nil, ErrEmptyResponse + } + + return chatResponse, nil +} + +func (c *Client) createChatRequestSchema(request *schemas.ChatRequest) *ChatRequest { + // TODO: consider using objectpool to optimize memory allocation + chatRequest := c.chatRequestTemplate // hoping to get a copy of the template + chatRequest.Messages = NewChatMessagesFromUnifiedRequest(request) + + return chatRequest +} + +func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) { + // Build request payload + rawPayload, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("unable to marshal ollama chat request payload: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.chatURL, bytes.NewBuffer(rawPayload)) + if err != nil { + return nil, fmt.Errorf("unable to create ollama chat request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + + // TODO: this could leak information from messages which may not be a desired thing to have + c.telemetry.Logger.Debug( + "ollama chat request", + zap.String("chat_url", c.chatURL), + zap.Any("payload", payload), + ) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send ollama chat request: %w", err) + } + + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + c.telemetry.Logger.Error("failed to read ollama chat response", zap.Error(err)) + } + + c.telemetry.Logger.Error( + "ollama chat request failed", + zap.Int("status_code", resp.StatusCode), + zap.String("response", string(bodyBytes)), + zap.Any("headers", resp.Header), + ) + + if resp.StatusCode == http.StatusTooManyRequests { + // Read the value of the "Retry-After" header to get the cooldown delay + retryAfter := resp.Header.Get("Retry-After") + + // Parse the value to get the duration + cooldownDelay, err := time.ParseDuration(retryAfter) + if err != nil { + return nil, fmt.Errorf("failed to parse cooldown delay from headers: %w", err) + } + + return nil, clients.NewRateLimitError(&cooldownDelay) + } + + // Server & client errors result in the same error to keep gateway resilient + return nil, clients.ErrProviderUnavailable + } + + // Read the response body into a byte slice + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + c.telemetry.Logger.Error("failed to read ollama chat response", zap.Error(err)) + return nil, err + } + + // Parse the response JSON + var ollamaCompletion ChatCompletion + + err = json.Unmarshal(bodyBytes, &ollamaCompletion) + if err != nil { + c.telemetry.Logger.Error("failed to parse ollama chat response", zap.Error(err)) + return nil, err + } + + // Map response to UnifiedChatResponse schema + response := schemas.ChatResponse{ + ID: uuid.NewString(), + Created: int(time.Now().Unix()), + Provider: providerName, + Model: ollamaCompletion.Model, + Cached: false, + ModelResponse: schemas.ProviderResponse{ + SystemID: map[string]string{ + "system_fingerprint": "", + }, + Message: schemas.ChatMessage{ + Role: ollamaCompletion.Message.Role, + Content: ollamaCompletion.Message.Content, + Name: "", + }, + TokenUsage: schemas.TokenUsage{ + PromptTokens: float64(ollamaCompletion.EvalCount), + ResponseTokens: float64(ollamaCompletion.EvalCount), + TotalTokens: float64(ollamaCompletion.EvalCount), + }, + }, + } + + return &response, nil +} diff --git a/pkg/providers/ollama/client.go b/pkg/providers/ollama/client.go new file mode 100644 index 00000000..00043025 --- /dev/null +++ b/pkg/providers/ollama/client.go @@ -0,0 +1,59 @@ +package ollama + +import ( + "errors" + "net/http" + "net/url" + + "glide/pkg/providers/clients" + "glide/pkg/telemetry" +) + +const ( + providerName = "ollama" +) + +// ErrEmptyResponse is returned when the OpenAI API returns an empty response. +var ( + ErrEmptyResponse = errors.New("empty response") +) + +// Client is a client for accessing OpenAI API +type Client struct { + baseURL string + chatURL string + chatRequestTemplate *ChatRequest + config *Config + httpClient *http.Client + telemetry *telemetry.Telemetry +} + +// NewClient creates a new OpenAI client for the OpenAI API. +func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *telemetry.Telemetry) (*Client, error) { + chatURL, err := url.JoinPath(providerConfig.BaseURL, providerConfig.ChatEndpoint) + if err != nil { + return nil, err + } + + c := &Client{ + baseURL: providerConfig.BaseURL, + chatURL: chatURL, + config: providerConfig, + chatRequestTemplate: NewChatRequestFromConfig(providerConfig), + httpClient: &http.Client{ + Timeout: *clientConfig.Timeout, + // TODO: use values from the config + Transport: &http.Transport{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 2, + }, + }, + telemetry: tel, + } + + return c, nil +} + +func (c *Client) Provider() string { + return providerName +} diff --git a/pkg/providers/ollama/client_test.go b/pkg/providers/ollama/client_test.go new file mode 100644 index 00000000..3a85d397 --- /dev/null +++ b/pkg/providers/ollama/client_test.go @@ -0,0 +1,128 @@ +package ollama + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "glide/pkg/providers/clients" + + "glide/pkg/api/schemas" + + "glide/pkg/telemetry" + + "github.com/stretchr/testify/require" +) + +func TestOllamaClient_ChatRequest(t *testing.T) { + OllamaAIMock := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rawPayload, _ := io.ReadAll(r.Body) + + var data interface{} + // Parse the JSON body + err := json.Unmarshal(rawPayload, &data) + if err != nil { + t.Errorf("error decoding payload (%q): %v", string(rawPayload), err) + } + + chatResponse, err := os.ReadFile(filepath.Clean("./testdata/chat.success.json")) + if err != nil { + t.Errorf("error reading ollama chat mock response: %v", err) + } + + w.Header().Set("Content-Type", "application/json") + + _, err = w.Write(chatResponse) + if err != nil { + t.Errorf("error on sending chat response: %v", err) + } + }) + + OllamaServer := httptest.NewServer(OllamaAIMock) + defer OllamaServer.Close() + + ctx := context.Background() + providerCfg := DefaultConfig() + + clientCfg := clients.DefaultClientConfig() + + providerCfg.Model = "llama2" + + client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) + require.NoError(t, err) + + request := schemas.ChatRequest{Message: schemas.ChatMessage{ + Role: "user", + Content: "What's the biggest animal?", + }} + + _, err = client.Chat(ctx, &request) + + // require.NoError(t, err) + + require.Error(t, err) + require.Contains(t, err.Error(), "chat request failed") +} + +func TestOllamaClient_ChatRequest_Non200Response(t *testing.T) { + // Create a mock HTTP server that returns a non-OK status code + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadRequest) + })) + + defer mockServer.Close() + + // Create a new client with the mock server URL + client := &Client{ + httpClient: http.DefaultClient, + chatURL: mockServer.URL, + config: DefaultConfig(), + telemetry: telemetry.NewTelemetryMock(), + } + + // Create a chat request payload + payload := &ChatRequest{ + Messages: []ChatMessage{{Role: "human", Content: "Hello"}}, + } + + // Call the chatRequest function + _, err := client.doChatRequest(context.Background(), payload) + + require.Error(t, err) + require.Contains(t, err.Error(), "provider is not available") +} + +func TestOllamaClient_ChatRequest_SuccessfulResponse(t *testing.T) { + // Create a mock HTTP server that returns an OK status code and a sample response + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"response": "OK"}`)) + })) + + defer mockServer.Close() + + // Create a new client with the mock server URL + client := &Client{ + httpClient: http.DefaultClient, + chatURL: mockServer.URL, + config: DefaultConfig(), + telemetry: telemetry.NewTelemetryMock(), + } + + // Create a chat request payload + payload := &ChatRequest{ + Messages: []ChatMessage{{Role: "human", Content: "Hello"}}, + } + + // Call the chatRequest function + response, err := client.doChatRequest(context.Background(), payload) + + require.NoError(t, err) + require.NotNil(t, response) + require.Equal(t, "", response.ModelResponse.Message.Role) +} diff --git a/pkg/providers/ollama/config.go b/pkg/providers/ollama/config.go new file mode 100644 index 00000000..7363db58 --- /dev/null +++ b/pkg/providers/ollama/config.go @@ -0,0 +1,67 @@ +package ollama + +// Params defines Ollmama-specific model params with the specific validation of values +// TODO: Add validations +type Params struct { + Temperature float64 `yaml:"temperature,omitempty" json:"temperature"` + TopP float64 `yaml:"top_p,omitempty" json:"top_p"` + Microstat int `yaml:"microstat,omitempty" json:"microstat"` + MicrostatEta float64 `yaml:"microstat_eta,omitempty" json:"microstat_eta"` + MicrostatTau float64 `yaml:"microstat_tau,omitempty" json:"microstat_tau"` + NumCtx int `yaml:"num_ctx,omitempty" json:"num_ctx"` + NumGqa int `yaml:"num_gqa,omitempty" json:"num_gqa"` + NumGpu int `yaml:"num_gpu,omitempty" json:"num_gpu"` + NumThread int `yaml:"num_thread,omitempty" json:"num_thread"` + RepeatLastN int `yaml:"repeat_last_n,omitempty" json:"repeat_last_n"` + Seed int `yaml:"seed,omitempty" json:"seed"` + StopWords []string `yaml:"stop,omitempty" json:"stop"` + Tfsz float64 `yaml:"tfs_z,omitempty" json:"tfs_z"` + NumPredict int `yaml:"num_predict,omitempty" json:"num_predict"` + TopK int `yaml:"top_k,omitempty" json:"top_k"` + Stream bool `yaml:"stream,omitempty" json:"stream"` +} + +func DefaultParams() Params { + return Params{ + Temperature: 0.8, + NumCtx: 2048, + TopP: 0.9, + TopK: 40, + Stream: false, + } +} + +func (p *Params) UnmarshalYAML(unmarshal func(interface{}) error) error { + *p = DefaultParams() + + type plain Params // to avoid recursion + + return unmarshal((*plain)(p)) +} + +type Config struct { + BaseURL string `yaml:"baseUrl" json:"baseUrl" validate:"required"` + ChatEndpoint string `yaml:"chatEndpoint" json:"chatEndpoint" validate:"required"` + Model string `yaml:"model" json:"model" validate:"required"` + DefaultParams *Params `yaml:"defaultParams,omitempty" json:"defaultParams"` +} + +// DefaultConfig for OpenAI models +func DefaultConfig() *Config { + defaultParams := DefaultParams() + + return &Config{ + BaseURL: "http://localhost:11434", + ChatEndpoint: "/api/chat", + Model: "", + DefaultParams: &defaultParams, + } +} + +func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { + *c = *DefaultConfig() + + type plain Config // to avoid recursion + + return unmarshal((*plain)(c)) +} diff --git a/pkg/providers/ollama/schemas.go b/pkg/providers/ollama/schemas.go new file mode 100644 index 00000000..39311c0f --- /dev/null +++ b/pkg/providers/ollama/schemas.go @@ -0,0 +1,17 @@ +package ollama + +type ChatCompletion struct { + Model string `json:"model"` + CreatedAt string `json:"created_at"` + Message struct { + Role string `json:"role"` + Content string `json:"content"` + } `json:"message"` + Done bool `json:"done"` + TotalDuration int64 `json:"total_duration"` + LoadDuration int64 `json:"load_duration"` + PromptEvalCount int `json:"prompt_eval_count,omitempty"` + PromptEvalDuration int64 `json:"prompt_eval_duration"` + EvalCount int `json:"eval_count"` + EvalDuration int64 `json:"eval_duration"` +} diff --git a/pkg/providers/ollama/testdata/chat.req.json b/pkg/providers/ollama/testdata/chat.req.json new file mode 100644 index 00000000..8af17718 --- /dev/null +++ b/pkg/providers/ollama/testdata/chat.req.json @@ -0,0 +1,11 @@ +{ + "model": "llama2", + "messages": [ + { + "role": "human", + "content": "What's the biggest animal?" + } + ], + "temperature": 0.8, + "stream": false +} diff --git a/pkg/providers/ollama/testdata/chat.success.json b/pkg/providers/ollama/testdata/chat.success.json new file mode 100644 index 00000000..0a674356 --- /dev/null +++ b/pkg/providers/ollama/testdata/chat.success.json @@ -0,0 +1,15 @@ +{ + "model": "registry.ollama.ai/library/llama2:latest", + "created_at": "2023-12-12T14:13:43.416799Z", + "message": { + "role": "assistant", + "content": "Hello! How are you today?" + }, + "done": true, + "total_duration": 5191566416, + "load_duration": 2154458, + "prompt_eval_count": 26, + "prompt_eval_duration": 383809000, + "eval_count": 298, + "eval_duration": 4799921000 +} \ No newline at end of file diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index c296c080..bbcc4ff4 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -61,7 +61,7 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest { } } -func NewChatMessagesFromUnifiedRequest(request *schemas.UnifiedChatRequest) []ChatMessage { +func NewChatMessagesFromUnifiedRequest(request *schemas.ChatRequest) []ChatMessage { messages := make([]ChatMessage, 0, len(request.MessageHistory)+1) // Add items from messageHistory first and the new chat message last @@ -75,7 +75,7 @@ func NewChatMessagesFromUnifiedRequest(request *schemas.UnifiedChatRequest) []Ch } // Chat sends a chat request to the specified OpenAI model. -func (c *Client) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) (*schemas.UnifiedChatResponse, error) { +func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) { // Create a new chat request chatRequest := c.createChatRequestSchema(request) @@ -91,7 +91,7 @@ func (c *Client) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) return chatResponse, nil } -func (c *Client) createChatRequestSchema(request *schemas.UnifiedChatRequest) *ChatRequest { +func (c *Client) createChatRequestSchema(request *schemas.ChatRequest) *ChatRequest { // TODO: consider using objectpool to optimize memory allocation chatRequest := c.chatRequestTemplate // hoping to get a copy of the template chatRequest.Messages = NewChatMessagesFromUnifiedRequest(request) @@ -99,7 +99,7 @@ func (c *Client) createChatRequestSchema(request *schemas.UnifiedChatRequest) *C return chatRequest } -func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.UnifiedChatResponse, error) { +func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) { // Build request payload rawPayload, err := json.Marshal(payload) if err != nil { @@ -166,7 +166,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche } // Parse the response JSON - var openAICompletion schemas.OpenAIChatCompletion + var openAICompletion ChatCompletion err = json.Unmarshal(bodyBytes, &openAICompletion) if err != nil { @@ -174,8 +174,8 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche return nil, err } - // Map response to UnifiedChatResponse schema - response := schemas.UnifiedChatResponse{ + // Map response to ChatResponse schema + response := schemas.ChatResponse{ ID: openAICompletion.ID, Created: openAICompletion.Created, Provider: providerName, diff --git a/pkg/providers/openai/client.go b/pkg/providers/openai/client.go index 7a825cd5..c56f227b 100644 --- a/pkg/providers/openai/client.go +++ b/pkg/providers/openai/client.go @@ -9,10 +9,6 @@ import ( "glide/pkg/telemetry" ) -// TODO: Explore resource pooling -// TODO: Optimize Type use -// TODO: Explore Hertz TLS & resource pooling - const ( providerName = "openai" ) diff --git a/pkg/providers/openai/client_test.go b/pkg/providers/openai/client_test.go index f8ca7e6a..6bd8298d 100644 --- a/pkg/providers/openai/client_test.go +++ b/pkg/providers/openai/client_test.go @@ -56,8 +56,8 @@ func TestOpenAIClient_ChatRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - request := schemas.UnifiedChatRequest{Message: schemas.ChatMessage{ - Role: "human", + request := schemas.ChatRequest{Message: schemas.ChatMessage{ + Role: "user", Content: "What's the biggest animal?", }} diff --git a/pkg/providers/openai/schemas.go b/pkg/providers/openai/schemas.go new file mode 100644 index 00000000..cf41aebf --- /dev/null +++ b/pkg/providers/openai/schemas.go @@ -0,0 +1,26 @@ +package openai + +// OpenAI Chat Response (also used by Azure OpenAI and OctoML) + +type ChatCompletion struct { + ID string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + Model string `json:"model"` + SystemFingerprint string `json:"system_fingerprint"` + Choices []Choice `json:"choices"` + Usage Usage `json:"usage"` +} + +type Choice struct { + Index int `json:"index"` + Message ChatMessage `json:"message"` + Logprobs interface{} `json:"logprobs"` + FinishReason string `json:"finish_reason"` +} + +type Usage struct { + PromptTokens float64 `json:"prompt_tokens"` + CompletionTokens float64 `json:"completion_tokens"` + TotalTokens float64 `json:"total_tokens"` +} diff --git a/pkg/providers/provider.go b/pkg/providers/provider.go index 4a3774b2..399d6ee7 100644 --- a/pkg/providers/provider.go +++ b/pkg/providers/provider.go @@ -15,7 +15,7 @@ import ( // LangModelProvider defines an interface a provider should fulfill to be able to serve language chat requests type LangModelProvider interface { Provider() string - Chat(ctx context.Context, request *schemas.UnifiedChatRequest) (*schemas.UnifiedChatResponse, error) + Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) } type Model interface { @@ -78,7 +78,7 @@ func (m *LangModel) Weight() int { return m.weight } -func (m *LangModel) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) (*schemas.UnifiedChatResponse, error) { +func (m *LangModel) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) { startedAt := time.Now() resp, err := m.client.Chat(ctx, request) diff --git a/pkg/providers/testing.go b/pkg/providers/testing.go index f408380c..890421a0 100644 --- a/pkg/providers/testing.go +++ b/pkg/providers/testing.go @@ -14,8 +14,8 @@ type ResponseMock struct { Err *error } -func (m *ResponseMock) Resp() *schemas.UnifiedChatResponse { - return &schemas.UnifiedChatResponse{ +func (m *ResponseMock) Resp() *schemas.ChatResponse { + return &schemas.ChatResponse{ ID: "rsp0001", ModelResponse: schemas.ProviderResponse{ SystemID: map[string]string{ @@ -40,7 +40,7 @@ func NewProviderMock(responses []ResponseMock) *ProviderMock { } } -func (c *ProviderMock) Chat(_ context.Context, _ *schemas.UnifiedChatRequest) (*schemas.UnifiedChatResponse, error) { +func (c *ProviderMock) Chat(_ context.Context, _ *schemas.ChatRequest) (*schemas.ChatResponse, error) { response := c.responses[c.idx] c.idx++ diff --git a/pkg/routers/router.go b/pkg/routers/router.go index c2149c7a..13d89fa3 100644 --- a/pkg/routers/router.go +++ b/pkg/routers/router.go @@ -55,7 +55,7 @@ func (r *LangRouter) ID() string { return r.routerID } -func (r *LangRouter) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) (*schemas.UnifiedChatResponse, error) { +func (r *LangRouter) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) { if len(r.models) == 0 { return nil, ErrNoModels }