Skip to content

Commit

Permalink
fix: reranking probabilities (#412)
Browse files Browse the repository at this point in the history
  • Loading branch information
giladgd authored Jan 7, 2025
1 parent 5d07289 commit d1b4416
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 40 deletions.
10 changes: 5 additions & 5 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@
"typedoc": "^0.27.6",
"typedoc-plugin-markdown": "^4.4.1",
"typedoc-plugin-mdn-links": "^4.0.7",
"typedoc-vitepress-theme": "^1.1.1",
"typedoc-vitepress-theme": "^1.1.2",
"typescript": "^5.7.2",
"typescript-eslint": "^8.19.1",
"vite-node": "^2.1.8",
Expand Down
21 changes: 20 additions & 1 deletion src/evaluator/LlamaRankingContext.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ export class LlamaRankingContext {

/**
* Get the ranking score for a document for a query.
*
* A ranking score is a number between 0 and 1 representing the probability that the document is relevant to the query.
* @returns a ranking score between 0 and 1 representing the probability that the document is relevant to the query.
*/
public async rank(query: Token[] | string | LlamaText, document: Token[] | string | LlamaText) {
if (this.model.tokens.bos == null || this.model.tokens.eos == null || this.model.tokens.sep == null)
Expand All @@ -96,6 +99,9 @@ export class LlamaRankingContext {

/**
* Get the ranking scores for all the given documents for a query.
*
* A ranking score is a number between 0 and 1 representing the probability that the document is relevant to the query.
* @returns an array of ranking scores between 0 and 1 representing the probability that the document is relevant to the query.
*/
public async rankAll(query: Token[] | string | LlamaText, documents: Array<Token[] | string | LlamaText>): Promise<number[]> {
const resolvedTokens = documents.map((document) => this._getEvaluationInput(query, document));
Expand All @@ -120,9 +126,15 @@ export class LlamaRankingContext {

/**
* Get the ranking scores for all the given documents for a query and sort them by score from highest to lowest.
*
* A ranking score is a number between 0 and 1 representing the probability that the document is relevant to the query.
*/
public async rankAndSort<const T extends string>(query: Token[] | string | LlamaText, documents: T[]): Promise<Array<{
document: T,

/**
* A ranking score is a number between 0 and 1 representing the probability that the document is relevant to the query.
*/
score: number
}>> {
const scores = await this.rankAll(query, documents);
Expand Down Expand Up @@ -190,7 +202,10 @@ export class LlamaRankingContext {
if (embedding.length === 0)
return 0;

return embedding[0]!;
const logit = embedding[0]!;
const probability = logitToSigmoid(logit);

return probability;
});
}

Expand Down Expand Up @@ -249,3 +264,7 @@ function findLayer(tensorInfo: GgufTensorInfo[] | undefined, name: string, suffi

return undefined;
}

function logitToSigmoid(logit: number) {
return 1 / (1 + Math.exp(-logit));
}
78 changes: 45 additions & 33 deletions test/modelDependent/bgeReranker/rank.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,19 @@ describe("bgeReranker", () => {
const highestRankDocument = documents[highestRankIndex];
expect(highestRankDocument).to.eql("Mount Everest is the tallest mountain in the world");

expect(simplifyRanks([highestRank])[0]).toMatchInlineSnapshot("-4");
expect(simplifyRanks([highestRank])[0]).toMatchInlineSnapshot("0.01798620996209156");
expect(simplifyRanks(ranks)).toMatchInlineSnapshot(`
[
-11,
-11,
-11,
-5.6,
-11,
-4,
-11,
-11,
-11,
-11,
0.00001670142184809518,
0.00001670142184809518,
0.00001670142184809518,
0.003684239899435989,
0.00001670142184809518,
0.01798620996209156,
0.00001670142184809518,
0.00001670142184809518,
0.00001670142184809518,
0.00001670142184809518,
]
`);
});
Expand Down Expand Up @@ -91,19 +91,19 @@ describe("bgeReranker", () => {
const highestRankDocument = documents[highestRankIndex];
expect(highestRankDocument).to.eql("Mount Everest is the tallest mountain in the world");

expect(simplifyRanks([highestRank])[0]).toMatchInlineSnapshot("-4");
expect(simplifyRanks([highestRank])[0]).toMatchInlineSnapshot("0.01798620996209156");
expect(simplifyRanks(ranks)).toMatchInlineSnapshot(`
[
-11,
-11,
-11,
-5.6,
-11,
-4,
-11,
-11,
-11,
-11,
0.00001670142184809518,
0.00001670142184809518,
0.00001670142184809518,
0.003684239899435989,
0.00001670142184809518,
0.01798620996209156,
0.00001670142184809518,
0.00001670142184809518,
0.00001670142184809518,
0.00001670142184809518,
]
`);
});
Expand Down Expand Up @@ -141,42 +141,42 @@ describe("bgeReranker", () => {
expect(simplifySortedRanks([topDocument])[0]).toMatchInlineSnapshot(`
{
"document": "Mount Everest is the tallest mountain in the world",
"score": -4,
"score": 0.01798620996209156,
}
`);
expect(simplifySortedRanks(rankedDocuments)).toMatchInlineSnapshot(`
[
{
"document": "Mount Everest is the tallest mountain in the world",
"score": -4,
"score": 0.01798620996209156,
},
{
"document": "The capital of France is Paris",
"score": -5.6,
"score": 0.003684239899435989,
},
{
"document": "Not all the things that shine are made of gold",
"score": -11,
"score": 0.00001670142184809518,
},
{
"document": "I love eating pizza with extra cheese",
"score": -11,
"score": 0.00001670142184809518,
},
{
"document": "Dogs love to play fetch with their owners",
"score": -11,
"score": 0.00001670142184809518,
},
{
"document": "The sky is clear and blue today",
"score": -11,
"score": 0.00001670142184809518,
},
{
"document": "Cleaning the house is a good way to keep it tidy",
"score": -11,
"score": 0.00001670142184809518,
},
{
"document": "A warm cup of tea is perfect for a cold winter day",
"score": -11,
"score": 0.00001670142184809518,
},
]
`);
Expand All @@ -185,16 +185,28 @@ describe("bgeReranker", () => {
});

function simplifyRanks<const T extends number[]>(ranks: T): T {
return ranks.map((rank) => parseFloat(roundToPrecision(rank, 0.2).toFixed(1))) as T;
return ranks.map((rank) => simplifyScore(rank)) as T;
}

function simplifySortedRanks<const T extends {document: string, score: number}[]>(values: T): T {
return values.map((item) => ({
document: item.document,
score: parseFloat(roundToPrecision(item.score, 0.2).toFixed(1))
score: simplifyScore(item.score)
})) as T;
}

function simplifyScore(score: number) {
return toSigmoid(parseFloat(roundToPrecision(toLogit(score), 0.2).toFixed(1)));
}

function roundToPrecision(value: number, precision: number): number {
return Math.round(value / precision) * precision;
}

function toLogit(sigmoid: number) {
return Math.log(sigmoid / (1 - sigmoid));
}

function toSigmoid(logit: number) {
return 1 / (1 + Math.exp(-logit));
}

0 comments on commit d1b4416

Please sign in to comment.