diff --git a/mindsearch/agent/__init__.py b/mindsearch/agent/__init__.py index 31c434a..72b6136 100644 --- a/mindsearch/agent/__init__.py +++ b/mindsearch/agent/__init__.py @@ -25,7 +25,9 @@ LLM = {} -def init_agent(lang="cn", model_format="internlm_server", search_engine="BingSearch", use_async=False): +def init_agent( + lang="cn", model_format="internlm_server", search_engine="BingSearch", use_async=False, identity_cognition='' +): mode = "async" if use_async else "sync" llm = LLM.get(model_format, {}).get(mode) if llm is None: @@ -59,7 +61,7 @@ def init_agent(lang="cn", model_format="internlm_server", search_engine="BingSea ] agent = (AsyncMindSearchAgent if use_async else MindSearchAgent)( llm=llm, - template=date, + template=identity_cognition + date, output_format=InterpreterParser( begin="<|action_start|><|interpreter|>", end="<|action_end|>", @@ -68,7 +70,7 @@ def init_agent(lang="cn", model_format="internlm_server", search_engine="BingSea searcher_cfg=dict( llm=llm, plugins=plugins, - template=date, + template=identity_cognition + date, output_format=PluginParser( begin="<|action_start|><|plugin|>", end="<|action_end|>", diff --git a/mindsearch/app.py b/mindsearch/app.py index 1f71cc6..32483dc 100644 --- a/mindsearch/app.py +++ b/mindsearch/app.py @@ -24,6 +24,12 @@ def parse_arguments(): parser.add_argument("--model_format", default="internlm_server", type=str, help="Model format") parser.add_argument("--search_engine", default="BingSearch", type=str, help="Search engine") parser.add_argument("--asy", default=False, action="store_true", help="Agent mode") + parser.add_argument( + "--identity_cognition", + default="You are InternLM (书生·浦语), a helpful, honest, and harmless AI assistant developed by Shanghai AI Laboratory (上海人工智能实验室).\n", + type=str, + help="Identity cognition", + ) return parser.parse_args() @@ -62,9 +68,7 @@ def _postprocess_agent_message(message: dict) -> dict: node_fmt = node["response"]["formatted"] if isinstance(node_fmt, dict) and "thought" in node_fmt and "action" in node_fmt: node["response"]["content"] = None - node_fmt["thought"] = ( - node_fmt["thought"] and node_fmt["thought"].split("<|action_start|>")[0] - ) + node_fmt["thought"] = node_fmt["thought"] and node_fmt["thought"].split("<|action_start|>")[0] if isinstance(node_fmt["action"], str): node_fmt["action"] = node_fmt["action"].split("<|action_end|>")[0] else: @@ -116,9 +120,7 @@ async def async_generator_wrapper(): except Exception as exc: msg = "An error occurred while generating the response." logging.exception(msg) - response_json = json.dumps( - dict(error=dict(msg=msg, details=str(exc))), ensure_ascii=False - ) + response_json = json.dumps(dict(error=dict(msg=msg, details=str(exc))), ensure_ascii=False) yield {"data": response_json} finally: await stop_event.wait() # Waiting for async_generator_wrapper to stop @@ -132,6 +134,7 @@ async def async_generator_wrapper(): lang=args.lang, model_format=args.model_format, search_engine=args.search_engine, + identity_cognition=args.identity_cognition, ) return EventSourceResponse(generate(), ping=300) @@ -150,9 +153,7 @@ async def generate(): except Exception as exc: msg = "An error occurred while generating the response." logging.exception(msg) - response_json = json.dumps( - dict(error=dict(msg=msg, details=str(exc))), ensure_ascii=False - ) + response_json = json.dumps(dict(error=dict(msg=msg, details=str(exc))), ensure_ascii=False) yield {"data": response_json} finally: agent.agent.memory.memory_map.pop(session_id, None) @@ -164,6 +165,7 @@ async def generate(): model_format=args.model_format, search_engine=args.search_engine, use_async=True, + identity_cognition=args.identity_cognition, ) return EventSourceResponse(generate(), ping=300)