-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* feat: Added support for reset * fix: Improved delete Collection test case. * feat: Cohere support Refs: #1 * feat: Added integration tests in a separate workflow - Release workflow now is only triggered upon new release and it bumps version to tag name * fix: Added reset after each test to ensure test are not failing due to what previous tests have done. * feat: List collections - Improvements of how Collections are handled Refs: #3 * feat: Full feature parity with the Python API (except for raw SQL execution) Refs: #2, #3, #4 * fix: Fixing release workflow * chore: Testing release * chore: Testing release - removed testing in release workflow * chore: For testing added tags to wf triggers. * chore: For testing added tags to wf triggers. * chore: For testing added tags to wf triggers. * chore: Removing commit of bumped version * fix: OSSRH creds missing * chore: Update POM to comply with OSSRH requirements. * chore: Finalizing release workflow. * feat: Auto release to central - Integration tests now apply to chroma version 0.4.3 and 0.4.4 - Some tests are conditional on API version * chore: Fixed a typo in pom * docs: Updated README.md with usage information * feat: HuggingFace Support - All tests have basic assertions to verify Refs: #11, #7 * fix: Added HF_API_KEY key Refs: #11, #7
- Loading branch information
Showing
12 changed files
with
304 additions
and
58 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
|
||
This is a very basic/naive implementation in Java of the Chroma Vector Database API. | ||
|
||
This client works with Chroma Version `0.4.3` | ||
This client works with Chroma Versions `0.4.3+` | ||
|
||
## Features | ||
|
||
|
@@ -35,23 +35,22 @@ This client works with Chroma Version `0.4.3` | |
|
||
## TODO | ||
|
||
- [ ] Add support for other embedding functions | ||
- [ ] Push the package to Maven | ||
- [x] Push the package to Maven | ||
Central - https://docs.github.com/en/actions/publishing-packages/publishing-java-packages-with-maven | ||
- [ ] Fluent API - make it easier for users to make use of the library | ||
- [ ] Support for PaLM API | ||
- [ ] Support for Sentence Transformers with Hugging Face API | ||
|
||
## Usage | ||
|
||
Clone the repository and install the package locally: | ||
Add Maven dependency: | ||
|
||
```bash | ||
git clone [email protected]:amikos-tech/chromadb-java-client.git | ||
``` | ||
|
||
Install dependencies: | ||
|
||
```bash | ||
mvn clean compile | ||
```xml | ||
<dependency> | ||
<groupId>io.github.amikos-tech</groupId> | ||
<artifactId>chromadb-java-client</artifactId> | ||
<version>0.1.1</version> | ||
</dependency> | ||
``` | ||
|
||
Ensure you have a running instance of Chroma running. We recommend one of the two following options: | ||
|
@@ -60,51 +59,48 @@ Ensure you have a running instance of Chroma running. We recommend one of the tw | |
- If you are a fan of Kubernetes, you can use the Helm chart - https://github.com/amikos-tech/chromadb-chart (Note: You | ||
will need `Docker`, `minikube` and `kubectl` installed) | ||
|
||
Run tests: | ||
|
||
| **Important**: Since we are using the OpenAI API, you need to set the `OPENAI_API_KEY` environment variable. Simply | ||
create `.env` file in the root of the repository. | ||
|
||
```bash | ||
mvn test | ||
``` | ||
|
||
## Example | ||
### Example | ||
|
||
```java | ||
package tech.amikos; | ||
|
||
import com.google.gson.internal.LinkedTreeMap; | ||
import io.github.cdimascio.dotenv.Dotenv; | ||
import tech.amikos.chromadb.Client; | ||
import tech.amikos.chromadb.Collection; | ||
import tech.amikos.chromadb.EmbeddingFunction; | ||
import tech.amikos.chromadb.OpenAIEmbeddingFunction; | ||
import tech.amikos.chromadb.handler.ApiException; | ||
|
||
class TestApi { | ||
public void testQueryExample() throws ApiException { | ||
Client client = new Client("http://localhost:8000"); | ||
Dotenv dotenv = Dotenv.load(); | ||
String apiKey = dotenv.get("OPENAI_API_KEY"); | ||
EmbeddingFunction ef = new OpenAIEmbeddingFunction(apiKey); | ||
Collection collection = client.createCollection("test-collection", null, true, ef); | ||
List<Map<String, String>> metadata = new ArrayList<>(); | ||
metadata.add(new HashMap<String, String>() {{ | ||
put("type", "scientist"); | ||
}}); | ||
metadata.add(new HashMap<String, String>() {{ | ||
put("type", "spy"); | ||
}}); | ||
collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2")); | ||
LinkedTreeMap<String, Object> qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null); | ||
System.out.println(qr); | ||
|
||
import java.util.*; | ||
|
||
public class Main { | ||
public static void main(String[] args) { | ||
try { | ||
Client client = new Client(System.getenv("CHROMA_URL")); | ||
String apiKey = System.getenv("OPENAI_API_KEY"); | ||
EmbeddingFunction ef = new OpenAIEmbeddingFunction(apiKey); | ||
Collection collection = client.createCollection("test-collection", null, true, ef); | ||
List<Map<String, String>> metadata = new ArrayList<>(); | ||
metadata.add(new HashMap<String, String>() {{ | ||
put("type", "scientist"); | ||
}}); | ||
metadata.add(new HashMap<String, String>() {{ | ||
put("type", "spy"); | ||
}}); | ||
collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2")); | ||
Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null); | ||
System.out.println(qr); | ||
} catch (Exception e) { | ||
e.printStackTrace(); | ||
System.out.println(e); | ||
} | ||
} | ||
} | ||
``` | ||
|
||
The above should output: | ||
|
||
```bash | ||
{ids=[[2, 1]], distances=[[0.28461432651150426, 0.5096168232841949]], metadatas=[[{key=value}, {key=value}]], embeddings=null, documents=[[Hello, my name is Bond. I am a Spy., Hello, my name is John. I am a Data Scientist.]]} | ||
{"documents":[["Hello, my name is Bond. I am a Spy.","Hello, my name is John. I am a Data Scientist."]],"ids":[["2","1"]],"metadatas":[[{"type":"spy"},{"type":"scientist"}]],"distances":[[0.28461432,0.50961685]]} | ||
``` | ||
|
||
## Development Notes | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,7 +27,7 @@ | |
<scm> | ||
<connection>scm:git:git://github.com/amikos-tech/chromadb-java-client.git</connection> | ||
<developerConnection>scm:git:ssh://[email protected]:amikos-tech/chromadb-java-client.git</developerConnection> | ||
<url>https://github.com/amikos-tech/chromadb-java-client/tree/master</url> | ||
<url>https://github.com/amikos-tech/chromadb-java-client/tree/main</url> | ||
</scm><licenses> | ||
<license> | ||
<name>MIT</name> | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
33 changes: 33 additions & 0 deletions
33
src/main/java/tech/amikos/chromadb/HuggingFaceEmbeddingFunction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
package tech.amikos.chromadb; | ||
|
||
|
||
import tech.amikos.hf.CreateEmbeddingRequest; | ||
import tech.amikos.hf.CreateEmbeddingResponse; | ||
import tech.amikos.hf.HuggingFaceClient; | ||
|
||
import java.util.List; | ||
|
||
public class HuggingFaceEmbeddingFunction implements EmbeddingFunction { | ||
|
||
private final String hfAPIKey; | ||
|
||
public HuggingFaceEmbeddingFunction(String hfAPIKey) { | ||
this.hfAPIKey = hfAPIKey; | ||
|
||
} | ||
|
||
@Override | ||
public List<List<Float>> createEmbedding(List<String> documents) { | ||
HuggingFaceClient client = new HuggingFaceClient(this.hfAPIKey); | ||
CreateEmbeddingResponse response = client.createEmbedding(new CreateEmbeddingRequest().inputs(documents.toArray(new String[0]))); | ||
return response.getEmbeddings(); | ||
} | ||
|
||
@Override | ||
public List<List<Float>> createEmbedding(List<String> documents, String model) { | ||
HuggingFaceClient client = new HuggingFaceClient(this.hfAPIKey); | ||
client.modelId(model); | ||
CreateEmbeddingResponse response = client.createEmbedding(new CreateEmbeddingRequest().inputs(documents.toArray(new String[0]))); | ||
return response.getEmbeddings(); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
package tech.amikos.hf; | ||
|
||
import com.google.gson.Gson; | ||
import com.google.gson.annotations.SerializedName; | ||
|
||
import java.util.HashMap; | ||
|
||
public class CreateEmbeddingRequest { | ||
|
||
@SerializedName("inputs") | ||
private String[] inputs; | ||
@SerializedName("options") | ||
private HashMap<String, Object> options; | ||
|
||
public CreateEmbeddingRequest inputs(String[] inputs) { | ||
this.inputs = inputs; | ||
return this; | ||
} | ||
|
||
public CreateEmbeddingRequest options(HashMap<String, Object> options) { | ||
this.options = options; | ||
return this; | ||
} | ||
|
||
public String[] getInputs() { | ||
return inputs; | ||
} | ||
|
||
public HashMap<String, Object> getOptions() { | ||
return options; | ||
} | ||
|
||
public String toString() { | ||
return this.json(); | ||
} | ||
|
||
public String json() { | ||
return new Gson().toJson(this); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
package tech.amikos.hf; | ||
|
||
import com.google.gson.Gson; | ||
|
||
import java.util.List; | ||
|
||
public class CreateEmbeddingResponse { | ||
public List<List<Float>> embeddings; | ||
|
||
public List<List<Float>> getEmbeddings() { | ||
return embeddings; | ||
} | ||
|
||
public CreateEmbeddingResponse(List<List<Float>> embeddings) { | ||
this.embeddings = embeddings; | ||
} | ||
|
||
@Override | ||
public String toString() { | ||
return new Gson().toJson(this); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
package tech.amikos.hf; | ||
|
||
import com.google.gson.Gson; | ||
import okhttp3.*; | ||
|
||
import java.io.IOException; | ||
import java.util.List; | ||
|
||
//https://huggingface.co/blog/getting-started-with-embeddings | ||
public class HuggingFaceClient { | ||
|
||
private String baseUrl = "https://api-inference.huggingface.co/pipeline/feature-extraction/"; | ||
private String apiKey; | ||
|
||
private OkHttpClient client = new OkHttpClient(); | ||
|
||
private String modelId = "sentence-transformers/all-MiniLM-L6-v2"; | ||
private Gson gson = new Gson(); | ||
MediaType JSON = MediaType.parse("application/json; charset=utf-8"); | ||
|
||
public HuggingFaceClient(String apiKey) { | ||
this.apiKey = apiKey; | ||
} | ||
|
||
|
||
public HuggingFaceClient apiKey(String apiKey) { | ||
this.apiKey = apiKey; | ||
return this; | ||
} | ||
|
||
|
||
public HuggingFaceClient baseUrl(String baseUrl) { | ||
this.baseUrl = baseUrl; | ||
return this; | ||
} | ||
|
||
public HuggingFaceClient modelId(String modelId) { | ||
this.modelId = modelId; | ||
return this; | ||
} | ||
|
||
private String getApiKey() { | ||
if (this.apiKey == null) { | ||
throw new RuntimeException("API Key not set"); | ||
} | ||
return this.apiKey; | ||
} | ||
|
||
public CreateEmbeddingResponse createEmbedding(CreateEmbeddingRequest req) { | ||
Request request = new Request.Builder() | ||
.url(this.baseUrl + this.modelId) | ||
.post(RequestBody.create(req.json(), JSON)) | ||
.addHeader("Accept", "application/json") | ||
.addHeader("Content-Type", "application/json") | ||
.addHeader("Authorization", "Bearer " + getApiKey()) | ||
.build(); | ||
try (Response response = client.newCall(request).execute()) { | ||
if (!response.isSuccessful()) { | ||
throw new IOException("Unexpected code " + response); | ||
} | ||
|
||
String responseData = response.body().string(); | ||
|
||
List parsedResponse = gson.fromJson(responseData, List.class); | ||
|
||
return new CreateEmbeddingResponse(parsedResponse); | ||
} catch (IOException e) { | ||
e.printStackTrace(); | ||
} | ||
return null; | ||
} | ||
|
||
} |
Oops, something went wrong.