Skip to content

Commit

Permalink
Merge pull request #33 from Barqawiz/aws-llama-v2-integration
Browse files Browse the repository at this point in the history
Aws llama v2 integration
  • Loading branch information
Barqawiz authored Jul 24, 2023
2 parents 75e94ec + c6097f1 commit 183f40e
Show file tree
Hide file tree
Showing 13 changed files with 447 additions and 127 deletions.
250 changes: 144 additions & 106 deletions IntelliNode/function/Chatbot.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,135 +7,173 @@ Copyright 2023 Github.com/Barqawiz/IntelliNode
*/
const OpenAIWrapper = require("../wrappers/OpenAIWrapper");
const ReplicateWrapper = require('../wrappers/ReplicateWrapper');
const { ChatGPTInput, ChatModelInput, ChatGPTMessage, ChatLLamaInput } = require("../model/input/ChatModelInput");
const AWSEndpointWrapper = require('../wrappers/AWSEndpointWrapper');
const {
ChatGPTInput,
ChatModelInput,
ChatGPTMessage,
ChatLLamaInput,
LLamaReplicateInput,
LLamaSageInput
} = require("../model/input/ChatModelInput");

const SupportedChatModels = {
OPENAI: "openai",
REPLICATE: "replicate"
OPENAI: "openai",
REPLICATE: "replicate",
SAGEMAKER: "sagemaker"
};

class Chatbot {
constructor(keyValue, provider = SupportedChatModels.OPENAI, customProxyHelper=null) {
const supportedModels = this.getSupportedModels();

if (supportedModels.includes(provider)) {
this.initiate(keyValue, provider, customProxyHelper);
} else {
const models = supportedModels.join(" - ");
throw new Error(
`The received keyValue is not supported. Send any model from: ${models}`
);
constructor(keyValue, provider = SupportedChatModels.OPENAI, customProxyHelper = null) {
const supportedModels = this.getSupportedModels();

if (supportedModels.includes(provider)) {
this.initiate(keyValue, provider, customProxyHelper);
} else {
const models = supportedModels.join(" - ");
throw new Error(
`The received keyValue is not supported. Send any model from: ${models}`
);
}
}
}

initiate(keyValue, provider, customProxyHelper=null) {
this.provider = provider;
initiate(keyValue, provider, customProxyHelper = null) {
this.provider = provider;

if (provider === SupportedChatModels.OPENAI) {
this.openaiWrapper = new OpenAIWrapper(keyValue, customProxyHelper);
} else if (provider === SupportedChatModels.REPLICATE) {
this.replicateWrapper = new ReplicateWrapper(keyValue);
} else {
throw new Error("Invalid provider name");
}
}

getSupportedModels() {
return Object.values(SupportedChatModels);
}

async chat(modelInput, functions = null, function_call = null, debugMode = true) {
if (this.provider === SupportedChatModels.OPENAI) {
return this._chatGPT(modelInput, functions, function_call);
} else if (this.provider === SupportedChatModels.REPLICATE) {
// functions not supported for REPLICATE models
if(functions!=null || function_call!=null){
throw new Error('The functions and function_call are supported for chatGPT models only. They should be null for LLama model.');
if (provider === SupportedChatModels.OPENAI) {
this.openaiWrapper = new OpenAIWrapper(keyValue, customProxyHelper);
} else if (provider === SupportedChatModels.REPLICATE) {
this.replicateWrapper = new ReplicateWrapper(keyValue);
} else if (provider === SupportedChatModels.SAGEMAKER) {
this.sagemakerWrapper = new AWSEndpointWrapper(customProxyHelper.url, keyValue);
} else {
throw new Error("Invalid provider name");
}
return this._chatReplicateLLama(modelInput, debugMode);
} else {
throw new Error("The provider is not supported");
}
}

async _chatGPT(modelInput, functions = null, function_call = null) {
let params;
getSupportedModels() {
return Object.values(SupportedChatModels);
}

if (modelInput instanceof ChatModelInput) {
params = modelInput.getChatInput();
async chat(modelInput, functions = null, function_call = null, debugMode = true) {
if (this.provider === SupportedChatModels.OPENAI) {
return this._chatGPT(modelInput, functions, function_call);
} else if (this.provider === SupportedChatModels.REPLICATE) {
// functions not supported for REPLICATE models
if (functions != null || function_call != null) {
throw new Error('The functions and function_call are supported for chatGPT models only. They should be null for LLama model.');
}

} else if (typeof modelInput === "object") {
params = modelInput;
} else {
throw new Error("Invalid input: Must be an instance of ChatGPTInput or a dictionary");
}
return this._chatReplicateLLama(modelInput, debugMode);
} else if (this.provider === SupportedChatModels.SAGEMAKER) {

const results = await this.openaiWrapper.generateChatText(params, functions, function_call);
return results.choices.map((choice) => {
if (choice.finish_reason === 'function_call' && choice.message.function_call) {
return {
content: choice.message.content,
function_call: choice.message.function_call
};
// functions not supported for REPLICATE models
if (functions != null || function_call != null) {
throw new Error('The functions and function_call are supported for chatGPT models only. They should be null for LLama model.');
}

return this._chatSageMaker(modelInput);
} else {
return choice.message.content;
throw new Error("The provider is not supported");
}
});
}

async _chatReplicateLLama(modelInput, debugMode) {
let params;
const waitTime = 1000, maxIterate = 100;
let iteration = 0;

console.log('call')
if (modelInput instanceof ChatLLamaInput) {
params = modelInput.getChatInput();
} else if (typeof modelInput === "object") {
params = modelInput;
} else {
throw new Error("Invalid input: Must be an instance of ChatLLamaInput or a dictionary");
}

try {
const modelName = params.model;
const inputData = params.inputData;

const prediction = await this.replicateWrapper.predict(modelName, inputData);

return new Promise((resolve, reject) => {
const poll = setInterval(async () => {
const status = await this.replicateWrapper.getPredictionStatus(prediction.id);
if (debugMode) {
console.log('The current status:', status.status);
}
}

if (status.status === 'succeeded' || status.status === 'failed') {
// stop the loop if prediction has completed or failed
clearInterval(poll);
async _chatGPT(modelInput, functions = null, function_call = null) {
let params;

if (status.status === 'succeeded') {
resolve(status.output.join(' '));
} else {
console.error('LLama prediction failed:', status.error);
reject(new Error('LLama prediction failed.'));
}
}
if (iteration > maxIterate) {
reject(new Error('Replicate taking too long to process the input, try again later!'));
if (modelInput instanceof ChatModelInput) {
params = modelInput.getChatInput();

} else if (typeof modelInput === "object") {
params = modelInput;
} else {
throw new Error("Invalid input: Must be an instance of ChatGPTInput or a dictionary");
}

const results = await this.openaiWrapper.generateChatText(params, functions, function_call);
return results.choices.map((choice) => {
if (choice.finish_reason === 'function_call' && choice.message.function_call) {
return {
content: choice.message.content,
function_call: choice.message.function_call
};
} else {
return choice.message.content;
}
iteration += 1
}, waitTime);
});
} catch (error) {
console.error('LLama Error:', error);
throw error;
}
}

async _chatReplicateLLama(modelInput, debugMode) {
let params;
const waitTime = 1000,
maxIterate = 100;
let iteration = 0;

console.log('call')
if (modelInput instanceof ChatModelInput) {
params = modelInput.getChatInput();
} else if (typeof modelInput === "object") {
params = modelInput;
} else {
throw new Error("Invalid input: Must be an instance of ChatLLamaInput or a dictionary");
}

try {
const modelName = params.model;
const inputData = params.inputData;

const prediction = await this.replicateWrapper.predict(modelName, inputData);

return new Promise((resolve, reject) => {
const poll = setInterval(async () => {
const status = await this.replicateWrapper.getPredictionStatus(prediction.id);
if (debugMode) {
console.log('The current status:', status.status);
}

if (status.status === 'succeeded' || status.status === 'failed') {
// stop the loop if prediction has completed or failed
clearInterval(poll);

if (status.status === 'succeeded') {
resolve([status.output.join(' ')]);
} else {
console.error('LLama prediction failed:', status.error);
reject(new Error('LLama prediction failed.'));
}
}
if (iteration > maxIterate) {
reject(new Error('Replicate taking too long to process the input, try again later!'));
}
iteration += 1
}, waitTime);
});
} catch (error) {
console.error('LLama Error:', error);
throw error;
}
}

async _chatSageMaker(modelInput) {

let params;

if (modelInput instanceof LLamaSageInput) {
params = modelInput.getChatInput();
} else if (typeof modelInput === "object") {
params = modelInput;
} else {
throw new Error("Invalid input: Must be an instance of LLamaSageInput or a dictionary");
}

const results = await this.sagemakerWrapper.predict(params);

return results.map(result => result.generation ? result.generation.content : result);
}

} /*chatbot class*/

module.exports = {
Chatbot,
SupportedChatModels,
Chatbot,
SupportedChatModels,
};
8 changes: 6 additions & 2 deletions IntelliNode/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ const { Gen } = require('./function/Gen');
const LanguageModelInput = require('./model/input/LanguageModelInput');
const ImageModelInput = require('./model/input/ImageModelInput');
const Text2SpeechInput = require('./model/input/Text2SpeechInput');
const { ChatGPTInput, ChatLLamaInput, ChatGPTMessage } = require("./model/input/ChatModelInput");
const { ChatGPTInput, ChatLLamaInput, LLamaReplicateInput, ChatGPTMessage, LLamaSageInput } = require("./model/input/ChatModelInput");
const FunctionModelInput = require("./model/input/FunctionModelInput");
const EmbedInput = require('./model/input/EmbedInput');
// wrappers
Expand All @@ -23,6 +23,7 @@ const OpenAIWrapper = require('./wrappers/OpenAIWrapper');
const StabilityAIWrapper = require('./wrappers/StabilityAIWrapper');
const HuggingWrapper = require('./wrappers/HuggingWrapper');
const ReplicateWrapper = require('./wrappers/ReplicateWrapper');
const AWSEndpointWrapper = require('./wrappers/AWSEndpointWrapper');
// utils
const AudioHelper = require('./utils/AudioHelper');
const Config2 = require('./utils/Config2');
Expand Down Expand Up @@ -52,6 +53,7 @@ module.exports = {
SupportedChatModels,
ChatGPTInput,
ChatLLamaInput,
LLamaReplicateInput,
ChatGPTMessage,
EmbedInput,
MatchHelpers,
Expand All @@ -64,5 +66,7 @@ module.exports = {
ReplicateWrapper,
Gen,
ProxyHelper,
FunctionModelInput
FunctionModelInput,
AWSEndpointWrapper,
LLamaSageInput
};
Loading

0 comments on commit 183f40e

Please sign in to comment.