Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assistants API support for Azure OpenAI #72

Merged
merged 20 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion README.md
sashirestela marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ private static ImageUrl loadImageAsBase64(String imagePath) {

## ✳ Run Examples
Examples for each OpenAI service have been created in the folder [demo](https://github.com/sashirestela/simple-openai/tree/main/src/demo/java/io/github/sashirestela/openai/demo) and you can follow the next steps to execute them:
* Clone this respository:
* Clone this repository:
```
git clone https://github.com/sashirestela/simple-openai.git
cd simple-openai
Expand All @@ -366,6 +366,43 @@ Examples for each OpenAI service have been created in the folder [demo](https://
```
export OPENAI_API_KEY=<here goes your api key>
```
* Create environment variables for the Azure OpenAI demos

Azure OpenAI requires a separate deployment for each model. The Azure OpenAI demos require
two models.

1. gpt-4-turbo (or similar) that supports:
- /chat/completions (including tool calls)
- /files
- /assistants (beta)
- /threads (beta)

3. gpt-4-vision-preview that supports:
- /chat/completions (including images).

See the Azure OpenAI docs for more details: [Azure OpenAI documentation](https://learn.microsoft.com/en-us/azure/ai-services/openai/).
Once you have the deployment URLs and the API keys, set the following environment variables:
```
export AZURE_OPENAI_BASE_URL=<your gpt-4-turbo deployment endpoint URL>
export AZURE_OPENAI_API_KEY=<here goes your regional API key>
export AZURE_OPENAI_BASE_URL_VISION=<your gpt-4 vision preview deployment endpoint URL>
export AZURE_OPENAI_API_KEY_VISION=<here goes your regional API key>
export AZURE_OPENAI_API_VERSION=2024-02-15-preview
```
Note that some models may not be available in all regions. If you have trouble finding a model,
try a different region. The API keys are regional (per cognitive account). If you provision
multiple models in the same region they will share the same API key (actually there are two keys
per region to support alternate key rotation).

At the moment the simple-openai support for Azure OpenAI includes the following OpenAI endpoints:
- /chat/completions (including tool calls)
- /chat/completions (including images)
- /files
- /assistants (beta)
- /threads (beta)

In addition, streaming mode is not supported at the moment.

* Grant execution permission to the script file:
```
chmod +x rundemo.sh
Expand All @@ -379,6 +416,8 @@ Examples for each OpenAI service have been created in the folder [demo](https://
* ```<demo>``` Is mandatory and must be one of the values:
* audio
* chat
* chatAnyscale
* chatAzure
sashirestela marked this conversation as resolved.
Show resolved Hide resolved
* completion
* embedding
* file
Expand All @@ -387,6 +426,7 @@ Examples for each OpenAI service have been created in the folder [demo](https://
* model
* moderation
* assistant
* assistantAzure

* ```[debug]``` Is optional and creates the ```demo.log``` file where you can see log details for each execution.
* For example, to run the chat demo with a log file: ```./rundemo.sh chat debug```
Expand Down
2 changes: 1 addition & 1 deletion rundemo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ log_file_value="demo.log"
log_options=""

if [ ! -z "$1" ]; then
service="${1^}"
service=$(echo "$1" | awk '{ print toupper(substr($0, 1, 1)) substr($0, 2) }')
sashirestela marked this conversation as resolved.
Show resolved Hide resolved
fi

if [ ! -z "$2" ]; then
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
package io.github.sashirestela.openai.demo;

import io.github.sashirestela.openai.SimpleOpenAIAzure;
import io.github.sashirestela.openai.domain.assistant.AssistantRequest;
import io.github.sashirestela.openai.domain.assistant.AssistantTool;
import io.github.sashirestela.openai.domain.assistant.ImageFileContent;
import io.github.sashirestela.openai.domain.assistant.ThreadMessage;
import io.github.sashirestela.openai.domain.assistant.ThreadMessageRequest;
import io.github.sashirestela.openai.domain.assistant.ThreadRequest;
import io.github.sashirestela.openai.domain.assistant.ThreadRun;
import io.github.sashirestela.openai.domain.file.FileRequest;
import io.github.sashirestela.openai.domain.file.PurposeType;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;

public class AssistantAzureServiceDemo extends AbstractDemo {

String assistantId;
String fileId;
String threadId;
String runId;

public AssistantAzureServiceDemo(String baseUrl, String apiKey, String apiVersion) {
super(SimpleOpenAIAzure.builder()
.apiKey(apiKey)
.baseUrl(baseUrl)
.apiVersion(apiVersion)
.build());

var assistantRequest = AssistantRequest.builder()
.model("N/A")
.build();

var assistant = openAI.assistants().create(assistantRequest).join();
System.out.println(assistant);
assistantId = assistant.getId();
}

public void demoCreateAssistant() {
var assistantRequest = AssistantRequest.builder()
.model("N/A")
.build();

var assistant = openAI.assistants().create(assistantRequest).join();
System.out.println(assistant);
assistantId = assistant.getId();
}

public void demoRetrieveAndModifyAssistant() {
var assistant = openAI.assistants().getOne(assistantId).join();
var assistantRequest = assistant.mutate()
.name("Math Expert")
.instructions(
"You are a personal math expert. When asked a question, write and run Python code to answer the question.")
.tool(AssistantTool.CODE_INTERPRETER)
.build();

assistant = openAI.assistants().modify(assistant.getId(), assistantRequest).join();
System.out.println(assistant);
}

public void demoListAssistants() {
AtomicInteger count = new AtomicInteger();
openAI.assistants()
.getList()
.join()
.forEach(r -> System.out.println("\n#" + count.incrementAndGet() + "\n" + r));
}

public void demoUploadAssistantFile() {
var fileRequest = FileRequest.builder()
.file(Paths.get("src/demo/resources/code_interpreter_file.txt"))
.purpose(PurposeType.ASSISTANTS)
.build();
var file = openAI.files().create(fileRequest).join();
var assistantFile = openAI.assistants().createFile(assistantId, file.getId()).join();
System.out.println(assistantFile);
fileId = file.getId();
}

public void demoCreateThread() {
var threadRequest = ThreadRequest.builder()
.message(ThreadMessageRequest.builder()
.role(ThreadMessageRequest.Role.USER)
.content(
"Inspect the content of the attached text file. After that plot graph of the formula requested in it.")
.build())
.build();

var thread = openAI.threads().create(threadRequest).join();
System.out.println(thread);
threadId = thread.getId();
}

public void demoRunThreadAndWaitUntilComplete() {
var run = openAI.threads().createRun(threadId, assistantId).join();
runId = run.getId();

while (!run.getStatus().equals(ThreadRun.Status.COMPLETED)) {
sleep(1);
run = openAI.threads().getRun(run.getThreadId(), run.getId()).join();
}
System.out.println(run);

var messages = openAI.threads().getMessageList(threadId).join();
System.out.println(messages);
}

public void demoGetAssistantMessages() {
List<ThreadMessage> messages = openAI.threads().getMessageList(threadId).join();
ThreadMessage assistant = messages.get(0);
ImageFileContent assistantImageContent = assistant.getContent()
.stream()
.filter(ImageFileContent.class::isInstance)
.map(ImageFileContent.class::cast)
.findFirst()
.orElse(null);

System.out.println("All messages:");
System.out.println("=============");
System.out.println(messages);

if (assistantImageContent != null) {
System.out.println("\nAssistant answer contains an image. Downloading it now...");
try (var in = openAI.files()
.getContentInputStream(assistantImageContent.getImageFile().getFileId())
.join()) {
Path tempFile = Files.createTempFile("code_interpreter", ".png");
Files.write(tempFile, in.readAllBytes());
System.out.println("Image file downloaded to: " + tempFile.toUri());
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}

public void demoDeleteAssistant() {
openAI.assistants().delete(assistantId).join();
}

private static void sleep(int seconds) {
try {
java.lang.Thread.sleep(1000L * seconds);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}

public static void main(String[] args) {
var baseUrl = System.getenv("AZURE_OPENAI_BASE_URL");
var apiKey = System.getenv("AZURE_OPENAI_API_KEY");
var apiVersion = System.getenv("AZURE_OPENAI_API_VERSION");

var demo = new AssistantAzureServiceDemo(baseUrl, apiKey, apiVersion);

demo.addTitleAction("Demo Call Assistant Create", demo::demoCreateAssistant);
demo.addTitleAction("Demo Call Assistant Retrieve and Modify", demo::demoRetrieveAndModifyAssistant);
demo.addTitleAction("Demo Call Assistant List", demo::demoListAssistants);
demo.addTitleAction("Demo Call Assistant File Upload", demo::demoUploadAssistantFile);
demo.addTitleAction("Demo Call Assistant Thread Create", demo::demoCreateThread);
demo.addTitleAction("Demo Call Assistant Thread Run", demo::demoRunThreadAndWaitUntilComplete);
demo.addTitleAction("Demo Call Assistant Messages Get", demo::demoGetAssistantMessages);
demo.addTitleAction("Demo Call Assistant Delete", demo::demoDeleteAssistant);

demo.run();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import io.github.sashirestela.openai.demo.ChatServiceDemo.RunAlarm;
import io.github.sashirestela.openai.demo.ChatServiceDemo.Weather;
import io.github.sashirestela.openai.domain.chat.ChatRequest;
import io.github.sashirestela.openai.domain.chat.ChatResponse;
import io.github.sashirestela.openai.domain.chat.content.ContentPartImage;
import io.github.sashirestela.openai.domain.chat.content.ContentPartText;
import io.github.sashirestela.openai.domain.chat.content.ImageUrl;
Expand Down Expand Up @@ -42,15 +41,6 @@ public ChatAzureServiceDemo(String baseUrl, String apiKey, String apiVersion) {
.build();
}

public void demoCallChatStreaming() {
var futureChat = openAI.chatCompletions().createStream(chatRequest);
var chatResponse = futureChat.join();
chatResponse.filter(chatResp -> chatResp.firstContent() != null)
.map(ChatResponse::firstContent)
.forEach(System.out::print);
System.out.println();
}

public void demoCallChatBlocking() {
var futureChat = openAI.chatCompletions().create(chatRequest);
var chatResponse = futureChat.join();
Expand Down Expand Up @@ -113,11 +103,11 @@ public void demoCallChatWithVisionExternalImage() {
.temperature(0.0)
.maxTokens(500)
.build();
var chatResponse = openAI.chatCompletions().createStream(chatRequest).join();
chatResponse.filter(chatResp -> chatResp.firstContent() != null)
.map(chatResp -> chatResp.firstContent())
.forEach(System.out::print);

var chatResponse = openAI.chatCompletions().create(chatRequest).join();
sashirestela marked this conversation as resolved.
Show resolved Hide resolved
System.out.println(chatResponse.firstContent());
System.out.println();

}

public void demoCallChatWithVisionLocalImage() {
Expand All @@ -131,11 +121,8 @@ public void demoCallChatWithVisionLocalImage() {
.temperature(0.0)
.maxTokens(500)
.build();
var chatResponse = openAI.chatCompletions().createStream(chatRequest).join();
chatResponse.filter(chatResp -> chatResp.firstContent() != null)
.map(chatResp -> chatResp.firstContent())
.forEach(System.out::print);
System.out.println();
var chatResponse = openAI.chatCompletions().create(chatRequest).join();
sashirestela marked this conversation as resolved.
Show resolved Hide resolved
System.out.println(chatResponse.firstContent());
}

private static ImageUrl loadImageAsBase64(String imagePath) {
Expand All @@ -152,23 +139,31 @@ private static ImageUrl loadImageAsBase64(String imagePath) {
}
}

public static void main(String[] args) {
private static void chatWithFunctionsDemo(String apiVersion) {
var baseUrl = System.getenv("AZURE_OPENAI_BASE_URL");
var apiKey = System.getenv("AZURE_OPENAI_API_KEY");
var apiVersion = System.getenv("AZURE_OPENAI_API_VERSION");
var chatDemo = new ChatAzureServiceDemo(baseUrl, apiKey, apiVersion);
chatDemo.addTitleAction("Call Chat (Blocking Approach)", chatDemo::demoCallChatBlocking);
chatDemo.addTitleAction("Call Chat with Functions", chatDemo::demoCallChatWithFunctions);

var demo = new ChatAzureServiceDemo(baseUrl, apiKey, apiVersion);
chatDemo.run();
}

demo.addTitleAction("Call Chat (Blocking Approach)", demo::demoCallChatBlocking);
if (baseUrl.contains("gpt-35-turbo")) {
demo.addTitleAction("Call Chat with Functions", demo::demoCallChatWithFunctions);
} else if (baseUrl.contains("gpt-4")) {
demo.addTitleAction("Call Chat (Streaming Approach)", demo::demoCallChatStreaming);
demo.addTitleAction("Call Chat with Vision (External image)", demo::demoCallChatWithVisionExternalImage);
demo.addTitleAction("Call Chat with Vision (Local image)", demo::demoCallChatWithVisionLocalImage);
}
private static void chatWithVisionDemo(String apiVersion) {
var baseUrl = System.getenv("AZURE_OPENAI_BASE_URL_VISION");
var apiKey = System.getenv("AZURE_OPENAI_API_KEY_VISION");
var visionDemo = new ChatAzureServiceDemo(baseUrl, apiKey, apiVersion);
visionDemo.addTitleAction("Call Chat with Vision (External image)",
visionDemo::demoCallChatWithVisionExternalImage);
visionDemo.addTitleAction("Call Chat with Vision (Local image)", visionDemo::demoCallChatWithVisionLocalImage);
visionDemo.run();
}

public static void main(String[] args) {
var apiVersion = System.getenv("AZURE_OPENAI_API_VERSION");

demo.run();
chatWithFunctionsDemo(apiVersion);
chatWithVisionDemo(apiVersion);
}

}
Loading