From 2a4e3730e10cd6eb96d718b71a4eee80610e6508 Mon Sep 17 00:00:00 2001 From: shenchucheng Date: Tue, 13 Aug 2024 15:12:24 +0800 Subject: [PATCH] add search report --- metagpt/actions/search_enhanced_qa.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/metagpt/actions/search_enhanced_qa.py b/metagpt/actions/search_enhanced_qa.py index 1d7944d61..152e615b6 100644 --- a/metagpt/actions/search_enhanced_qa.py +++ b/metagpt/actions/search_enhanced_qa.py @@ -4,7 +4,7 @@ from __future__ import annotations import json -from pydantic import Field, model_validator +from pydantic import Field, PrivateAttr, model_validator from metagpt.actions import Action from metagpt.actions.research import CollectLinks, WebBrowseAndSummarize @@ -90,6 +90,8 @@ class SearchEnhancedQA(Action): description="Maximum number of search results (links) to collect using the collect_links_action. This controls the number of potential sources for answering the question.", ) + _reporter: ThoughtReporter = PrivateAttr(ThoughtReporter()) + @model_validator(mode="after") def initialize(self): if self.web_browse_and_summarize_action is None: @@ -118,13 +120,14 @@ class SearchEnhancedQA(Action): Raises: ValueError: If the query is invalid. """ + async with self._reporter: + await self._reporter.async_report({"type": "search", "stage": "init"}) + self._validate_query(query) - self._validate_query(query) + processed_query = await self._process_query(query, rewrite_query) + context = await self._build_context(processed_query) - processed_query = await self._process_query(query, rewrite_query) - context = await self._build_context(processed_query) - - return await self._generate_answer(processed_query, context) + return await self._generate_answer(processed_query, context) def _validate_query(self, query: str) -> None: """Validate the input query. @@ -203,6 +206,7 @@ class SearchEnhancedQA(Action): """ relevant_urls = await self._collect_relevant_links(query) + await self._reporter.async_report({"type": "search", "stage": "searching", "urls": relevant_urls}) if not relevant_urls: logger.warning(f"No relevant URLs found for query: {query}") return [] @@ -245,10 +249,12 @@ class SearchEnhancedQA(Action): contents = await self._fetch_web_contents(urls) summaries = {} + await self._reporter.async_report( + {"type": "search", "stage": "browsing", "pages": [i.model_dump() for i in contents]} + ) for content in contents: url = content.url inner_text = content.inner_text.replace("\n", "") - if self.web_browse_and_summarize_action._is_content_invalid(inner_text): logger.warning(f"Invalid content detected for URL {url}: {inner_text[:10]}...") continue @@ -276,8 +282,7 @@ class SearchEnhancedQA(Action): system_prompt = SEARCH_ENHANCED_QA_SYSTEM_PROMPT.format(context=context) - async with ThoughtReporter(enable_llm_stream=True) as reporter: - await reporter.async_report({"type": "quick"}) + async with ThoughtReporter(uuid=self._reporter.uuid, enable_llm_stream=True) as reporter: + await reporter.async_report({"type": "search", "stage": "answer"}) rsp = await self._aask(query, [system_prompt]) - return rsp