Merge branch 'mgx_ops' into feat-intention-fs

This commit is contained in:
Yizhou Chi 2024-08-14 10:18:57 +08:00
commit aee672b8bb
9 changed files with 104 additions and 42 deletions

View file

@ -4,7 +4,7 @@ from __future__ import annotations
import json
from pydantic import Field, model_validator
from pydantic import Field, PrivateAttr, model_validator
from metagpt.actions import Action
from metagpt.actions.research import CollectLinks, WebBrowseAndSummarize
@ -45,7 +45,9 @@ Follow **Instructions**, generate output and make sure it follows the **Constrai
SEARCH_ENHANCED_QA_SYSTEM_PROMPT = """
You are a large language AI assistant built by MGX. You are given a user question, and please write clean, concise and accurate answer to the question. You will be given a set of related contexts to the question, each starting with a reference number like [[citation:x]], where x is a number. Please use the context.
Your answer must be correct, accurate and written by an expert using an unbiased and professional tone. Please limit to 1024 tokens. Do not give any information that is not related to the question, and do not repeat. Say "information is missing on" followed by the related topic, if the given context do not provide sufficient information. Do not include [citation] in your anwser.
Your answer must be correct, accurate and written by an expert using an unbiased and professional tone. Please limit to 1024 tokens. Do not give any information that is not related to the question, and do not repeat. Say "information is missing on" followed by the related topic, if the given context do not provide sufficient information.
Do not include [citation:x] in your anwser, where x is a number. Other than code and specific names and citations, your answer must be written in the same language as the question.
Here are the set of contexts:
@ -90,10 +92,12 @@ class SearchEnhancedQA(Action):
description="Maximum number of search results (links) to collect using the collect_links_action. This controls the number of potential sources for answering the question.",
)
_reporter: ThoughtReporter = PrivateAttr(ThoughtReporter())
@model_validator(mode="after")
def initialize(self):
if self.web_browse_and_summarize_action is None:
self.web_browser_engine = WebBrowserEngine.from_browser_config(
web_browser_engine = WebBrowserEngine.from_browser_config(
self.config.browser,
proxy=self.config.proxy,
java_script_enabled=self.java_script_enabled,
@ -101,7 +105,7 @@ class SearchEnhancedQA(Action):
user_agent=self.user_agent,
)
self.web_browse_and_summarize_action = WebBrowseAndSummarize(web_browser_engine=self.web_browser_engine)
self.web_browse_and_summarize_action = WebBrowseAndSummarize(web_browser_engine=web_browser_engine)
return self
@ -118,13 +122,14 @@ class SearchEnhancedQA(Action):
Raises:
ValueError: If the query is invalid.
"""
async with self._reporter:
await self._reporter.async_report({"type": "search", "stage": "init"})
self._validate_query(query)
self._validate_query(query)
processed_query = await self._process_query(query, rewrite_query)
context = await self._build_context(processed_query)
processed_query = await self._process_query(query, rewrite_query)
context = await self._build_context(processed_query)
return await self._generate_answer(processed_query, context)
return await self._generate_answer(processed_query, context)
def _validate_query(self, query: str) -> None:
"""Validate the input query.
@ -203,6 +208,7 @@ class SearchEnhancedQA(Action):
"""
relevant_urls = await self._collect_relevant_links(query)
await self._reporter.async_report({"type": "search", "stage": "searching", "urls": relevant_urls})
if not relevant_urls:
logger.warning(f"No relevant URLs found for query: {query}")
return []
@ -245,10 +251,12 @@ class SearchEnhancedQA(Action):
contents = await self._fetch_web_contents(urls)
summaries = {}
await self._reporter.async_report(
{"type": "search", "stage": "browsing", "pages": [i.model_dump() for i in contents]}
)
for content in contents:
url = content.url
inner_text = content.inner_text.replace("\n", "")
if self.web_browse_and_summarize_action._is_content_invalid(inner_text):
logger.warning(f"Invalid content detected for URL {url}: {inner_text[:10]}...")
continue
@ -276,8 +284,7 @@ class SearchEnhancedQA(Action):
system_prompt = SEARCH_ENHANCED_QA_SYSTEM_PROMPT.format(context=context)
async with ThoughtReporter(enable_llm_stream=True) as reporter:
await reporter.async_report({"type": "quick"})
async with ThoughtReporter(uuid=self._reporter.uuid, enable_llm_stream=True) as reporter:
await reporter.async_report({"type": "search", "stage": "answer"})
rsp = await self._aask(query, [system_prompt])
return rsp

View file

@ -7,7 +7,7 @@
"""
from typing import Callable, Optional
from pydantic import Field
from pydantic import ConfigDict, Field
from metagpt.tools import SearchEngineType
from metagpt.utils.yaml_model import YamlModel
@ -16,10 +16,11 @@ from metagpt.utils.yaml_model import YamlModel
class SearchConfig(YamlModel):
"""Config for Search"""
model_config = ConfigDict(extra="allow")
api_type: SearchEngineType = SearchEngineType.DUCK_DUCK_GO
api_key: str = ""
cse_id: str = "" # for google
discovery_service_url: str = "" # for google
search_func: Optional[Callable] = None
params: dict = Field(
default_factory=lambda: {

View file

@ -208,6 +208,6 @@ class CodeReview(Action):
comments = await self.confirm_comments(patch=patch, comments=comments, points=points)
for comment in comments:
if comment["code"]:
if not (comment["code"].startswith("-") or comment["code"].isspace()):
if not (comment["code"].isspace()):
result.append(comment)
return result

View file

@ -168,7 +168,7 @@ Response Category: [QUICK/SEARCH/TASK/AMBIGUOUS]
"""
QUICK_THINK_EXAMPLES ="""
QUICK_THINK_EXAMPLES = """
# Example
1. Request: "How do I design an online document editing platform that supports real-time collaboration?"
@ -203,4 +203,4 @@ Response Category: TASK.
Thought: The request is vague and lacks specifics, requiring clarification on the process to optimize.
Response Category: AMBIGUOUS.
"""
"""

View file

@ -37,7 +37,11 @@ from metagpt.tools.libs.editor import Editor
from metagpt.tools.tool_recommend import BM25ToolRecommender, ToolRecommender
from metagpt.tools.tool_registry import register_tool
from metagpt.utils.common import CodeParser, any_to_str
from metagpt.utils.repair_llm_raw_output import RepairType, repair_llm_raw_output
from metagpt.utils.repair_llm_raw_output import (
RepairType,
repair_escape_error,
repair_llm_raw_output,
)
from metagpt.utils.report import ThoughtReporter
@ -321,10 +325,20 @@ class RoleZero(Role):
if commands.endswith("]") and not commands.startswith("["):
commands = "[" + commands
commands = json.loads(repair_llm_raw_output(output=commands, req_keys=[None], repair_type=RepairType.JSON))
except json.JSONDecodeError:
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse JSON for: {self.command_rsp}. Trying to repair...")
commands = await self.llm.aask(msg=JSON_REPAIR_PROMPT.format(json_data=self.command_rsp))
commands = json.loads(CodeParser.parse_code(block=None, lang="json", text=commands))
commands = await self.llm.aask(
msg=JSON_REPAIR_PROMPT.format(json_data=self.command_rsp, json_decode_error=str(e))
)
try:
commands = json.loads(CodeParser.parse_code(block=None, lang="json", text=commands))
except json.JSONDecodeError:
# repair escape error of code and math
commands = CodeParser.parse_code(block=None, lang="json", text=self.command_rsp)
new_command = repair_escape_error(commands)
commands = json.loads(
repair_llm_raw_output(output=new_command, req_keys=[None], repair_type=RepairType.JSON)
)
except Exception as e:
tb = traceback.format_exc()
print(tb)

View file

@ -31,6 +31,10 @@ async def default_get_env(key: str, app_name: str = None) -> str:
if app_key in os.environ:
return os.environ[app_key]
env_app_key = app_key.replace("-", "_") # "-" is not supported by linux environment variable
if env_app_key in os.environ:
return os.environ[env_app_key]
from metagpt.context import Context
context = Context()

View file

@ -6,7 +6,7 @@
@File : search_engine_serpapi.py
"""
import warnings
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional
import aiohttp
from pydantic import BaseModel, ConfigDict, Field, model_validator
@ -24,6 +24,7 @@ class SerpAPIWrapper(BaseModel):
"hl": "en",
}
)
url: str = "https://serpapi.com/search"
aiosession: Optional[aiohttp.ClientSession] = None
proxy: Optional[str] = None
@ -49,22 +50,18 @@ class SerpAPIWrapper(BaseModel):
async def results(self, query: str, max_results: int) -> dict:
"""Use aiohttp to run query through SerpAPI and return the results async."""
def construct_url_and_params() -> Tuple[str, Dict[str, str]]:
params = self.get_params(query)
params["source"] = "python"
params["num"] = max_results
params["output"] = "json"
url = "https://serpapi.com/search"
return url, params
params = self.get_params(query)
params["source"] = "python"
params["num"] = max_results
params["output"] = "json"
url, params = construct_url_and_params()
if not self.aiosession:
async with aiohttp.ClientSession() as session:
async with session.get(url, params=params, proxy=self.proxy) as response:
async with session.get(self.url, params=params, proxy=self.proxy) as response:
response.raise_for_status()
res = await response.json()
else:
async with self.aiosession.get(url, params=params, proxy=self.proxy) as response:
async with self.aiosession.get(self.url, params=params, proxy=self.proxy) as response:
response.raise_for_status()
res = await response.json()

View file

@ -7,7 +7,7 @@
"""
import json
import warnings
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional
import aiohttp
from pydantic import BaseModel, ConfigDict, Field, model_validator
@ -17,6 +17,7 @@ class SerperWrapper(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
api_key: str
url: str = "https://google.serper.dev/search"
payload: dict = Field(default_factory=lambda: {"page": 1, "num": 10})
aiosession: Optional[aiohttp.ClientSession] = None
proxy: Optional[str] = None
@ -33,6 +34,7 @@ class SerperWrapper(BaseModel):
"To use serper search engine, make sure you provide the `api_key` when constructing an object. You can obtain "
"an API key from https://serper.dev/."
)
return values
async def run(self, query: str, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str:
@ -46,20 +48,16 @@ class SerperWrapper(BaseModel):
async def results(self, queries: list[str], max_results: int = 8) -> dict:
"""Use aiohttp to run query through Serper and return the results async."""
def construct_url_and_payload_and_headers() -> Tuple[str, Dict[str, str]]:
payloads = self.get_payloads(queries, max_results)
url = "https://google.serper.dev/search"
headers = self.get_headers()
return url, payloads, headers
payloads = self.get_payloads(queries, max_results)
headers = self.get_headers()
url, payloads, headers = construct_url_and_payload_and_headers()
if not self.aiosession:
async with aiohttp.ClientSession() as session:
async with session.post(url, data=payloads, headers=headers, proxy=self.proxy) as response:
async with session.post(self.url, data=payloads, headers=headers, proxy=self.proxy) as response:
response.raise_for_status()
res = await response.json()
else:
async with self.aiosession.get.post(url, data=payloads, headers=headers, proxy=self.proxy) as response:
async with self.aiosession.post(self.url, data=payloads, headers=headers, proxy=self.proxy) as response:
response.raise_for_status()
res = await response.json()

View file

@ -347,3 +347,44 @@ def extract_state_value_from_output(content: str) -> str:
matches = list(set(matches))
state = matches[0] if len(matches) > 0 else "-1"
return state
def repair_escape_error(commands):
"""
Repaires escape errors in command responses.
When RoleZero parses a command, the command may contain unknown escape characters.
This function has two steps:
1. Transform unescaped substrings like "\d" and "\(" to "\\\\d" and "\\\\(".
2. Transform escaped characters like '\f' to substrings like "\\\\f".
Example:
When the original JSON string is " {"content":"\\\\( \\\\frac{1}{2} \\\\)"} ",
The "content" will be parsed correctly to "\( \frac{1}{2} \)".
However, if the original JSON string is " {"content":"\( \frac{1}{2} \)"}" directly.
It will cause a parsing error.
To repair the wrong JSON string, the following transformations will be used:
"\(" ---> "\\\\("
'\f' ---> "\\\\f"
"\)" ---> "\\\\)"
"""
escape_repair_map = {
"\a": "\\\\a",
"\b": "\\\\b",
"\f": "\\\\f",
"\r": "\\\\r",
"\t": "\\\\t",
"\v": "\\\\v",
}
new_command = ""
for index, ch in enumerate(commands):
if ch == "\\" and index + 1 < len(commands):
if commands[index + 1] not in ["n", '"', " "]:
new_command += "\\"
elif ch in escape_repair_map:
ch = escape_repair_map[ch]
new_command += ch
return new_command