-
Notifications
You must be signed in to change notification settings - Fork 87
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #256 from lgrammel/classify
Add classify function.
- Loading branch information
Showing
16 changed files
with
520 additions
and
209 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
--- | ||
sidebar_position: 75 | ||
--- | ||
|
||
# Classify Value | ||
|
||
Classifies a value into a category. | ||
|
||
## Usage | ||
|
||
[classify API](/api/modules#classify) | ||
|
||
You can call the `classify` function with a classifier model and a value to classify. | ||
|
||
#### Example | ||
|
||
```ts | ||
import { classify } from "modelfusion"; | ||
|
||
// strongly typed result: | ||
const result = await classify({ | ||
model: classifier, // see classifiers below | ||
value: "don't you love politics?", | ||
// ... other function options | ||
}); | ||
|
||
switch (result) { | ||
case "politics": | ||
console.log("politics"); | ||
break; | ||
case "chitchat": | ||
console.log("chitchat"); | ||
break; | ||
case null: | ||
console.log("null"); | ||
break; | ||
} | ||
``` | ||
|
||
## Classifiers | ||
|
||
### EmbeddingSimilarityClassifier | ||
|
||
[EmbeddingSimilarityClassifier API](/api/classes/EmbeddingSimilarityClassifier) | ||
|
||
Classifies values based on their distance to the values from a set of clusters. | ||
When the distance is below a certain threshold, the value is classified as belonging to the cluster, | ||
and the cluster name is returned. Otherwise, the value is classified as null. | ||
|
||
#### Example | ||
|
||
```ts | ||
import { EmbeddingSimilarityClassifier, openai } from "modelfusion"; | ||
|
||
const classifier = new EmbeddingSimilarityClassifier({ | ||
// you can use any supported embedding model: | ||
embeddingModel: openai.TextEmbedder({ | ||
model: "text-embedding-ada-002", | ||
}), | ||
|
||
// the threshold for the distance between the value and the cluster values: | ||
similarityThreshold: 0.82, | ||
|
||
clusters: [ | ||
{ | ||
name: "politics" as const, | ||
values: [ | ||
"isn't politics the best thing ever", | ||
"why don't you tell me about your political opinions", | ||
"don't you just love the president", | ||
"don't you just hate the president", | ||
"they're going to destroy this country!", | ||
"they will save the country!", | ||
], | ||
}, | ||
{ | ||
name: "chitchat" as const, | ||
values: [ | ||
"how's the weather today?", | ||
"how are things going?", | ||
"lovely weather today", | ||
"the weather is horrendous", | ||
"let's go to the chippy", | ||
], | ||
}, | ||
], | ||
}); | ||
``` |
44 changes: 0 additions & 44 deletions
44
examples/basic/src/classifier/semantic-classifier-example.ts
This file was deleted.
Oops, something went wrong.
53 changes: 0 additions & 53 deletions
53
examples/basic/src/classifier/semantic-classifier-switch-example.ts
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import dotenv from "dotenv"; | ||
import { EmbeddingSimilarityClassifier, classify, openai } from "modelfusion"; | ||
|
||
dotenv.config(); | ||
|
||
const classifier = new EmbeddingSimilarityClassifier({ | ||
embeddingModel: openai.TextEmbedder({ | ||
model: "text-embedding-ada-002", | ||
}), | ||
similarityThreshold: 0.82, | ||
clusters: [ | ||
{ | ||
name: "politics" as const, | ||
values: [ | ||
"isn't politics the best thing ever", | ||
"why don't you tell me about your political opinions", | ||
"don't you just love the president", | ||
"don't you just hate the president", | ||
"they're going to destroy this country!", | ||
"they will save the country!", | ||
], | ||
}, | ||
{ | ||
name: "chitchat" as const, | ||
values: [ | ||
"how's the weather today?", | ||
"how are things going?", | ||
"lovely weather today", | ||
"the weather is horrendous", | ||
"let's go to the chippy", | ||
], | ||
}, | ||
], | ||
}); | ||
|
||
async function main() { | ||
// politics: | ||
console.log( | ||
await classify({ | ||
model: classifier, | ||
value: "don't you love politics?", | ||
}) | ||
); | ||
|
||
// chitchat: | ||
console.log( | ||
await classify({ | ||
model: classifier, | ||
value: "how's the weather today?", | ||
}) | ||
); | ||
|
||
// null (no match): | ||
console.log( | ||
await classify({ | ||
model: classifier, | ||
value: "I'm interested in learning about llama 2", | ||
}) | ||
); | ||
} | ||
|
||
main().catch(console.error); |
57 changes: 57 additions & 0 deletions
57
examples/basic/src/model-function/classify-switch-example.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import dotenv from "dotenv"; | ||
import { EmbeddingSimilarityClassifier, classify, openai } from "modelfusion"; | ||
|
||
dotenv.config(); | ||
|
||
const classifier = new EmbeddingSimilarityClassifier({ | ||
embeddingModel: openai.TextEmbedder({ | ||
model: "text-embedding-ada-002", | ||
}), | ||
similarityThreshold: 0.82, | ||
clusters: [ | ||
{ | ||
name: "politics" as const, | ||
values: [ | ||
"isn't politics the best thing ever", | ||
"why don't you tell me about your political opinions", | ||
"don't you just love the president", | ||
"don't you just hate the president", | ||
"they're going to destroy this country!", | ||
"they will save the country!", | ||
], | ||
}, | ||
{ | ||
name: "chitchat" as const, | ||
values: [ | ||
"how's the weather today?", | ||
"how are things going?", | ||
"lovely weather today", | ||
"the weather is horrendous", | ||
"let's go to the chippy", | ||
], | ||
}, | ||
], | ||
}); | ||
|
||
async function main() { | ||
// strongly typed result: | ||
const result = await classify({ | ||
model: classifier, | ||
value: "don't you love politics?", | ||
// logging: "basic-text", | ||
}); | ||
|
||
switch (result) { | ||
case "politics": | ||
console.log("politics"); | ||
break; | ||
case "chitchat": | ||
console.log("chitchat"); | ||
break; | ||
case null: | ||
console.log("null"); | ||
break; | ||
} | ||
} | ||
|
||
main().catch(console.error); |
Oops, something went wrong.