From e9203f0e160b18815c053e9269df1373854c0e35 Mon Sep 17 00:00:00 2001 From: the-gigi Date: Wed, 3 Apr 2024 10:55:10 -0700 Subject: [PATCH] WIP on azure assistant API support --- rundemo.sh | 5 +- .../demo/AssistantAzureServiceDemo.java | 189 ++++++++++++++++++ .../openai/demo/ChatAzureServiceDemo.java | 6 +- .../openai/SimpleOpenAIAzure.java | 75 ++++++- 4 files changed, 267 insertions(+), 8 deletions(-) create mode 100644 src/demo/java/io/github/sashirestela/openai/demo/AssistantAzureServiceDemo.java diff --git a/rundemo.sh b/rundemo.sh index 1dc5653f..97c6c267 100755 --- a/rundemo.sh +++ b/rundemo.sh @@ -11,7 +11,7 @@ log_file_value="demo.log" log_options="" if [ ! -z "$1" ]; then - service="${1^}" + service=$(echo "$1" | awk '{ print toupper(substr($0, 1, 1)) substr($0, 2) }') fi if [ ! -z "$2" ]; then @@ -19,7 +19,8 @@ if [ ! -z "$2" ]; then fi if [ "$log_level_value" != "off" ]; then - log_options="${log_level_param}=${log_level_value} ${log_file_param}=${log_file_value}" + log_options="${log_level_param}=${log_level_value}" + #log_options="${log_level_param}=${log_level_value} ${log_file_param}=${log_file_value}" fi main_class="io.github.sashirestela.openai.demo.${service}ServiceDemo" diff --git a/src/demo/java/io/github/sashirestela/openai/demo/AssistantAzureServiceDemo.java b/src/demo/java/io/github/sashirestela/openai/demo/AssistantAzureServiceDemo.java new file mode 100644 index 00000000..1562e776 --- /dev/null +++ b/src/demo/java/io/github/sashirestela/openai/demo/AssistantAzureServiceDemo.java @@ -0,0 +1,189 @@ +package io.github.sashirestela.openai.demo; + +import io.github.sashirestela.openai.SimpleOpenAIAzure; +import io.github.sashirestela.openai.domain.assistant.AssistantRequest; +import io.github.sashirestela.openai.domain.assistant.AssistantTool; +import io.github.sashirestela.openai.domain.assistant.Events; +import io.github.sashirestela.openai.domain.assistant.ImageFileContent; +import io.github.sashirestela.openai.domain.assistant.TextContent; +import io.github.sashirestela.openai.domain.assistant.ThreadMessage; +import io.github.sashirestela.openai.domain.assistant.ThreadMessageDelta; +import io.github.sashirestela.openai.domain.assistant.ThreadMessageRequest; +import io.github.sashirestela.openai.domain.assistant.ThreadRequest; +import io.github.sashirestela.openai.domain.assistant.ThreadRun; +import io.github.sashirestela.openai.domain.assistant.ThreadRunRequest; +import io.github.sashirestela.openai.domain.file.FileRequest; +import io.github.sashirestela.openai.domain.file.PurposeType; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + +public class AssistantAzureServiceDemo extends AbstractDemo { + + String assistantId; + String fileId; + String threadId; + String runId; + + public AssistantAzureServiceDemo(String baseUrl, String apiKey, String apiVersion) { + super(SimpleOpenAIAzure.builder() + .apiKey(apiKey) + .baseUrl(baseUrl) + .apiVersion(apiVersion) + .build()); + + var assistantRequest = AssistantRequest.builder() + .model("N/A") + .build(); + + var assistant = openAI.assistants().create(assistantRequest).join(); + System.out.println(assistant); + assistantId = assistant.getId(); + } + + + + public void demoCreateAssistant() { + var assistantRequest = AssistantRequest.builder() + .model("N/A") + .build(); + + var assistant = openAI.assistants().create(assistantRequest).join(); + System.out.println(assistant); + assistantId = assistant.getId(); + } + + public void demoRetrieveAndModifyAssistant() { + var assistant = openAI.assistants().getOne(assistantId).join(); + var assistantRequest = assistant.mutate() + .name("Math Expert") + .instructions( + "You are a personal math expert. When asked a question, write and run Python code to answer the question.") + .tool(AssistantTool.CODE_INTERPRETER) + .build(); + + assistant = openAI.assistants().modify(assistant.getId(), assistantRequest).join(); + System.out.println(assistant); + } + + public void demoListAssistants() { + AtomicInteger count = new AtomicInteger(); + openAI.assistants() + .getList() + .join() + .forEach(r -> System.out.println("\n#" + count.incrementAndGet() + "\n" + r)); + } + + public void demoUploadAssistantFile() { + var fileRequest = FileRequest.builder() + .file(Paths.get("src/demo/resources/code_interpreter_file.txt")) + .purpose(PurposeType.ASSISTANTS) + .build(); + var file = openAI.files().create(fileRequest).join(); + var assistantFile = openAI.assistants().createFile(assistantId, file.getId()).join(); + System.out.println(assistantFile); + fileId = file.getId(); + } + + public void demoCreateThread() { + var threadRequest = ThreadRequest.builder() + .message(ThreadMessageRequest.builder() + .role("user") + .content( + "Inspect the content of the attached text file. After that plot graph of the formula requested in it.") + .build()) + .build(); + + var thread = openAI.threads().create(threadRequest).join(); + System.out.println(thread); + threadId = thread.getId(); + } + + public void demoRunThreadAndWaitUntilComplete() { + var run = openAI.threads().createRun(threadId, assistantId).join(); + runId = run.getId(); + + while (!run.getStatus().equals(ThreadRun.Status.COMPLETED)) { + sleep(1); + run = openAI.threads().getRun(run.getThreadId(), run.getId()).join(); + } + System.out.println(run); + + var messages = openAI.threads().getMessageList(threadId).join(); + System.out.println(messages); + } + + public void demoRunThreadAndStream() { + var request = ThreadRunRequest.builder().assistantId(assistantId).build(); + var response = openAI.threads().createRunStream(threadId, request).join(); + response.filter(e -> e.getName().equals(Events.THREAD_MESSAGE_DELTA)) + .map(e -> ((TextContent) ((ThreadMessageDelta) e.getData()).getDelta().getContent().get(0)).getValue()) + .forEach(System.out::print); + System.out.println(); + } + + public void demoGetAssistantMessages() { + List messages = openAI.threads().getMessageList(threadId).join(); + ThreadMessage assistant = messages.get(0); + ImageFileContent assistantImageContent = assistant.getContent() + .stream() + .filter(ImageFileContent.class::isInstance) + .map(ImageFileContent.class::cast) + .findFirst() + .orElse(null); + + System.out.println("All messages:"); + System.out.println("============="); + System.out.println(messages); + + if (assistantImageContent != null) { + System.out.println("\nAssistant answer contains an image. Downloading it now..."); + try (var in = openAI.files() + .getContentInputStream(assistantImageContent.getImageFile().getFileId()) + .join()) { + Path tempFile = Files.createTempFile("code_interpreter", ".png"); + Files.write(tempFile, in.readAllBytes()); + System.out.println("Image file downloaded to: " + tempFile.toUri()); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + } + + public void demoDeleteAssistant() { + openAI.assistants().delete(assistantId).join(); + } + + private static void sleep(int seconds) { + try { + java.lang.Thread.sleep(1000L * seconds); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + public static void main(String[] args) { + var baseUrl = System.getenv("AZURE_OPENAI_BASE_URL"); + var apiKey = System.getenv("AZURE_OPENAI_API_KEY"); + var apiVersion = System.getenv("AZURE_OPENAI_API_VERSION"); + + var demo = new AssistantAzureServiceDemo(baseUrl, apiKey, apiVersion); + + demo.addTitleAction("Demo Call Assistant Create", demo::demoCreateAssistant); + demo.addTitleAction("Demo Call Assistant Retrieve and Modify", demo::demoRetrieveAndModifyAssistant); + demo.addTitleAction("Demo Call Assistant List", demo::demoListAssistants); + demo.addTitleAction("Demo Call Assistant File Upload", demo::demoUploadAssistantFile); + demo.addTitleAction("Demo Call Assistant Thread Create", demo::demoCreateThread); + demo.addTitleAction("Demo Call Assistant Thread Run", demo::demoRunThreadAndWaitUntilComplete); + demo.addTitleAction("Demo Call Assistant Thread Run Stream", demo::demoRunThreadAndStream); + demo.addTitleAction("Demo Call Assistant Messages Get", demo::demoGetAssistantMessages); + demo.addTitleAction("Demo Call Assistant Delete", demo::demoDeleteAssistant); + + demo.run(); + } + +} diff --git a/src/demo/java/io/github/sashirestela/openai/demo/ChatAzureServiceDemo.java b/src/demo/java/io/github/sashirestela/openai/demo/ChatAzureServiceDemo.java index 8f86b5b3..7425ac91 100644 --- a/src/demo/java/io/github/sashirestela/openai/demo/ChatAzureServiceDemo.java +++ b/src/demo/java/io/github/sashirestela/openai/demo/ChatAzureServiceDemo.java @@ -160,14 +160,14 @@ public static void main(String[] args) { var demo = new ChatAzureServiceDemo(baseUrl, apiKey, apiVersion); demo.addTitleAction("Call Chat (Blocking Approach)", demo::demoCallChatBlocking); - if (baseUrl.contains("gpt-35-turbo")) { + if (baseUrl.contains("gpt-4-0125-Preview")) { demo.addTitleAction("Call Chat with Functions", demo::demoCallChatWithFunctions); - } else if (baseUrl.contains("gpt-4")) { + } else if (baseUrl.contains("gpt-35-turbo")) { demo.addTitleAction("Call Chat (Streaming Approach)", demo::demoCallChatStreaming); + } else if (baseUrl.contains("gpt-4-vision")) { demo.addTitleAction("Call Chat with Vision (External image)", demo::demoCallChatWithVisionExternalImage); demo.addTitleAction("Call Chat with Vision (Local image)", demo::demoCallChatWithVisionLocalImage); } - demo.run(); } diff --git a/src/main/java/io/github/sashirestela/openai/SimpleOpenAIAzure.java b/src/main/java/io/github/sashirestela/openai/SimpleOpenAIAzure.java index 7e186979..97730eae 100644 --- a/src/main/java/io/github/sashirestela/openai/SimpleOpenAIAzure.java +++ b/src/main/java/io/github/sashirestela/openai/SimpleOpenAIAzure.java @@ -3,6 +3,7 @@ import io.github.sashirestela.cleverclient.http.HttpRequestData; import io.github.sashirestela.cleverclient.support.ContentType; import io.github.sashirestela.openai.support.Constant; +import java.util.regex.Pattern; import lombok.Builder; import lombok.NonNull; @@ -17,6 +18,10 @@ */ public class SimpleOpenAIAzure extends BaseSimpleOpenAI { + private OpenAI.Files fileService; + private OpenAI.Assistants assistantService; + private OpenAI.Threads threadService; + /** * Constructor used to generate a builder. * @@ -34,11 +39,23 @@ public SimpleOpenAIAzure(@NonNull String apiKey, @NonNull String baseUrl, @NonNu super(prepareBaseSimpleOpenAIArgs(apiKey, baseUrl, apiVersion, httpClient)); } + private static String extractDeployment(String url) { + final String DEPLOYMENT_REGEX = "/deployments/([^/]+)/"; + + var pattern = Pattern.compile(DEPLOYMENT_REGEX); + var matcher = pattern.matcher(url); + + if (matcher.find()) { + return matcher.group(1); // Return the first matched group + } + + return null; // Return null if no match was found + } + public static BaseSimpleOpenAIArgs prepareBaseSimpleOpenAIArgs(String apiKey, String baseUrl, String apiVersion, HttpClient httpClient) { var headers = Map.of(Constant.AZURE_APIKEY_HEADER, apiKey); - UnaryOperator requestInterceptor = request -> { final String VERSION_REGEX = "(\\/v\\d+\\.*\\d*)"; final String MODEL_REGEX = ",?\"model\":\"[^\"]*\",?"; @@ -50,21 +67,37 @@ public static BaseSimpleOpenAIArgs prepareBaseSimpleOpenAIArgs(String apiKey, St var contentType = request.getContentType(); var body = request.getBody(); + var deployment = extractDeployment(url); + url += (url.contains("?") ? "&" : "?") + Constant.AZURE_API_VERSION + "=" + apiVersion; url = url.replaceFirst(VERSION_REGEX, ""); + + // Strip deployment from URL unless it's /chat/completions call + if (!url.contains("/chat/completions")) { + url = url.replaceFirst("/deployments/[^/]+/", "/"); + } + request.setUrl(url); if (contentType != null) { if (contentType.equals(ContentType.APPLICATION_JSON)) { var bodyJson = (String) request.getBody(); - bodyJson = bodyJson.replaceFirst(MODEL_REGEX, ""); + var model = ""; + if (!url.contains("/chat/completions/")) { + model = "\"" + MODEL_LITERAL + "\":\"" + deployment + "\""; + } + bodyJson = bodyJson.replaceFirst(MODEL_REGEX, model); bodyJson = bodyJson.replaceFirst(EMPTY_REGEX, QUOTED_COMMA); body = bodyJson; } if (contentType.equals(ContentType.MULTIPART_FORMDATA)) { @SuppressWarnings("unchecked") var bodyMap = (Map) request.getBody(); - bodyMap.remove(MODEL_LITERAL); + if (url.contains("/assistants/")) { + bodyMap.put(MODEL_LITERAL, deployment); + } else { + bodyMap.remove(MODEL_LITERAL); + } body = bodyMap; } request.setBody(body); @@ -81,4 +114,40 @@ public static BaseSimpleOpenAIArgs prepareBaseSimpleOpenAIArgs(String apiKey, St .build(); } + /** + * Generates an implementation of the Files interface to handle requests. + * + * @return An instance of the interface. It is created only once. + */ + public OpenAI.Files files() { + if (fileService == null) { + fileService = cleverClient.create(OpenAI.Files.class); + } + return fileService; + } + + /** + * Generates an implementation of the Assistant interface to handle requests. + * + * @return An instance of the interface. It is created only once. + */ + public OpenAI.Assistants assistants() { + if (assistantService == null) { + assistantService = cleverClient.create(OpenAI.Assistants.class); + } + return assistantService; + } + + /** + * Spawns a single instance of the Threads interface to manage requests. + * + * @return An instance of the interface. It is created only once. + */ + public OpenAI.Threads threads() { + if (threadService == null) { + threadService = cleverClient.create(OpenAI.Threads.class); + } + return threadService; + } + }