Skip to content

Commit

Permalink
Hugging Face Support (#12)
Browse files Browse the repository at this point in the history
* feat: Added support for reset

* fix: Improved delete Collection test case.

* feat: Cohere support

Refs: #1

* feat: Added integration tests in a separate workflow

- Release workflow now is only triggered upon new release and it bumps version to tag name

* fix: Added reset after each test to ensure test are not failing due to what previous tests have done.

* feat: List collections

- Improvements of how Collections are handled

Refs: #3

* feat: Full feature parity with the Python API (except for raw SQL execution)

Refs: #2, #3, #4

* fix: Fixing release workflow

* chore: Testing release

* chore: Testing release - removed testing in release workflow

* chore: For testing added tags to wf triggers.

* chore: For testing added tags to wf triggers.

* chore: For testing added tags to wf triggers.

* chore: Removing commit of bumped version

* fix: OSSRH creds missing

* chore: Update POM to comply with OSSRH requirements.

* chore: Finalizing release workflow.

* feat: Auto release to central

- Integration tests now apply to chroma version 0.4.3 and 0.4.4
- Some tests are conditional on API version

* chore: Fixed a typo in pom

* docs: Updated README.md with usage information

* feat: HuggingFace Support

- All tests have basic assertions to verify

Refs: #11, #7

* fix: Added HF_API_KEY key

Refs: #11, #7
  • Loading branch information
tazarov authored Aug 1, 2023
1 parent 2bbab35 commit 2f3dcea
Show file tree
Hide file tree
Showing 12 changed files with 304 additions and 58 deletions.
1 change: 1 addition & 0 deletions .github/workflows/integration-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,5 @@ jobs:
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
HF_API_KEY: ${{ secrets.HF_API_KEY }}
CHROMA_URL: ${{steps.wait-and-set.outputs.chroma-url}}
82 changes: 39 additions & 43 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

This is a very basic/naive implementation in Java of the Chroma Vector Database API.

This client works with Chroma Version `0.4.3`
This client works with Chroma Versions `0.4.3+`

## Features

Expand Down Expand Up @@ -35,23 +35,22 @@ This client works with Chroma Version `0.4.3`

## TODO

- [ ] Add support for other embedding functions
- [ ] Push the package to Maven
- [x] Push the package to Maven
Central - https://docs.github.com/en/actions/publishing-packages/publishing-java-packages-with-maven
- [ ] Fluent API - make it easier for users to make use of the library
- [ ] Support for PaLM API
- [ ] Support for Sentence Transformers with Hugging Face API

## Usage

Clone the repository and install the package locally:
Add Maven dependency:

```bash
git clone [email protected]:amikos-tech/chromadb-java-client.git
```

Install dependencies:

```bash
mvn clean compile
```xml
<dependency>
<groupId>io.github.amikos-tech</groupId>
<artifactId>chromadb-java-client</artifactId>
<version>0.1.1</version>
</dependency>
```

Ensure you have a running instance of Chroma running. We recommend one of the two following options:
Expand All @@ -60,51 +59,48 @@ Ensure you have a running instance of Chroma running. We recommend one of the tw
- If you are a fan of Kubernetes, you can use the Helm chart - https://github.com/amikos-tech/chromadb-chart (Note: You
will need `Docker`, `minikube` and `kubectl` installed)

Run tests:

| **Important**: Since we are using the OpenAI API, you need to set the `OPENAI_API_KEY` environment variable. Simply
create `.env` file in the root of the repository.

```bash
mvn test
```

## Example
### Example

```java
package tech.amikos;

import com.google.gson.internal.LinkedTreeMap;
import io.github.cdimascio.dotenv.Dotenv;
import tech.amikos.chromadb.Client;
import tech.amikos.chromadb.Collection;
import tech.amikos.chromadb.EmbeddingFunction;
import tech.amikos.chromadb.OpenAIEmbeddingFunction;
import tech.amikos.chromadb.handler.ApiException;

class TestApi {
public void testQueryExample() throws ApiException {
Client client = new Client("http://localhost:8000");
Dotenv dotenv = Dotenv.load();
String apiKey = dotenv.get("OPENAI_API_KEY");
EmbeddingFunction ef = new OpenAIEmbeddingFunction(apiKey);
Collection collection = client.createCollection("test-collection", null, true, ef);
List<Map<String, String>> metadata = new ArrayList<>();
metadata.add(new HashMap<String, String>() {{
put("type", "scientist");
}});
metadata.add(new HashMap<String, String>() {{
put("type", "spy");
}});
collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2"));
LinkedTreeMap<String, Object> qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null);
System.out.println(qr);

import java.util.*;

public class Main {
public static void main(String[] args) {
try {
Client client = new Client(System.getenv("CHROMA_URL"));
String apiKey = System.getenv("OPENAI_API_KEY");
EmbeddingFunction ef = new OpenAIEmbeddingFunction(apiKey);
Collection collection = client.createCollection("test-collection", null, true, ef);
List<Map<String, String>> metadata = new ArrayList<>();
metadata.add(new HashMap<String, String>() {{
put("type", "scientist");
}});
metadata.add(new HashMap<String, String>() {{
put("type", "spy");
}});
collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2"));
Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null);
System.out.println(qr);
} catch (Exception e) {
e.printStackTrace();
System.out.println(e);
}
}
}
```

The above should output:

```bash
{ids=[[2, 1]], distances=[[0.28461432651150426, 0.5096168232841949]], metadatas=[[{key=value}, {key=value}]], embeddings=null, documents=[[Hello, my name is Bond. I am a Spy., Hello, my name is John. I am a Data Scientist.]]}
{"documents":[["Hello, my name is Bond. I am a Spy.","Hello, my name is John. I am a Data Scientist."]],"ids":[["2","1"]],"metadatas":[[{"type":"spy"},{"type":"scientist"}]],"distances":[[0.28461432,0.50961685]]}
```

## Development Notes
Expand Down
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
<scm>
<connection>scm:git:git://github.com/amikos-tech/chromadb-java-client.git</connection>
<developerConnection>scm:git:ssh://[email protected]:amikos-tech/chromadb-java-client.git</developerConnection>
<url>https://github.com/amikos-tech/chromadb-java-client/tree/master</url>
<url>https://github.com/amikos-tech/chromadb-java-client/tree/main</url>
</scm><licenses>
<license>
<name>MIT</name>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,11 @@ public List<List<Float>> createEmbedding(List<String> documents) {
CreateEmbeddingResponse response = client.createEmbedding(new CreateEmbeddingRequest().texts(documents.toArray(new String[0])));
return response.getEmbeddings();
}

@Override
public List<List<Float>> createEmbedding(List<String> documents, String model) {
CohereClient client = new CohereClient(this.cohereAPIKey);
CreateEmbeddingResponse response = client.createEmbedding(new CreateEmbeddingRequest().texts(documents.toArray(new String[0])).model(model));
return response.getEmbeddings();
}
}
2 changes: 2 additions & 0 deletions src/main/java/tech/amikos/chromadb/EmbeddingFunction.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@
public interface EmbeddingFunction {

List<List<Float>> createEmbedding(List<String> documents);

List<List<Float>> createEmbedding(List<String> documents, String model);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package tech.amikos.chromadb;


import tech.amikos.hf.CreateEmbeddingRequest;
import tech.amikos.hf.CreateEmbeddingResponse;
import tech.amikos.hf.HuggingFaceClient;

import java.util.List;

public class HuggingFaceEmbeddingFunction implements EmbeddingFunction {

private final String hfAPIKey;

public HuggingFaceEmbeddingFunction(String hfAPIKey) {
this.hfAPIKey = hfAPIKey;

}

@Override
public List<List<Float>> createEmbedding(List<String> documents) {
HuggingFaceClient client = new HuggingFaceClient(this.hfAPIKey);
CreateEmbeddingResponse response = client.createEmbedding(new CreateEmbeddingRequest().inputs(documents.toArray(new String[0])));
return response.getEmbeddings();
}

@Override
public List<List<Float>> createEmbedding(List<String> documents, String model) {
HuggingFaceClient client = new HuggingFaceClient(this.hfAPIKey);
client.modelId(model);
CreateEmbeddingResponse response = client.createEmbedding(new CreateEmbeddingRequest().inputs(documents.toArray(new String[0])));
return response.getEmbeddings();
}
}
11 changes: 10 additions & 1 deletion src/main/java/tech/amikos/chromadb/OpenAIEmbeddingFunction.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,16 @@ public List<List<Float>> createEmbedding(List<String> documents) {
OpenAIClient client = new OpenAIClient();
CreateEmbeddingResponse response = client.apiKey(this.openAIAPIKey)
.createEmbedding(req);
// return response.getEmbeddings();
return response.getData().stream().map(emb -> emb.getEmbedding()).collect(Collectors.toList());
}

@Override
public List<List<Float>> createEmbedding(List<String> documents, String model) {
CreateEmbeddingRequest req = new CreateEmbeddingRequest().model(model);
req.input(new CreateEmbeddingRequest.Input(documents.toArray(new String[0])));
OpenAIClient client = new OpenAIClient();
CreateEmbeddingResponse response = client.apiKey(this.openAIAPIKey)
.createEmbedding(req);
return response.getData().stream().map(emb -> emb.getEmbedding()).collect(Collectors.toList());
}
}
40 changes: 40 additions & 0 deletions src/main/java/tech/amikos/hf/CreateEmbeddingRequest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package tech.amikos.hf;

import com.google.gson.Gson;
import com.google.gson.annotations.SerializedName;

import java.util.HashMap;

public class CreateEmbeddingRequest {

@SerializedName("inputs")
private String[] inputs;
@SerializedName("options")
private HashMap<String, Object> options;

public CreateEmbeddingRequest inputs(String[] inputs) {
this.inputs = inputs;
return this;
}

public CreateEmbeddingRequest options(HashMap<String, Object> options) {
this.options = options;
return this;
}

public String[] getInputs() {
return inputs;
}

public HashMap<String, Object> getOptions() {
return options;
}

public String toString() {
return this.json();
}

public String json() {
return new Gson().toJson(this);
}
}
22 changes: 22 additions & 0 deletions src/main/java/tech/amikos/hf/CreateEmbeddingResponse.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package tech.amikos.hf;

import com.google.gson.Gson;

import java.util.List;

public class CreateEmbeddingResponse {
public List<List<Float>> embeddings;

public List<List<Float>> getEmbeddings() {
return embeddings;
}

public CreateEmbeddingResponse(List<List<Float>> embeddings) {
this.embeddings = embeddings;
}

@Override
public String toString() {
return new Gson().toJson(this);
}
}
73 changes: 73 additions & 0 deletions src/main/java/tech/amikos/hf/HuggingFaceClient.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package tech.amikos.hf;

import com.google.gson.Gson;
import okhttp3.*;

import java.io.IOException;
import java.util.List;

//https://huggingface.co/blog/getting-started-with-embeddings
public class HuggingFaceClient {

private String baseUrl = "https://api-inference.huggingface.co/pipeline/feature-extraction/";
private String apiKey;

private OkHttpClient client = new OkHttpClient();

private String modelId = "sentence-transformers/all-MiniLM-L6-v2";
private Gson gson = new Gson();
MediaType JSON = MediaType.parse("application/json; charset=utf-8");

public HuggingFaceClient(String apiKey) {
this.apiKey = apiKey;
}


public HuggingFaceClient apiKey(String apiKey) {
this.apiKey = apiKey;
return this;
}


public HuggingFaceClient baseUrl(String baseUrl) {
this.baseUrl = baseUrl;
return this;
}

public HuggingFaceClient modelId(String modelId) {
this.modelId = modelId;
return this;
}

private String getApiKey() {
if (this.apiKey == null) {
throw new RuntimeException("API Key not set");
}
return this.apiKey;
}

public CreateEmbeddingResponse createEmbedding(CreateEmbeddingRequest req) {
Request request = new Request.Builder()
.url(this.baseUrl + this.modelId)
.post(RequestBody.create(req.json(), JSON))
.addHeader("Accept", "application/json")
.addHeader("Content-Type", "application/json")
.addHeader("Authorization", "Bearer " + getApiKey())
.build();
try (Response response = client.newCall(request).execute()) {
if (!response.isSuccessful()) {
throw new IOException("Unexpected code " + response);
}

String responseData = response.body().string();

List parsedResponse = gson.fromJson(responseData, List.class);

return new CreateEmbeddingResponse(parsedResponse);
} catch (IOException e) {
e.printStackTrace();
}
return null;
}

}
Loading

0 comments on commit 2f3dcea

Please sign in to comment.