Merge branch 'feature-search-report' into 'mgx_ops'

Feature search report

See merge request pub/MetaGPT!303
This commit is contained in:
张雷 2024-08-13 08:11:00 +00:00
commit 0604d3bb02
4 changed files with 33 additions and 32 deletions

View file

@ -4,7 +4,7 @@ from __future__ import annotations
import json
from pydantic import Field, model_validator
from pydantic import Field, PrivateAttr, model_validator
from metagpt.actions import Action
from metagpt.actions.research import CollectLinks, WebBrowseAndSummarize
@ -90,6 +90,8 @@ class SearchEnhancedQA(Action):
description="Maximum number of search results (links) to collect using the collect_links_action. This controls the number of potential sources for answering the question.",
)
_reporter: ThoughtReporter = PrivateAttr(ThoughtReporter())
@model_validator(mode="after")
def initialize(self):
if self.web_browse_and_summarize_action is None:
@ -118,13 +120,14 @@ class SearchEnhancedQA(Action):
Raises:
ValueError: If the query is invalid.
"""
async with self._reporter:
await self._reporter.async_report({"type": "search", "stage": "init"})
self._validate_query(query)
self._validate_query(query)
processed_query = await self._process_query(query, rewrite_query)
context = await self._build_context(processed_query)
processed_query = await self._process_query(query, rewrite_query)
context = await self._build_context(processed_query)
return await self._generate_answer(processed_query, context)
return await self._generate_answer(processed_query, context)
def _validate_query(self, query: str) -> None:
"""Validate the input query.
@ -203,6 +206,7 @@ class SearchEnhancedQA(Action):
"""
relevant_urls = await self._collect_relevant_links(query)
await self._reporter.async_report({"type": "search", "stage": "searching", "urls": relevant_urls})
if not relevant_urls:
logger.warning(f"No relevant URLs found for query: {query}")
return []
@ -245,10 +249,12 @@ class SearchEnhancedQA(Action):
contents = await self._fetch_web_contents(urls)
summaries = {}
await self._reporter.async_report(
{"type": "search", "stage": "browsing", "pages": [i.model_dump() for i in contents]}
)
for content in contents:
url = content.url
inner_text = content.inner_text.replace("\n", "")
if self.web_browse_and_summarize_action._is_content_invalid(inner_text):
logger.warning(f"Invalid content detected for URL {url}: {inner_text[:10]}...")
continue
@ -276,8 +282,7 @@ class SearchEnhancedQA(Action):
system_prompt = SEARCH_ENHANCED_QA_SYSTEM_PROMPT.format(context=context)
async with ThoughtReporter(enable_llm_stream=True) as reporter:
await reporter.async_report({"type": "quick"})
async with ThoughtReporter(uuid=self._reporter.uuid, enable_llm_stream=True) as reporter:
await reporter.async_report({"type": "search", "stage": "answer"})
rsp = await self._aask(query, [system_prompt])
return rsp

View file

@ -7,7 +7,7 @@
"""
from typing import Callable, Optional
from pydantic import Field
from pydantic import ConfigDict, Field
from metagpt.tools import SearchEngineType
from metagpt.utils.yaml_model import YamlModel
@ -16,10 +16,11 @@ from metagpt.utils.yaml_model import YamlModel
class SearchConfig(YamlModel):
"""Config for Search"""
model_config = ConfigDict(extra="allow")
api_type: SearchEngineType = SearchEngineType.DUCK_DUCK_GO
api_key: str = ""
cse_id: str = "" # for google
discovery_service_url: str = "" # for google
search_func: Optional[Callable] = None
params: dict = Field(
default_factory=lambda: {

View file

@ -6,7 +6,7 @@
@File : search_engine_serpapi.py
"""
import warnings
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional
import aiohttp
from pydantic import BaseModel, ConfigDict, Field, model_validator
@ -24,6 +24,7 @@ class SerpAPIWrapper(BaseModel):
"hl": "en",
}
)
url: str = "https://serpapi.com/search"
aiosession: Optional[aiohttp.ClientSession] = None
proxy: Optional[str] = None
@ -49,22 +50,18 @@ class SerpAPIWrapper(BaseModel):
async def results(self, query: str, max_results: int) -> dict:
"""Use aiohttp to run query through SerpAPI and return the results async."""
def construct_url_and_params() -> Tuple[str, Dict[str, str]]:
params = self.get_params(query)
params["source"] = "python"
params["num"] = max_results
params["output"] = "json"
url = "https://serpapi.com/search"
return url, params
params = self.get_params(query)
params["source"] = "python"
params["num"] = max_results
params["output"] = "json"
url, params = construct_url_and_params()
if not self.aiosession:
async with aiohttp.ClientSession() as session:
async with session.get(url, params=params, proxy=self.proxy) as response:
async with session.get(self.url, params=params, proxy=self.proxy) as response:
response.raise_for_status()
res = await response.json()
else:
async with self.aiosession.get(url, params=params, proxy=self.proxy) as response:
async with self.aiosession.get(self.url, params=params, proxy=self.proxy) as response:
response.raise_for_status()
res = await response.json()

View file

@ -7,7 +7,7 @@
"""
import json
import warnings
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional
import aiohttp
from pydantic import BaseModel, ConfigDict, Field, model_validator
@ -17,6 +17,7 @@ class SerperWrapper(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
api_key: str
url: str = "https://google.serper.dev/search"
payload: dict = Field(default_factory=lambda: {"page": 1, "num": 10})
aiosession: Optional[aiohttp.ClientSession] = None
proxy: Optional[str] = None
@ -33,6 +34,7 @@ class SerperWrapper(BaseModel):
"To use serper search engine, make sure you provide the `api_key` when constructing an object. You can obtain "
"an API key from https://serper.dev/."
)
return values
async def run(self, query: str, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str:
@ -46,20 +48,16 @@ class SerperWrapper(BaseModel):
async def results(self, queries: list[str], max_results: int = 8) -> dict:
"""Use aiohttp to run query through Serper and return the results async."""
def construct_url_and_payload_and_headers() -> Tuple[str, Dict[str, str]]:
payloads = self.get_payloads(queries, max_results)
url = "https://google.serper.dev/search"
headers = self.get_headers()
return url, payloads, headers
payloads = self.get_payloads(queries, max_results)
headers = self.get_headers()
url, payloads, headers = construct_url_and_payload_and_headers()
if not self.aiosession:
async with aiohttp.ClientSession() as session:
async with session.post(url, data=payloads, headers=headers, proxy=self.proxy) as response:
async with session.post(self.url, data=payloads, headers=headers, proxy=self.proxy) as response:
response.raise_for_status()
res = await response.json()
else:
async with self.aiosession.get.post(url, data=payloads, headers=headers, proxy=self.proxy) as response:
async with self.aiosession.post(self.url, data=payloads, headers=headers, proxy=self.proxy) as response:
response.raise_for_status()
res = await response.json()