Suppress the pydantic.ValidationError in SearchAndSummarize

This commit is contained in:
shenchucheng 2023-08-17 21:26:11 +08:00
parent 4da1a7d0ac
commit d2aaafd235

View file

@ -5,6 +5,8 @@
@Author : alexanderwu
@File : search_google.py
"""
import pydantic
from metagpt.actions import Action
from metagpt.config import Config
from metagpt.logs import logger
@ -34,7 +36,7 @@ A: MLOps competitors
8. Dataiku
"""
SEARCH_AND_SUMMARIZE_SYSTEM_EN_US = SEARCH_AND_SUMMARIZE_SYSTEM.format(LANG='en-us')
SEARCH_AND_SUMMARIZE_SYSTEM_EN_US = SEARCH_AND_SUMMARIZE_SYSTEM.format(LANG="en-us")
SEARCH_AND_SUMMARIZE_PROMPT = """
### Reference Information
@ -102,25 +104,26 @@ class SearchAndSummarize(Action):
def __init__(self, name="", context=None, llm=None, engine=None, search_func=None):
self.config = Config()
self.engine = engine or self.config.search_engine
self.search_engine = SearchEngine(self.engine, run_func=search_func)
try:
self.search_engine = SearchEngine(self.engine, run_func=search_func)
except pydantic.ValidationError:
self.search_engine = None
self.result = ""
super().__init__(name, context, llm)
async def run(self, context: list[Message], system_text=SEARCH_AND_SUMMARIZE_SYSTEM) -> str:
no_serpapi = not self.config.serpapi_api_key or 'YOUR_API_KEY' == self.config.serpapi_api_key
no_serper = not self.config.serper_api_key or 'YOUR_API_KEY' == self.config.serper_api_key
no_google = not self.config.google_api_key or 'YOUR_API_KEY' == self.config.google_api_key
if no_serpapi and no_google and no_serper:
logger.warning('Configure one of SERPAPI_API_KEY, SERPER_API_KEY, GOOGLE_API_KEY to unlock full feature')
if self.search_engine is None:
logger.warning("Configure one of SERPAPI_API_KEY, SERPER_API_KEY, GOOGLE_API_KEY to unlock full feature")
return ""
query = context[-1].content
# logger.debug(query)
rsp = await self.search_engine.run(query)
self.result = rsp
if not rsp:
logger.error('empty rsp...')
logger.error("empty rsp...")
return ""
# logger.info(rsp)
@ -130,8 +133,8 @@ class SearchAndSummarize(Action):
# PREFIX = self.prefix,
ROLE=self.profile,
CONTEXT=rsp,
QUERY_HISTORY='\n'.join([str(i) for i in context[:-1]]),
QUERY=str(context[-1])
QUERY_HISTORY="\n".join([str(i) for i in context[:-1]]),
QUERY=str(context[-1]),
)
result = await self._aask(prompt, system_prompt)
logger.debug(prompt)