Skip to content

Commit

Permalink
Fix: Flash 2-0 doesn’t respect BLOCK_NONE on ALL
Browse files Browse the repository at this point in the history
harm categories
  • Loading branch information
YuenSzeHong committed Jan 9, 2025
1 parent 32398fe commit 3369e3b
Showing 1 changed file with 27 additions and 21 deletions.
48 changes: 27 additions & 21 deletions src/api_proxy/worker.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import { Buffer } from "node:buffer";

export default {
async fetch (request) {
async fetch(request) {
if (request.method === "OPTIONS") {
return handleOPTIONS();
}
Expand Down Expand Up @@ -79,7 +79,7 @@ const makeHeaders = (apiKey, more) => ({
...more
});

async function handleModels (apiKey) {
async function handleModels(apiKey) {
const response = await fetch(`${BASE_URL}/${API_VERSION}/models`, {
headers: makeHeaders(apiKey),
});
Expand All @@ -100,12 +100,12 @@ async function handleModels (apiKey) {
}

const DEFAULT_EMBEDDINGS_MODEL = "text-embedding-004";
async function handleEmbeddings (req, apiKey) {
async function handleEmbeddings(req, apiKey) {
if (typeof req.model !== "string") {
throw new HttpError("model is not specified", 400);
}
if (!Array.isArray(req.input)) {
req.input = [ req.input ];
req.input = [req.input];
}
let model;
if (req.model.startsWith("models/")) {
Expand Down Expand Up @@ -142,9 +142,9 @@ async function handleEmbeddings (req, apiKey) {
}

const DEFAULT_MODEL = "gemini-1.5-pro-latest";
async function handleCompletions (req, apiKey) {
async function handleCompletions(req, apiKey) {
let model = DEFAULT_MODEL;
switch(true) {
switch (true) {
case typeof req.model !== "string":
break;
case req.model.startsWith("models/"):
Expand Down Expand Up @@ -196,10 +196,15 @@ const harmCategory = [
"HARM_CATEGORY_HARASSMENT",
"HARM_CATEGORY_CIVIC_INTEGRITY",
];
const safetySettings = harmCategory.map(category => ({
category,
threshold: "BLOCK_NONE",
}));

const safetySettings = (model) => {
const threshold = modelName?.includes('2.0') ? 'OFF' : 'BLOCK_NONE';
return harmCategory.map(category => ({
category,
threshold
}));
}

const fieldsMap = {
stop: "stopSequences",
n: "candidateCount", // not for streaming
Expand All @@ -221,14 +226,14 @@ const transformConfig = (req) => {
}
}
if (req.response_format) {
switch(req.response_format.type) {
switch (req.response_format.type) {
case "json_schema":
cfg.responseSchema = req.response_format.json_schema?.schema;
if (cfg.responseSchema && "enum" in cfg.responseSchema) {
cfg.responseMimeType = "text/x.enum";
break;
}
// eslint-disable-next-line no-fallthrough
// eslint-disable-next-line no-fallthrough
case "json_object":
cfg.responseMimeType = "application/json";
break;
Expand Down Expand Up @@ -330,7 +335,7 @@ const transformMessages = async (messages) => {

const transformRequest = async (req) => ({
...await transformMessages(req.messages),
safetySettings,
safetySettings: safetySettings(req.model),
generationConfig: transformConfig(req),
});

Expand All @@ -354,7 +359,8 @@ const transformCandidates = (key, cand) => ({
index: cand.index || 0, // 0-index is absent in new -002 models response
[key]: {
role: "assistant",
content: cand.content?.parts.map(p => p.text).join(SEP) },
content: cand.content?.parts.map(p => p.text).join(SEP)
},
logprobs: null,
finish_reason: reasonsMap[cand.finishReason] || cand.finishReason,
});
Expand All @@ -371,7 +377,7 @@ const processCompletionsResponse = (data, model, id) => {
return JSON.stringify({
id,
choices: data.candidates.map(transformCandidatesMessage),
created: Math.floor(Date.now()/1000),
created: Math.floor(Date.now() / 1000),
model,
//system_fingerprint: "fp_69829325d0",
object: "chat.completion",
Expand All @@ -380,7 +386,7 @@ const processCompletionsResponse = (data, model, id) => {
};

const responseLineRE = /^data: (.*)(?:\n\n|\r\r|\r\n\r\n)/;
async function parseStream (chunk, controller) {
async function parseStream(chunk, controller) {
chunk = await chunk;
if (!chunk) { return; }
this.buffer += chunk;
Expand All @@ -391,21 +397,21 @@ async function parseStream (chunk, controller) {
this.buffer = this.buffer.substring(match[0].length);
} while (true); // eslint-disable-line no-constant-condition
}
async function parseStreamFlush (controller) {
async function parseStreamFlush(controller) {
if (this.buffer) {
console.error("Invalid data:", this.buffer);
controller.enqueue(this.buffer);
}
}

function transformResponseStream (data, stop, first) {
function transformResponseStream(data, stop, first) {
const item = transformCandidatesDelta(data.candidates[0]);
if (stop) { item.delta = {}; } else { item.finish_reason = null; }
if (first) { item.delta.content = ""; } else { delete item.delta.role; }
const output = {
id: this.id,
choices: [item],
created: Math.floor(Date.now()/1000),
created: Math.floor(Date.now() / 1000),
model: this.model,
//system_fingerprint: "fp_69829325d0",
object: "chat.completion.chunk",
Expand All @@ -416,7 +422,7 @@ function transformResponseStream (data, stop, first) {
return "data: " + JSON.stringify(output) + delimiter;
}
const delimiter = "\n\n";
async function toOpenAiStream (chunk, controller) {
async function toOpenAiStream(chunk, controller) {
const transform = transformResponseStream.bind(this);
const line = await chunk;
if (!line) { return; }
Expand Down Expand Up @@ -445,7 +451,7 @@ async function toOpenAiStream (chunk, controller) {
controller.enqueue(transform(data));
}
}
async function toOpenAiStreamFlush (controller) {
async function toOpenAiStreamFlush(controller) {
const transform = transformResponseStream.bind(this);
if (this.last.length > 0) {
for (const data of this.last) {
Expand Down

0 comments on commit 3369e3b

Please sign in to comment.