From 2f3dceac2f64a5c9d03fd77c4c87d08f5a8d5211 Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Tue, 1 Aug 2023 14:07:14 +0300 Subject: [PATCH] Hugging Face Support (#12) * feat: Added support for reset * fix: Improved delete Collection test case. * feat: Cohere support Refs: #1 * feat: Added integration tests in a separate workflow - Release workflow now is only triggered upon new release and it bumps version to tag name * fix: Added reset after each test to ensure test are not failing due to what previous tests have done. * feat: List collections - Improvements of how Collections are handled Refs: #3 * feat: Full feature parity with the Python API (except for raw SQL execution) Refs: #2, #3, #4 * fix: Fixing release workflow * chore: Testing release * chore: Testing release - removed testing in release workflow * chore: For testing added tags to wf triggers. * chore: For testing added tags to wf triggers. * chore: For testing added tags to wf triggers. * chore: Removing commit of bumped version * fix: OSSRH creds missing * chore: Update POM to comply with OSSRH requirements. * chore: Finalizing release workflow. * feat: Auto release to central - Integration tests now apply to chroma version 0.4.3 and 0.4.4 - Some tests are conditional on API version * chore: Fixed a typo in pom * docs: Updated README.md with usage information * feat: HuggingFace Support - All tests have basic assertions to verify Refs: #11, #7 * fix: Added HF_API_KEY key Refs: #11, #7 --- .github/workflows/integration-test.yml | 1 + README.md | 82 +++++++++---------- pom.xml | 2 +- .../chromadb/CohereEmbeddingFunction.java | 7 ++ .../amikos/chromadb/EmbeddingFunction.java | 2 + .../HuggingFaceEmbeddingFunction.java | 33 ++++++++ .../chromadb/OpenAIEmbeddingFunction.java | 11 ++- .../amikos/hf/CreateEmbeddingRequest.java | 40 +++++++++ .../amikos/hf/CreateEmbeddingResponse.java | 22 +++++ .../tech/amikos/hf/HuggingFaceClient.java | 73 +++++++++++++++++ src/test/java/TestAPI.java | 68 ++++++++++++--- src/test/java/TestHuggingFaceClient.java | 21 +++++ 12 files changed, 304 insertions(+), 58 deletions(-) create mode 100644 src/main/java/tech/amikos/chromadb/HuggingFaceEmbeddingFunction.java create mode 100644 src/main/java/tech/amikos/hf/CreateEmbeddingRequest.java create mode 100644 src/main/java/tech/amikos/hf/CreateEmbeddingResponse.java create mode 100644 src/main/java/tech/amikos/hf/HuggingFaceClient.java create mode 100644 src/test/java/TestHuggingFaceClient.java diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index 1b2fe02..97400bc 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -60,4 +60,5 @@ jobs: env: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }} + HF_API_KEY: ${{ secrets.HF_API_KEY }} CHROMA_URL: ${{steps.wait-and-set.outputs.chroma-url}} diff --git a/README.md b/README.md index 65d8023..44b9407 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ This is a very basic/naive implementation in Java of the Chroma Vector Database API. -This client works with Chroma Version `0.4.3` +This client works with Chroma Versions `0.4.3+` ## Features @@ -35,23 +35,22 @@ This client works with Chroma Version `0.4.3` ## TODO -- [ ] Add support for other embedding functions -- [ ] Push the package to Maven +- [x] Push the package to Maven Central - https://docs.github.com/en/actions/publishing-packages/publishing-java-packages-with-maven - [ ] Fluent API - make it easier for users to make use of the library +- [ ] Support for PaLM API +- [ ] Support for Sentence Transformers with Hugging Face API ## Usage -Clone the repository and install the package locally: +Add Maven dependency: -```bash -git clone git@github.com:amikos-tech/chromadb-java-client.git -``` - -Install dependencies: - -```bash -mvn clean compile +```xml + + io.github.amikos-tech + chromadb-java-client + 0.1.1 + ``` Ensure you have a running instance of Chroma running. We recommend one of the two following options: @@ -60,43 +59,40 @@ Ensure you have a running instance of Chroma running. We recommend one of the tw - If you are a fan of Kubernetes, you can use the Helm chart - https://github.com/amikos-tech/chromadb-chart (Note: You will need `Docker`, `minikube` and `kubectl` installed) -Run tests: - -| **Important**: Since we are using the OpenAI API, you need to set the `OPENAI_API_KEY` environment variable. Simply -create `.env` file in the root of the repository. - -```bash -mvn test -``` - -## Example +### Example ```java +package tech.amikos; + import com.google.gson.internal.LinkedTreeMap; -import io.github.cdimascio.dotenv.Dotenv; import tech.amikos.chromadb.Client; import tech.amikos.chromadb.Collection; import tech.amikos.chromadb.EmbeddingFunction; import tech.amikos.chromadb.OpenAIEmbeddingFunction; -import tech.amikos.chromadb.handler.ApiException; - -class TestApi { - public void testQueryExample() throws ApiException { - Client client = new Client("http://localhost:8000"); - Dotenv dotenv = Dotenv.load(); - String apiKey = dotenv.get("OPENAI_API_KEY"); - EmbeddingFunction ef = new OpenAIEmbeddingFunction(apiKey); - Collection collection = client.createCollection("test-collection", null, true, ef); - List> metadata = new ArrayList<>(); - metadata.add(new HashMap() {{ - put("type", "scientist"); - }}); - metadata.add(new HashMap() {{ - put("type", "spy"); - }}); - collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2")); - LinkedTreeMap qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null); - System.out.println(qr); + +import java.util.*; + +public class Main { + public static void main(String[] args) { + try { + Client client = new Client(System.getenv("CHROMA_URL")); + String apiKey = System.getenv("OPENAI_API_KEY"); + EmbeddingFunction ef = new OpenAIEmbeddingFunction(apiKey); + Collection collection = client.createCollection("test-collection", null, true, ef); + List> metadata = new ArrayList<>(); + metadata.add(new HashMap() {{ + put("type", "scientist"); + }}); + metadata.add(new HashMap() {{ + put("type", "spy"); + }}); + collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2")); + Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null); + System.out.println(qr); + } catch (Exception e) { + e.printStackTrace(); + System.out.println(e); + } } } ``` @@ -104,7 +100,7 @@ class TestApi { The above should output: ```bash -{ids=[[2, 1]], distances=[[0.28461432651150426, 0.5096168232841949]], metadatas=[[{key=value}, {key=value}]], embeddings=null, documents=[[Hello, my name is Bond. I am a Spy., Hello, my name is John. I am a Data Scientist.]]} +{"documents":[["Hello, my name is Bond. I am a Spy.","Hello, my name is John. I am a Data Scientist."]],"ids":[["2","1"]],"metadatas":[[{"type":"spy"},{"type":"scientist"}]],"distances":[[0.28461432,0.50961685]]} ``` ## Development Notes diff --git a/pom.xml b/pom.xml index fc5c1b0..fa63f6b 100644 --- a/pom.xml +++ b/pom.xml @@ -27,7 +27,7 @@ scm:git:git://github.com/amikos-tech/chromadb-java-client.git scm:git:ssh://git@github.com:amikos-tech/chromadb-java-client.git - https://github.com/amikos-tech/chromadb-java-client/tree/master + https://github.com/amikos-tech/chromadb-java-client/tree/main MIT diff --git a/src/main/java/tech/amikos/chromadb/CohereEmbeddingFunction.java b/src/main/java/tech/amikos/chromadb/CohereEmbeddingFunction.java index ac3195e..e828164 100644 --- a/src/main/java/tech/amikos/chromadb/CohereEmbeddingFunction.java +++ b/src/main/java/tech/amikos/chromadb/CohereEmbeddingFunction.java @@ -23,4 +23,11 @@ public List> createEmbedding(List documents) { CreateEmbeddingResponse response = client.createEmbedding(new CreateEmbeddingRequest().texts(documents.toArray(new String[0]))); return response.getEmbeddings(); } + + @Override + public List> createEmbedding(List documents, String model) { + CohereClient client = new CohereClient(this.cohereAPIKey); + CreateEmbeddingResponse response = client.createEmbedding(new CreateEmbeddingRequest().texts(documents.toArray(new String[0])).model(model)); + return response.getEmbeddings(); + } } diff --git a/src/main/java/tech/amikos/chromadb/EmbeddingFunction.java b/src/main/java/tech/amikos/chromadb/EmbeddingFunction.java index d6b0a8c..bda1477 100644 --- a/src/main/java/tech/amikos/chromadb/EmbeddingFunction.java +++ b/src/main/java/tech/amikos/chromadb/EmbeddingFunction.java @@ -5,4 +5,6 @@ public interface EmbeddingFunction { List> createEmbedding(List documents); + + List> createEmbedding(List documents, String model); } diff --git a/src/main/java/tech/amikos/chromadb/HuggingFaceEmbeddingFunction.java b/src/main/java/tech/amikos/chromadb/HuggingFaceEmbeddingFunction.java new file mode 100644 index 0000000..d1c7d33 --- /dev/null +++ b/src/main/java/tech/amikos/chromadb/HuggingFaceEmbeddingFunction.java @@ -0,0 +1,33 @@ +package tech.amikos.chromadb; + + +import tech.amikos.hf.CreateEmbeddingRequest; +import tech.amikos.hf.CreateEmbeddingResponse; +import tech.amikos.hf.HuggingFaceClient; + +import java.util.List; + +public class HuggingFaceEmbeddingFunction implements EmbeddingFunction { + + private final String hfAPIKey; + + public HuggingFaceEmbeddingFunction(String hfAPIKey) { + this.hfAPIKey = hfAPIKey; + + } + + @Override + public List> createEmbedding(List documents) { + HuggingFaceClient client = new HuggingFaceClient(this.hfAPIKey); + CreateEmbeddingResponse response = client.createEmbedding(new CreateEmbeddingRequest().inputs(documents.toArray(new String[0]))); + return response.getEmbeddings(); + } + + @Override + public List> createEmbedding(List documents, String model) { + HuggingFaceClient client = new HuggingFaceClient(this.hfAPIKey); + client.modelId(model); + CreateEmbeddingResponse response = client.createEmbedding(new CreateEmbeddingRequest().inputs(documents.toArray(new String[0]))); + return response.getEmbeddings(); + } +} diff --git a/src/main/java/tech/amikos/chromadb/OpenAIEmbeddingFunction.java b/src/main/java/tech/amikos/chromadb/OpenAIEmbeddingFunction.java index 55b3a72..4602088 100644 --- a/src/main/java/tech/amikos/chromadb/OpenAIEmbeddingFunction.java +++ b/src/main/java/tech/amikos/chromadb/OpenAIEmbeddingFunction.java @@ -23,7 +23,16 @@ public List> createEmbedding(List documents) { OpenAIClient client = new OpenAIClient(); CreateEmbeddingResponse response = client.apiKey(this.openAIAPIKey) .createEmbedding(req); -// return response.getEmbeddings(); + return response.getData().stream().map(emb -> emb.getEmbedding()).collect(Collectors.toList()); + } + + @Override + public List> createEmbedding(List documents, String model) { + CreateEmbeddingRequest req = new CreateEmbeddingRequest().model(model); + req.input(new CreateEmbeddingRequest.Input(documents.toArray(new String[0]))); + OpenAIClient client = new OpenAIClient(); + CreateEmbeddingResponse response = client.apiKey(this.openAIAPIKey) + .createEmbedding(req); return response.getData().stream().map(emb -> emb.getEmbedding()).collect(Collectors.toList()); } } diff --git a/src/main/java/tech/amikos/hf/CreateEmbeddingRequest.java b/src/main/java/tech/amikos/hf/CreateEmbeddingRequest.java new file mode 100644 index 0000000..4a8cc19 --- /dev/null +++ b/src/main/java/tech/amikos/hf/CreateEmbeddingRequest.java @@ -0,0 +1,40 @@ +package tech.amikos.hf; + +import com.google.gson.Gson; +import com.google.gson.annotations.SerializedName; + +import java.util.HashMap; + +public class CreateEmbeddingRequest { + + @SerializedName("inputs") + private String[] inputs; + @SerializedName("options") + private HashMap options; + + public CreateEmbeddingRequest inputs(String[] inputs) { + this.inputs = inputs; + return this; + } + + public CreateEmbeddingRequest options(HashMap options) { + this.options = options; + return this; + } + + public String[] getInputs() { + return inputs; + } + + public HashMap getOptions() { + return options; + } + + public String toString() { + return this.json(); + } + + public String json() { + return new Gson().toJson(this); + } +} diff --git a/src/main/java/tech/amikos/hf/CreateEmbeddingResponse.java b/src/main/java/tech/amikos/hf/CreateEmbeddingResponse.java new file mode 100644 index 0000000..0c819a8 --- /dev/null +++ b/src/main/java/tech/amikos/hf/CreateEmbeddingResponse.java @@ -0,0 +1,22 @@ +package tech.amikos.hf; + +import com.google.gson.Gson; + +import java.util.List; + +public class CreateEmbeddingResponse { + public List> embeddings; + + public List> getEmbeddings() { + return embeddings; + } + + public CreateEmbeddingResponse(List> embeddings) { + this.embeddings = embeddings; + } + + @Override + public String toString() { + return new Gson().toJson(this); + } +} diff --git a/src/main/java/tech/amikos/hf/HuggingFaceClient.java b/src/main/java/tech/amikos/hf/HuggingFaceClient.java new file mode 100644 index 0000000..dc65205 --- /dev/null +++ b/src/main/java/tech/amikos/hf/HuggingFaceClient.java @@ -0,0 +1,73 @@ +package tech.amikos.hf; + +import com.google.gson.Gson; +import okhttp3.*; + +import java.io.IOException; +import java.util.List; + +//https://huggingface.co/blog/getting-started-with-embeddings +public class HuggingFaceClient { + + private String baseUrl = "https://api-inference.huggingface.co/pipeline/feature-extraction/"; + private String apiKey; + + private OkHttpClient client = new OkHttpClient(); + + private String modelId = "sentence-transformers/all-MiniLM-L6-v2"; + private Gson gson = new Gson(); + MediaType JSON = MediaType.parse("application/json; charset=utf-8"); + + public HuggingFaceClient(String apiKey) { + this.apiKey = apiKey; + } + + + public HuggingFaceClient apiKey(String apiKey) { + this.apiKey = apiKey; + return this; + } + + + public HuggingFaceClient baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + public HuggingFaceClient modelId(String modelId) { + this.modelId = modelId; + return this; + } + + private String getApiKey() { + if (this.apiKey == null) { + throw new RuntimeException("API Key not set"); + } + return this.apiKey; + } + + public CreateEmbeddingResponse createEmbedding(CreateEmbeddingRequest req) { + Request request = new Request.Builder() + .url(this.baseUrl + this.modelId) + .post(RequestBody.create(req.json(), JSON)) + .addHeader("Accept", "application/json") + .addHeader("Content-Type", "application/json") + .addHeader("Authorization", "Bearer " + getApiKey()) + .build(); + try (Response response = client.newCall(request).execute()) { + if (!response.isSuccessful()) { + throw new IOException("Unexpected code " + response); + } + + String responseData = response.body().string(); + + List parsedResponse = gson.fromJson(responseData, List.class); + + return new CreateEmbeddingResponse(parsedResponse); + } catch (IOException e) { + e.printStackTrace(); + } + return null; + } + +} diff --git a/src/test/java/TestAPI.java b/src/test/java/TestAPI.java index 70d5343..7c2c2d5 100644 --- a/src/test/java/TestAPI.java +++ b/src/test/java/TestAPI.java @@ -5,6 +5,7 @@ import tech.amikos.chromadb.handler.ApiException; import java.io.IOException; +import java.math.BigDecimal; import java.util.*; import static org.junit.Assert.*; @@ -15,19 +16,11 @@ public class TestAPI { @Test public void testHeartbeat() throws ApiException, IOException { - Client client = new Client(Utils.getEnvOrProperty("CHROMA_URL")); - System.out.println(client.heartbeat()); - } - - @Test - public void testGetCollection() throws ApiException, IOException { Utils.loadEnvFile(".env"); Client client = new Client(Utils.getEnvOrProperty("CHROMA_URL")); - client.reset(); - String apiKey = Utils.getEnvOrProperty("OPENAI_API_KEY"); - EmbeddingFunction ef = new OpenAIEmbeddingFunction(apiKey); - client.createCollection("test-collection", null, true, ef); - System.out.println(client.getCollection("test-collection", ef)); + Map hb = client.heartbeat(); + System.out.println(hb); + assertTrue(hb.containsKey("nanosecond heartbeat")); } @Test @@ -39,6 +32,7 @@ public void testGetCollectionGet() throws ApiException, IOException { EmbeddingFunction ef = new OpenAIEmbeddingFunction(apiKey); client.createCollection("test-collection", null, true, ef); System.out.println(client.getCollection("test-collection", ef).get()); + assertTrue(client.getCollection("test-collection", ef).get() != null); } @@ -49,8 +43,9 @@ public void testCreateCollection() throws ApiException { client.reset(); String apiKey = Utils.getEnvOrProperty("OPENAI_API_KEY"); EmbeddingFunction ef = new OpenAIEmbeddingFunction(apiKey); - Collection resp = client.createCollection("test-collection", null, true, ef); - System.out.println(resp); + Collection collection = client.createCollection("test-collection", null, true, ef); + System.out.println(collection); + assertEquals(collection.getName(), "test-collection"); } @Test @@ -85,6 +80,7 @@ public void testCreateUpsert() throws ApiException { Object resp = collection.upsert(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist."), Arrays.asList("1")); System.out.println(resp); System.out.println(collection.get()); + assertTrue(collection.count() == 1); } @Test @@ -102,6 +98,7 @@ public void testCreateAdd() throws ApiException { Object resp = collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist."), Arrays.asList("1")); System.out.println(resp); System.out.println(collection.get()); + assertTrue(collection.count() == 1); } @Test @@ -120,6 +117,7 @@ public void testQuery() throws ApiException { collection.add(null, metadata, Arrays.asList("Hello, my name is Bond. I am a Spy."), Arrays.asList("2")); Collection.QueryResponse qr = collection.query(Arrays.asList("name is John"), 10, null, null, null); System.out.println(qr); + assertEquals(qr.getIds().get(0).get(0), "1"); //we check that Bond doc is first } @Test @@ -140,6 +138,7 @@ public void testQueryExample() throws ApiException { collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2")); Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null); System.out.println(qr); + assertEquals(qr.getIds().get(0).get(0), "2"); //we check that Bond doc is first } @Test @@ -175,6 +174,7 @@ public void testCreateAddCohere() throws ApiException { Object resp = collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist."), Arrays.asList("1")); System.out.println(resp); System.out.println(collection.get()); + assertTrue(collection.get().getDocuments().size() == 1); } @Test @@ -227,6 +227,7 @@ public void testCollectionCount() throws ApiException { System.out.println(resp); System.out.println(collection.get()); System.out.println(collection.count()); + assertTrue(collection.count() == 1); } @Test @@ -246,6 +247,7 @@ public void testCollectionDeleteIds() throws ApiException { System.out.println(collection.get()); System.out.println(collection.deleteWithIds(Arrays.asList("1"))); System.out.println(collection.get()); + assertTrue(collection.get().getDocuments().size() == 0); } @Test @@ -407,5 +409,45 @@ public void testCollectionUpdateEmbeddings() throws ApiException { } + @Test + public void testCreateAddHF() throws ApiException { + Utils.loadEnvFile(".env"); + Client client = new Client(Utils.getEnvOrProperty("CHROMA_URL")); + client.reset(); + String apiKey = Utils.getEnvOrProperty("HF_API_KEY"); + EmbeddingFunction ef = new HuggingFaceEmbeddingFunction(apiKey); + Collection collection = client.createCollection("test-collection", null, true, ef); + List> metadata = new ArrayList<>(); + metadata.add(new HashMap() {{ + put("key", "value"); + }}); + Object resp = collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist."), Arrays.asList("1")); + System.out.println(resp); + System.out.println(collection.get()); + assertTrue(collection.count() == 1); + } + + @Test + public void testQueryExampleHF() throws ApiException { + Utils.loadEnvFile(".env"); + Client client = new Client(Utils.getEnvOrProperty("CHROMA_URL")); + client.reset(); + String apiKey = Utils.getEnvOrProperty("HF_API_KEY"); + EmbeddingFunction ef = new HuggingFaceEmbeddingFunction(apiKey); + Collection collection = client.createCollection("test-collection", null, true, ef); + List> metadata = new ArrayList<>(); + metadata.add(new HashMap() {{ + put("type", "scientist"); + }}); + metadata.add(new HashMap() {{ + put("type", "spy"); + }}); + List texts = Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."); + collection.add(null, metadata, texts, Arrays.asList("1", "2")); + Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null); + assertEquals(qr.getIds().get(0).get(0), "2"); //we check that Bond doc is first + System.out.println(qr); + } + } diff --git a/src/test/java/TestHuggingFaceClient.java b/src/test/java/TestHuggingFaceClient.java new file mode 100644 index 0000000..6b43de1 --- /dev/null +++ b/src/test/java/TestHuggingFaceClient.java @@ -0,0 +1,21 @@ +import org.junit.Test; +import tech.amikos.hf.CreateEmbeddingRequest; +import tech.amikos.hf.CreateEmbeddingResponse; +import tech.amikos.hf.HuggingFaceClient; + +import static org.junit.Assert.assertEquals; + +public class TestHuggingFaceClient { + + + @Test + public void testEmbeddings() { + Utils.loadEnvFile(".env"); + HuggingFaceClient client = new HuggingFaceClient(Utils.getEnvOrProperty("HF_API_KEY")); + client.modelId("sentence-transformers/all-MiniLM-L6-v2"); + String[] texts = {"Hello world", "How are you?"}; + CreateEmbeddingResponse response = client.createEmbedding(new CreateEmbeddingRequest().inputs(texts)); + assertEquals(2, response.getEmbeddings().size()); + } +} +