Skip to content

Commit

Permalink
Merge pull request #256 from lgrammel/classify
Browse files Browse the repository at this point in the history
Add classify function.
  • Loading branch information
lgrammel authored Jan 13, 2024
2 parents f1a00fd + babafb4 commit a001d31
Show file tree
Hide file tree
Showing 16 changed files with 520 additions and 209 deletions.
40 changes: 40 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,8 @@ Providers: [OpenAI (Whisper)](https://modelfusion.dev/integration/model-provider
Create embeddings for text and other values. Embeddings are vectors that represent the essence of the values in the context of the model.

```ts
import { embed, embedMany, openai } from "modelfusion";

// embed single value:
const embedding = await embed({
model: openai.TextEmbedder({ model: "text-embedding-ada-002" }),
Expand All @@ -309,6 +311,43 @@ const embeddings = await embedMany({

Providers: [OpenAI](https://modelfusion.dev/integration/model-provider/openai), [Llama.cpp](https://modelfusion.dev/integration/model-provider/llamacpp), [Ollama](https://modelfusion.dev/integration/model-provider/ollama), [Mistral](https://modelfusion.dev/integration/model-provider/mistral), [Hugging Face](https://modelfusion.dev/integration/model-provider/huggingface), [Cohere](https://modelfusion.dev/integration/model-provider/cohere)

### [Classify Value](https://modelfusion.dev/guide/function/classify)

Classifies a value into a category.

```ts
import { classify, EmbeddingSimilarityClassifier, openai } from "modelfusion";

const classifier = new EmbeddingSimilarityClassifier({
embeddingModel: openai.TextEmbedder({ model: "text-embedding-ada-002" }),
similarityThreshold: 0.82,
clusters: [
{
name: "politics" as const,
values: [
"they will save the country!",
// ...
],
},
{
name: "chitchat" as const,
values: [
"how's the weather today?",
// ...
],
},
],
});

// strongly typed result:
const result = await classify({
model: classifier,
value: "don't you love politics?",
});
```

Classifiers: [EmbeddingSimilarityClassifier](https://modelfusion.dev/guide/function/classify#embeddingsimilarityclassifier)

### [Tokenize Text](https://modelfusion.dev/guide/function/tokenize-text)

Split text into tokens and reconstruct the text from tokens.
Expand Down Expand Up @@ -552,6 +591,7 @@ modelfusion.setLogFormat("detailed-object"); // log full events
- [Generate transcription](https://modelfusion.dev/guide/function/generation-transcription)
- [Tokenize Text](https://modelfusion.dev/guide/function/tokenize-text)
- [Embed Value](https://modelfusion.dev/guide/function/embed)
- [Classify Value](https://modelfusion.dev/guide/function/classify)
- [Tools](https://modelfusion.dev/guide/tools)
- [Use Tool](https://modelfusion.dev/guide/tools/use-tool)
- [Use Tools](https://modelfusion.dev/guide/tools/use-tools)
Expand Down
88 changes: 88 additions & 0 deletions docs/guide/function/classify.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
---
sidebar_position: 75
---

# Classify Value

Classifies a value into a category.

## Usage

[classify API](/api/modules#classify)

You can call the `classify` function with a classifier model and a value to classify.

#### Example

```ts
import { classify } from "modelfusion";

// strongly typed result:
const result = await classify({
model: classifier, // see classifiers below
value: "don't you love politics?",
// ... other function options
});

switch (result) {
case "politics":
console.log("politics");
break;
case "chitchat":
console.log("chitchat");
break;
case null:
console.log("null");
break;
}
```

## Classifiers

### EmbeddingSimilarityClassifier

[EmbeddingSimilarityClassifier API](/api/classes/EmbeddingSimilarityClassifier)

Classifies values based on their distance to the values from a set of clusters.
When the distance is below a certain threshold, the value is classified as belonging to the cluster,
and the cluster name is returned. Otherwise, the value is classified as null.

#### Example

```ts
import { EmbeddingSimilarityClassifier, openai } from "modelfusion";

const classifier = new EmbeddingSimilarityClassifier({
// you can use any supported embedding model:
embeddingModel: openai.TextEmbedder({
model: "text-embedding-ada-002",
}),

// the threshold for the distance between the value and the cluster values:
similarityThreshold: 0.82,

clusters: [
{
name: "politics" as const,
values: [
"isn't politics the best thing ever",
"why don't you tell me about your political opinions",
"don't you just love the president",
"don't you just hate the president",
"they're going to destroy this country!",
"they will save the country!",
],
},
{
name: "chitchat" as const,
values: [
"how's the weather today?",
"how are things going?",
"lovely weather today",
"the weather is horrendous",
"let's go to the chippy",
],
},
],
});
```
44 changes: 0 additions & 44 deletions examples/basic/src/classifier/semantic-classifier-example.ts

This file was deleted.

This file was deleted.

62 changes: 62 additions & 0 deletions examples/basic/src/model-function/classify-example.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import dotenv from "dotenv";
import { EmbeddingSimilarityClassifier, classify, openai } from "modelfusion";

dotenv.config();

const classifier = new EmbeddingSimilarityClassifier({
embeddingModel: openai.TextEmbedder({
model: "text-embedding-ada-002",
}),
similarityThreshold: 0.82,
clusters: [
{
name: "politics" as const,
values: [
"isn't politics the best thing ever",
"why don't you tell me about your political opinions",
"don't you just love the president",
"don't you just hate the president",
"they're going to destroy this country!",
"they will save the country!",
],
},
{
name: "chitchat" as const,
values: [
"how's the weather today?",
"how are things going?",
"lovely weather today",
"the weather is horrendous",
"let's go to the chippy",
],
},
],
});

async function main() {
// politics:
console.log(
await classify({
model: classifier,
value: "don't you love politics?",
})
);

// chitchat:
console.log(
await classify({
model: classifier,
value: "how's the weather today?",
})
);

// null (no match):
console.log(
await classify({
model: classifier,
value: "I'm interested in learning about llama 2",
})
);
}

main().catch(console.error);
57 changes: 57 additions & 0 deletions examples/basic/src/model-function/classify-switch-example.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import dotenv from "dotenv";
import { EmbeddingSimilarityClassifier, classify, openai } from "modelfusion";

dotenv.config();

const classifier = new EmbeddingSimilarityClassifier({
embeddingModel: openai.TextEmbedder({
model: "text-embedding-ada-002",
}),
similarityThreshold: 0.82,
clusters: [
{
name: "politics" as const,
values: [
"isn't politics the best thing ever",
"why don't you tell me about your political opinions",
"don't you just love the president",
"don't you just hate the president",
"they're going to destroy this country!",
"they will save the country!",
],
},
{
name: "chitchat" as const,
values: [
"how's the weather today?",
"how are things going?",
"lovely weather today",
"the weather is horrendous",
"let's go to the chippy",
],
},
],
});

async function main() {
// strongly typed result:
const result = await classify({
model: classifier,
value: "don't you love politics?",
// logging: "basic-text",
});

switch (result) {
case "politics":
console.log("politics");
break;
case "chitchat":
console.log("chitchat");
break;
case null:
console.log("null");
break;
}
}

main().catch(console.error);
Loading

0 comments on commit a001d31

Please sign in to comment.