Skip to content

Commit

Permalink
WIP on azure assistant API support
Browse files Browse the repository at this point in the history
  • Loading branch information
the-gigi committed Apr 4, 2024
1 parent 743ae0f commit e9203f0
Show file tree
Hide file tree
Showing 4 changed files with 267 additions and 8 deletions.
5 changes: 3 additions & 2 deletions rundemo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@ log_file_value="demo.log"
log_options=""

if [ ! -z "$1" ]; then
service="${1^}"
service=$(echo "$1" | awk '{ print toupper(substr($0, 1, 1)) substr($0, 2) }')
fi

if [ ! -z "$2" ]; then
log_level_value="$2"
fi

if [ "$log_level_value" != "off" ]; then
log_options="${log_level_param}=${log_level_value} ${log_file_param}=${log_file_value}"
log_options="${log_level_param}=${log_level_value}"
#log_options="${log_level_param}=${log_level_value} ${log_file_param}=${log_file_value}"
fi

main_class="io.github.sashirestela.openai.demo.${service}ServiceDemo"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
package io.github.sashirestela.openai.demo;

import io.github.sashirestela.openai.SimpleOpenAIAzure;
import io.github.sashirestela.openai.domain.assistant.AssistantRequest;
import io.github.sashirestela.openai.domain.assistant.AssistantTool;
import io.github.sashirestela.openai.domain.assistant.Events;
import io.github.sashirestela.openai.domain.assistant.ImageFileContent;
import io.github.sashirestela.openai.domain.assistant.TextContent;
import io.github.sashirestela.openai.domain.assistant.ThreadMessage;
import io.github.sashirestela.openai.domain.assistant.ThreadMessageDelta;
import io.github.sashirestela.openai.domain.assistant.ThreadMessageRequest;
import io.github.sashirestela.openai.domain.assistant.ThreadRequest;
import io.github.sashirestela.openai.domain.assistant.ThreadRun;
import io.github.sashirestela.openai.domain.assistant.ThreadRunRequest;
import io.github.sashirestela.openai.domain.file.FileRequest;
import io.github.sashirestela.openai.domain.file.PurposeType;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;

public class AssistantAzureServiceDemo extends AbstractDemo {

String assistantId;
String fileId;
String threadId;
String runId;

public AssistantAzureServiceDemo(String baseUrl, String apiKey, String apiVersion) {
super(SimpleOpenAIAzure.builder()
.apiKey(apiKey)
.baseUrl(baseUrl)
.apiVersion(apiVersion)
.build());

var assistantRequest = AssistantRequest.builder()
.model("N/A")
.build();

var assistant = openAI.assistants().create(assistantRequest).join();
System.out.println(assistant);
assistantId = assistant.getId();
}



public void demoCreateAssistant() {
var assistantRequest = AssistantRequest.builder()
.model("N/A")
.build();

var assistant = openAI.assistants().create(assistantRequest).join();
System.out.println(assistant);
assistantId = assistant.getId();
}

public void demoRetrieveAndModifyAssistant() {
var assistant = openAI.assistants().getOne(assistantId).join();
var assistantRequest = assistant.mutate()
.name("Math Expert")
.instructions(
"You are a personal math expert. When asked a question, write and run Python code to answer the question.")
.tool(AssistantTool.CODE_INTERPRETER)
.build();

assistant = openAI.assistants().modify(assistant.getId(), assistantRequest).join();
System.out.println(assistant);
}

public void demoListAssistants() {
AtomicInteger count = new AtomicInteger();
openAI.assistants()
.getList()
.join()
.forEach(r -> System.out.println("\n#" + count.incrementAndGet() + "\n" + r));
}

public void demoUploadAssistantFile() {
var fileRequest = FileRequest.builder()
.file(Paths.get("src/demo/resources/code_interpreter_file.txt"))
.purpose(PurposeType.ASSISTANTS)
.build();
var file = openAI.files().create(fileRequest).join();
var assistantFile = openAI.assistants().createFile(assistantId, file.getId()).join();
System.out.println(assistantFile);
fileId = file.getId();
}

public void demoCreateThread() {
var threadRequest = ThreadRequest.builder()
.message(ThreadMessageRequest.builder()
.role("user")
.content(
"Inspect the content of the attached text file. After that plot graph of the formula requested in it.")
.build())
.build();

var thread = openAI.threads().create(threadRequest).join();
System.out.println(thread);
threadId = thread.getId();
}

public void demoRunThreadAndWaitUntilComplete() {
var run = openAI.threads().createRun(threadId, assistantId).join();
runId = run.getId();

while (!run.getStatus().equals(ThreadRun.Status.COMPLETED)) {
sleep(1);
run = openAI.threads().getRun(run.getThreadId(), run.getId()).join();
}
System.out.println(run);

var messages = openAI.threads().getMessageList(threadId).join();
System.out.println(messages);
}

public void demoRunThreadAndStream() {
var request = ThreadRunRequest.builder().assistantId(assistantId).build();
var response = openAI.threads().createRunStream(threadId, request).join();
response.filter(e -> e.getName().equals(Events.THREAD_MESSAGE_DELTA))
.map(e -> ((TextContent) ((ThreadMessageDelta) e.getData()).getDelta().getContent().get(0)).getValue())
.forEach(System.out::print);
System.out.println();
}

public void demoGetAssistantMessages() {
List<ThreadMessage> messages = openAI.threads().getMessageList(threadId).join();
ThreadMessage assistant = messages.get(0);
ImageFileContent assistantImageContent = assistant.getContent()
.stream()
.filter(ImageFileContent.class::isInstance)
.map(ImageFileContent.class::cast)
.findFirst()
.orElse(null);

System.out.println("All messages:");
System.out.println("=============");
System.out.println(messages);

if (assistantImageContent != null) {
System.out.println("\nAssistant answer contains an image. Downloading it now...");
try (var in = openAI.files()
.getContentInputStream(assistantImageContent.getImageFile().getFileId())
.join()) {
Path tempFile = Files.createTempFile("code_interpreter", ".png");
Files.write(tempFile, in.readAllBytes());
System.out.println("Image file downloaded to: " + tempFile.toUri());
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}

public void demoDeleteAssistant() {
openAI.assistants().delete(assistantId).join();
}

private static void sleep(int seconds) {
try {
java.lang.Thread.sleep(1000L * seconds);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}

public static void main(String[] args) {
var baseUrl = System.getenv("AZURE_OPENAI_BASE_URL");
var apiKey = System.getenv("AZURE_OPENAI_API_KEY");
var apiVersion = System.getenv("AZURE_OPENAI_API_VERSION");

var demo = new AssistantAzureServiceDemo(baseUrl, apiKey, apiVersion);

demo.addTitleAction("Demo Call Assistant Create", demo::demoCreateAssistant);
demo.addTitleAction("Demo Call Assistant Retrieve and Modify", demo::demoRetrieveAndModifyAssistant);
demo.addTitleAction("Demo Call Assistant List", demo::demoListAssistants);
demo.addTitleAction("Demo Call Assistant File Upload", demo::demoUploadAssistantFile);
demo.addTitleAction("Demo Call Assistant Thread Create", demo::demoCreateThread);
demo.addTitleAction("Demo Call Assistant Thread Run", demo::demoRunThreadAndWaitUntilComplete);
demo.addTitleAction("Demo Call Assistant Thread Run Stream", demo::demoRunThreadAndStream);
demo.addTitleAction("Demo Call Assistant Messages Get", demo::demoGetAssistantMessages);
demo.addTitleAction("Demo Call Assistant Delete", demo::demoDeleteAssistant);

demo.run();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -160,14 +160,14 @@ public static void main(String[] args) {
var demo = new ChatAzureServiceDemo(baseUrl, apiKey, apiVersion);

demo.addTitleAction("Call Chat (Blocking Approach)", demo::demoCallChatBlocking);
if (baseUrl.contains("gpt-35-turbo")) {
if (baseUrl.contains("gpt-4-0125-Preview")) {
demo.addTitleAction("Call Chat with Functions", demo::demoCallChatWithFunctions);
} else if (baseUrl.contains("gpt-4")) {
} else if (baseUrl.contains("gpt-35-turbo")) {
demo.addTitleAction("Call Chat (Streaming Approach)", demo::demoCallChatStreaming);
} else if (baseUrl.contains("gpt-4-vision")) {
demo.addTitleAction("Call Chat with Vision (External image)", demo::demoCallChatWithVisionExternalImage);
demo.addTitleAction("Call Chat with Vision (Local image)", demo::demoCallChatWithVisionLocalImage);
}

demo.run();
}

Expand Down
75 changes: 72 additions & 3 deletions src/main/java/io/github/sashirestela/openai/SimpleOpenAIAzure.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import io.github.sashirestela.cleverclient.http.HttpRequestData;
import io.github.sashirestela.cleverclient.support.ContentType;
import io.github.sashirestela.openai.support.Constant;
import java.util.regex.Pattern;
import lombok.Builder;
import lombok.NonNull;

Expand All @@ -17,6 +18,10 @@
*/
public class SimpleOpenAIAzure extends BaseSimpleOpenAI {

private OpenAI.Files fileService;
private OpenAI.Assistants assistantService;
private OpenAI.Threads threadService;

/**
* Constructor used to generate a builder.
*
Expand All @@ -34,11 +39,23 @@ public SimpleOpenAIAzure(@NonNull String apiKey, @NonNull String baseUrl, @NonNu
super(prepareBaseSimpleOpenAIArgs(apiKey, baseUrl, apiVersion, httpClient));
}

private static String extractDeployment(String url) {
final String DEPLOYMENT_REGEX = "/deployments/([^/]+)/";

var pattern = Pattern.compile(DEPLOYMENT_REGEX);
var matcher = pattern.matcher(url);

if (matcher.find()) {
return matcher.group(1); // Return the first matched group
}

return null; // Return null if no match was found
}

public static BaseSimpleOpenAIArgs prepareBaseSimpleOpenAIArgs(String apiKey, String baseUrl, String apiVersion,
HttpClient httpClient) {

var headers = Map.of(Constant.AZURE_APIKEY_HEADER, apiKey);

UnaryOperator<HttpRequestData> requestInterceptor = request -> {
final String VERSION_REGEX = "(\\/v\\d+\\.*\\d*)";
final String MODEL_REGEX = ",?\"model\":\"[^\"]*\",?";
Expand All @@ -50,21 +67,37 @@ public static BaseSimpleOpenAIArgs prepareBaseSimpleOpenAIArgs(String apiKey, St
var contentType = request.getContentType();
var body = request.getBody();

var deployment = extractDeployment(url);

url += (url.contains("?") ? "&" : "?") + Constant.AZURE_API_VERSION + "=" + apiVersion;
url = url.replaceFirst(VERSION_REGEX, "");

// Strip deployment from URL unless it's /chat/completions call
if (!url.contains("/chat/completions")) {
url = url.replaceFirst("/deployments/[^/]+/", "/");
}

request.setUrl(url);

if (contentType != null) {
if (contentType.equals(ContentType.APPLICATION_JSON)) {
var bodyJson = (String) request.getBody();
bodyJson = bodyJson.replaceFirst(MODEL_REGEX, "");
var model = "";
if (!url.contains("/chat/completions/")) {
model = "\"" + MODEL_LITERAL + "\":\"" + deployment + "\"";
}
bodyJson = bodyJson.replaceFirst(MODEL_REGEX, model);
bodyJson = bodyJson.replaceFirst(EMPTY_REGEX, QUOTED_COMMA);
body = bodyJson;
}
if (contentType.equals(ContentType.MULTIPART_FORMDATA)) {
@SuppressWarnings("unchecked")
var bodyMap = (Map<String, Object>) request.getBody();
bodyMap.remove(MODEL_LITERAL);
if (url.contains("/assistants/")) {
bodyMap.put(MODEL_LITERAL, deployment);
} else {
bodyMap.remove(MODEL_LITERAL);
}
body = bodyMap;
}
request.setBody(body);
Expand All @@ -81,4 +114,40 @@ public static BaseSimpleOpenAIArgs prepareBaseSimpleOpenAIArgs(String apiKey, St
.build();
}

/**
* Generates an implementation of the Files interface to handle requests.
*
* @return An instance of the interface. It is created only once.
*/
public OpenAI.Files files() {
if (fileService == null) {
fileService = cleverClient.create(OpenAI.Files.class);
}
return fileService;
}

/**
* Generates an implementation of the Assistant interface to handle requests.
*
* @return An instance of the interface. It is created only once.
*/
public OpenAI.Assistants assistants() {
if (assistantService == null) {
assistantService = cleverClient.create(OpenAI.Assistants.class);
}
return assistantService;
}

/**
* Spawns a single instance of the Threads interface to manage requests.
*
* @return An instance of the interface. It is created only once.
*/
public OpenAI.Threads threads() {
if (threadService == null) {
threadService = cleverClient.create(OpenAI.Threads.class);
}
return threadService;
}

}

0 comments on commit e9203f0

Please sign in to comment.