add search report

This commit is contained in:
shenchucheng 2024-08-13 15:12:24 +08:00
parent 120075c250
commit 2a4e3730e1

View file

@ -4,7 +4,7 @@ from __future__ import annotations
import json
from pydantic import Field, model_validator
from pydantic import Field, PrivateAttr, model_validator
from metagpt.actions import Action
from metagpt.actions.research import CollectLinks, WebBrowseAndSummarize
@ -90,6 +90,8 @@ class SearchEnhancedQA(Action):
description="Maximum number of search results (links) to collect using the collect_links_action. This controls the number of potential sources for answering the question.",
)
_reporter: ThoughtReporter = PrivateAttr(ThoughtReporter())
@model_validator(mode="after")
def initialize(self):
if self.web_browse_and_summarize_action is None:
@ -118,13 +120,14 @@ class SearchEnhancedQA(Action):
Raises:
ValueError: If the query is invalid.
"""
async with self._reporter:
await self._reporter.async_report({"type": "search", "stage": "init"})
self._validate_query(query)
self._validate_query(query)
processed_query = await self._process_query(query, rewrite_query)
context = await self._build_context(processed_query)
processed_query = await self._process_query(query, rewrite_query)
context = await self._build_context(processed_query)
return await self._generate_answer(processed_query, context)
return await self._generate_answer(processed_query, context)
def _validate_query(self, query: str) -> None:
"""Validate the input query.
@ -203,6 +206,7 @@ class SearchEnhancedQA(Action):
"""
relevant_urls = await self._collect_relevant_links(query)
await self._reporter.async_report({"type": "search", "stage": "searching", "urls": relevant_urls})
if not relevant_urls:
logger.warning(f"No relevant URLs found for query: {query}")
return []
@ -245,10 +249,12 @@ class SearchEnhancedQA(Action):
contents = await self._fetch_web_contents(urls)
summaries = {}
await self._reporter.async_report(
{"type": "search", "stage": "browsing", "pages": [i.model_dump() for i in contents]}
)
for content in contents:
url = content.url
inner_text = content.inner_text.replace("\n", "")
if self.web_browse_and_summarize_action._is_content_invalid(inner_text):
logger.warning(f"Invalid content detected for URL {url}: {inner_text[:10]}...")
continue
@ -276,8 +282,7 @@ class SearchEnhancedQA(Action):
system_prompt = SEARCH_ENHANCED_QA_SYSTEM_PROMPT.format(context=context)
async with ThoughtReporter(enable_llm_stream=True) as reporter:
await reporter.async_report({"type": "quick"})
async with ThoughtReporter(uuid=self._reporter.uuid, enable_llm_stream=True) as reporter:
await reporter.async_report({"type": "search", "stage": "answer"})
rsp = await self._aask(query, [system_prompt])
return rsp