Skip to content

Commit

Permalink
Adding LM Studio Support (stitionai#389)
Browse files Browse the repository at this point in the history
* adding LM Studio

* include LM Studio in  simple config

* Delete ui/package-lock.json

That file shouldn't have been commited
  • Loading branch information
ayoubachak authored Sep 19, 2024
1 parent 04af76c commit 3b98ed3
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 2 deletions.
3 changes: 3 additions & 0 deletions sample.config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@ NETLIFY = "<YOUR_NETLIFY_API_KEY>"
BING = "https://api.bing.microsoft.com/v7.0/search"
GOOGLE = "https://www.googleapis.com/customsearch/v1"
OLLAMA = "http://127.0.0.1:11434"

LM_STUDIO = "http://localhost:1234/v1"
OPENAI = "https://api.openai.com/v1"


[LOGGING]
LOG_REST_API = "true"
LOG_PROMPTS = "false"
Expand Down
7 changes: 7 additions & 0 deletions src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def get_google_search_api_endpoint(self):

def get_ollama_api_endpoint(self):
return self.config["API_ENDPOINTS"]["OLLAMA"]

def get_lmstudio_api_endpoint(self):
return self.config["API_ENDPOINTS"]["LM_STUDIO"]

def get_claude_api_key(self):
return self.config["API_KEYS"]["CLAUDE"]
Expand Down Expand Up @@ -131,6 +134,10 @@ def set_google_search_api_endpoint(self, endpoint):
def set_ollama_api_endpoint(self, endpoint):
self.config["API_ENDPOINTS"]["OLLAMA"] = endpoint
self.save_config()

def set_lmstudio_api_endpoint(self, endpoint):
self.config["API_ENDPOINTS"]["LM_STUDIO"] = endpoint
self.save_config()

def set_claude_api_key(self, key):
self.config["API_KEYS"]["CLAUDE"] = key
Expand Down
10 changes: 8 additions & 2 deletions src/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .gemini_client import Gemini
from .mistral_client import MistralAi
from .groq_client import Groq
from .lm_studio_client import LMStudio

from src.state import AgentState

Expand Down Expand Up @@ -60,7 +61,11 @@ def __init__(self, model_id: str = None):
("Mixtral", "mixtral-8x7b-32768"),
("GEMMA 7B", "gemma-7b-it"),
],
"OLLAMA": []
"OLLAMA": [],
"LM_STUDIO": [
("LM Studio", "local-model"),
],

}
if ollama.client:
self.models["OLLAMA"] = [(model["name"], model["name"]) for model in ollama.models]
Expand Down Expand Up @@ -99,7 +104,8 @@ def inference(self, prompt: str, project_name: str) -> str:
"OPENAI": OpenAi(),
"GOOGLE": Gemini(),
"MISTRAL": MistralAi(),
"GROQ": Groq()
"GROQ": Groq(),
"LM_STUDIO": LMStudio()
}

try:
Expand Down
30 changes: 30 additions & 0 deletions src/llm/lm_studio_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from src.logger import Logger
from src.config import Config
from openai import OpenAI


log = Logger()

class LMStudio:
def __init__(self):
try:
self.api_endpoint = Config().get_lmstudio_api_endpoint()
self.client = OpenAI(base_url=self.api_endpoint, api_key="not-needed")
log.info("LM Studio available")
except:
self.api_endpoint = None
self.client = None
log.warning("LM Studio not available")
log.warning("Make sure to set the LM Studio API endpoint in the config")

def inference(self, model_id: str, prompt: str) -> str:
chat_completion = self.client.chat.completions.create(
messages=[
{
"role": "user",
"content": prompt.strip(),
}
],
model=model_id, # unused
)
return chat_completion.choices[0].message.content

0 comments on commit 3b98ed3

Please sign in to comment.