diff --git a/packages/components/nodes/chains/RetrievalQAChain/RetrievalQAChain.ts b/packages/components/nodes/chains/RetrievalQAChain/RetrievalQAChain.ts index 5b43ffc1369..0a4de643df1 100644 --- a/packages/components/nodes/chains/RetrievalQAChain/RetrievalQAChain.ts +++ b/packages/components/nodes/chains/RetrievalQAChain/RetrievalQAChain.ts @@ -1,6 +1,7 @@ import { BaseRetriever } from '@langchain/core/retrievers' import { BaseLanguageModel } from '@langchain/core/language_models/base' import { RetrievalQAChain } from 'langchain/chains' +import { BasePromptTemplate } from '@langchain/core/prompts' import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler' import { ICommonObject, INode, INodeData, INodeParams, IServerSideEventStreamer } from '../../../src/Interface' import { getBaseClasses } from '../../../src/utils' @@ -38,6 +39,18 @@ class RetrievalQAChain_Chains implements INode { name: 'vectorStoreRetriever', type: 'BaseRetriever' }, + { + label: 'Prompt', + name: 'prompt', + type: 'BasePromptTemplate', + optional: true + }, + { + label: 'Return Source Documents', + name: 'returnSourceDocuments', + type: 'boolean', + optional: true + }, { label: 'Input Moderation', description: 'Detect text that could generate harmful output and prevent it from being sent to the language model', @@ -51,9 +64,15 @@ class RetrievalQAChain_Chains implements INode { async init(nodeData: INodeData): Promise { const model = nodeData.inputs?.model as BaseLanguageModel + const prompt = nodeData.inputs?.prompt as BasePromptTemplate | undefined const vectorStoreRetriever = nodeData.inputs?.vectorStoreRetriever as BaseRetriever + const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean - const chain = RetrievalQAChain.fromLLM(model, vectorStoreRetriever, { verbose: process.env.DEBUG === 'true' ? true : false }) + const chain = RetrievalQAChain.fromLLM(model, vectorStoreRetriever, { + prompt, + returnSourceDocuments, + verbose: process.env.DEBUG === 'true' ? true : false + }) return chain } @@ -86,9 +105,11 @@ class RetrievalQAChain_Chains implements INode { if (shouldStreamResponse) { const handler = new CustomChainHandler(sseStreamer, chatId) const res = await chain.call(obj, [loggerHandler, handler, ...callbacks]) + if (res.text && res.sourceDocuments) return res return res?.text } else { const res = await chain.call(obj, [loggerHandler, ...callbacks]) + if (res.text && res.sourceDocuments) return res return res?.text } }