+
diff --git a/docs/resources/20230808-002840.jpg b/docs/resources/20230808-002840.jpg
deleted file mode 100644
index 1d4852930..000000000
Binary files a/docs/resources/20230808-002840.jpg and /dev/null differ
diff --git a/docs/resources/20230808-220924.jpg b/docs/resources/20230808-220924.jpg
deleted file mode 100644
index 3226e2366..000000000
Binary files a/docs/resources/20230808-220924.jpg and /dev/null differ
diff --git a/docs/resources/20230811-214014.jpg b/docs/resources/20230811-214014.jpg
new file mode 100644
index 000000000..2006f2646
Binary files /dev/null and b/docs/resources/20230811-214014.jpg differ
diff --git a/docs/resources/MetaGPT-WeChat-Group-Simple.png b/docs/resources/MetaGPT-WeChat-Group-Simple.png
deleted file mode 100644
index 771a235c9..000000000
Binary files a/docs/resources/MetaGPT-WeChat-Group-Simple.png and /dev/null differ
diff --git a/docs/resources/MetaGPT-WeChat-Group.jpeg b/docs/resources/MetaGPT-WeChat-Group.jpeg
deleted file mode 100644
index 8e60cfd94..000000000
Binary files a/docs/resources/MetaGPT-WeChat-Group.jpeg and /dev/null differ
diff --git a/docs/resources/MetaGPT-WeChat-Group4.jpeg b/docs/resources/MetaGPT-WeChat-Group4.jpeg
deleted file mode 100644
index f665f8b1d..000000000
Binary files a/docs/resources/MetaGPT-WeChat-Group4.jpeg and /dev/null differ
diff --git a/docs/resources/MetaGPT-WeChat-Personal-new.jpg b/docs/resources/MetaGPT-WeChat-Personal-new.jpg
deleted file mode 100644
index 9a5ae5a56..000000000
Binary files a/docs/resources/MetaGPT-WeChat-Personal-new.jpg and /dev/null differ
diff --git a/examples/research.py b/examples/research.py
new file mode 100644
index 000000000..344f8d0e9
--- /dev/null
+++ b/examples/research.py
@@ -0,0 +1,16 @@
+#!/usr/bin/env python
+
+import asyncio
+
+from metagpt.roles.researcher import RESEARCH_PATH, Researcher
+
+
+async def main():
+ topic = "dataiku vs. datarobot"
+ role = Researcher(language="en-us")
+ await role.run(topic)
+ print(f"save report to {RESEARCH_PATH / f'{topic}.md'}.")
+
+
+if __name__ == '__main__':
+ asyncio.run(main())
diff --git a/metagpt/actions/__init__.py b/metagpt/actions/__init__.py
index 0c861aa69..b004bd58e 100644
--- a/metagpt/actions/__init__.py
+++ b/metagpt/actions/__init__.py
@@ -15,6 +15,7 @@ from metagpt.actions.design_api import WriteDesign
from metagpt.actions.design_api_review import DesignReview
from metagpt.actions.design_filenames import DesignFilenames
from metagpt.actions.project_management import AssignTasks, WriteTasks
+from metagpt.actions.research import CollectLinks, WebBrowseAndSummarize, ConductResearch
from metagpt.actions.run_code import RunCode
from metagpt.actions.search_and_summarize import SearchAndSummarize
from metagpt.actions.write_code import WriteCode
@@ -26,6 +27,7 @@ from metagpt.actions.write_test import WriteTest
class ActionType(Enum):
"""All types of Actions, used for indexing."""
+
ADD_REQUIREMENT = BossRequirement
WRITE_PRD = WritePRD
WRITE_PRD_REVIEW = WritePRDReview
@@ -40,3 +42,13 @@ class ActionType(Enum):
WRITE_TASKS = WriteTasks
ASSIGN_TASKS = AssignTasks
SEARCH_AND_SUMMARIZE = SearchAndSummarize
+ COLLECT_LINKS = CollectLinks
+ WEB_BROWSE_AND_SUMMARIZE = WebBrowseAndSummarize
+ CONDUCT_RESEARCH = ConductResearch
+
+
+__all__ = [
+ "ActionType",
+ "Action",
+ "ActionOutput",
+]
diff --git a/metagpt/actions/research.py b/metagpt/actions/research.py
new file mode 100644
index 000000000..81eb876dd
--- /dev/null
+++ b/metagpt/actions/research.py
@@ -0,0 +1,277 @@
+#!/usr/bin/env python
+
+from __future__ import annotations
+
+import asyncio
+import json
+from typing import Callable
+
+from pydantic import parse_obj_as
+
+from metagpt.actions import Action
+from metagpt.config import CONFIG
+from metagpt.logs import logger
+from metagpt.tools.search_engine import SearchEngine
+from metagpt.tools.web_browser_engine import WebBrowserEngine, WebBrowserEngineType
+from metagpt.utils.text import generate_prompt_chunk, reduce_message_length
+
+LANG_PROMPT = "Please respond in {language}."
+
+RESEARCH_BASE_SYSTEM = """You are an AI critical thinker research assistant. Your sole purpose is to write well \
+written, critically acclaimed, objective and structured reports on the given text."""
+
+RESEARCH_TOPIC_SYSTEM = "You are an AI researcher assistant, and your research topic is:\n#TOPIC#\n{topic}"
+
+SEARCH_TOPIC_PROMPT = """Please provide up to 2 necessary keywords related to your research topic for Google search. \
+Your response must be in JSON format, for example: ["keyword1", "keyword2"]."""
+
+SUMMARIZE_SEARCH_PROMPT = """### Requirements
+1. The keywords related to your research topic and the search results are shown in the "Search Result Information" section.
+2. Provide up to {decomposition_nums} queries related to your research topic base on the search results.
+3. Please respond in the following JSON format: ["query1", "query2", "query3", ...].
+
+### Search Result Information
+{search_results}
+"""
+
+COLLECT_AND_RANKURLS_PROMPT = """### Topic
+{topic}
+### Query
+{query}
+
+### The online search results
+{results}
+
+### Requirements
+Please remove irrelevant search results that are not related to the query or topic. Then, sort the remaining search results \
+based on the link credibility. If two results have equal credibility, prioritize them based on the relevance. Provide the
+ranked results' indices in JSON format, like [0, 1, 3, 4, ...], without including other words.
+"""
+
+WEB_BROWSE_AND_SUMMARIZE_PROMPT = '''### Requirements
+1. Utilize the text in the "Reference Information" section to respond to the question "{query}".
+2. If the question cannot be directly answered using the text, but the text is related to the research topic, please provide \
+a comprehensive summary of the text.
+3. If the text is entirely unrelated to the research topic, please reply with a simple text "Not relevant."
+4. Include all relevant factual information, numbers, statistics, etc., if available.
+
+### Reference Information
+{content}
+'''
+
+
+CONDUCT_RESEARCH_PROMPT = '''### Reference Information
+{content}
+
+### Requirements
+Please provide a detailed research report in response to the following topic: "{topic}", using the information provided \
+above. The report must meet the following requirements:
+
+- Focus on directly addressing the chosen topic.
+- Ensure a well-structured and in-depth presentation, incorporating relevant facts and figures where available.
+- Present data and findings in an intuitive manner, utilizing feature comparative tables, if applicable.
+- The report should have a minimum word count of 2,000 and be formatted with Markdown syntax following APA style guidelines.
+- Include all source URLs in APA format at the end of the report.
+'''
+
+
+class CollectLinks(Action):
+ """Action class to collect links from a search engine."""
+ def __init__(
+ self,
+ name: str = "",
+ *args,
+ rank_func: Callable[[list[str]], None] | None = None,
+ **kwargs,
+ ):
+ super().__init__(name, *args, **kwargs)
+ self.desc = "Collect links from a search engine."
+ self.search_engine = SearchEngine()
+ self.rank_func = rank_func
+
+ async def run(
+ self,
+ topic: str,
+ decomposition_nums: int = 4,
+ url_per_query: int = 4,
+ system_text: str | None = None,
+ ) -> dict[str, list[str]]:
+ """Run the action to collect links.
+
+ Args:
+ topic: The research topic.
+ decomposition_nums: The number of search questions to generate.
+ url_per_query: The number of URLs to collect per search question.
+ system_text: The system text.
+
+ Returns:
+ A dictionary containing the search questions as keys and the collected URLs as values.
+ """
+ system_text = system_text if system_text else RESEARCH_TOPIC_SYSTEM.format(topic=topic)
+ keywords = await self._aask(SEARCH_TOPIC_PROMPT, [system_text])
+ try:
+ keywords = json.loads(keywords)
+ keywords = parse_obj_as(list[str], keywords)
+ except Exception as e:
+ logger.exception(f"fail to get keywords related to the research topic \"{topic}\" for {e}")
+ keywords = [topic]
+ results = await asyncio.gather(*(self.search_engine.run(i, as_string=False) for i in keywords))
+
+ def gen_msg():
+ while True:
+ search_results = "\n".join(f"#### Keyword: {i}\n Search Result: {j}\n" for (i, j) in zip(keywords, results))
+ prompt = SUMMARIZE_SEARCH_PROMPT.format(decomposition_nums=decomposition_nums, search_results=search_results)
+ yield prompt
+ remove = max(results, key=len)
+ remove.pop()
+ if len(remove) == 0:
+ break
+ prompt = reduce_message_length(gen_msg(), self.llm.model, system_text, CONFIG.max_tokens_rsp)
+ logger.debug(prompt)
+ queries = await self._aask(prompt, [system_text])
+ try:
+ queries = json.loads(queries)
+ queries = parse_obj_as(list[str], queries)
+ except Exception as e:
+ logger.exception(f"fail to break down the research question due to {e}")
+ queries = keywords
+ ret = {}
+ for query in queries:
+ ret[query] = await self._search_and_rank_urls(topic, query, url_per_query)
+ return ret
+
+ async def _search_and_rank_urls(self, topic: str, query: str, num_results: int = 4) -> list[str]:
+ """Search and rank URLs based on a query.
+
+ Args:
+ topic: The research topic.
+ query: The search query.
+ num_results: The number of URLs to collect.
+
+ Returns:
+ A list of ranked URLs.
+ """
+ max_results = max(num_results * 2, 6)
+ results = await self.search_engine.run(query, max_results=max_results, as_string=False)
+ _results = "\n".join(f"{i}: {j}" for i, j in zip(range(max_results), results))
+ prompt = COLLECT_AND_RANKURLS_PROMPT.format(topic=topic, query=query, results=_results)
+ logger.debug(prompt)
+ indices = await self._aask(prompt)
+ try:
+ indices = json.loads(indices)
+ assert all(isinstance(i, int) for i in indices)
+ except Exception as e:
+ logger.exception(f"fail to rank results for {e}")
+ indices = list(range(max_results))
+ results = [results[i] for i in indices]
+ if self.rank_func:
+ results = self.rank_func(results)
+ return [i["link"] for i in results[:num_results]]
+
+
+class WebBrowseAndSummarize(Action):
+ """Action class to explore the web and provide summaries of articles and webpages."""
+ def __init__(
+ self,
+ *args,
+ browse_func: Callable[[list[str]], None] | None = None,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+ if CONFIG.model_for_researcher_summary:
+ self.llm.model = CONFIG.model_for_researcher_summary
+ self.web_browser_engine = WebBrowserEngine(
+ engine=WebBrowserEngineType.CUSTOM if browse_func else None,
+ run_func=browse_func,
+ )
+ self.desc = "Explore the web and provide summaries of articles and webpages."
+
+ async def run(
+ self,
+ url: str,
+ *urls: str,
+ query: str,
+ system_text: str = RESEARCH_BASE_SYSTEM,
+ ) -> dict[str, str]:
+ """Run the action to browse the web and provide summaries.
+
+ Args:
+ url: The main URL to browse.
+ urls: Additional URLs to browse.
+ query: The research question.
+ system_text: The system text.
+
+ Returns:
+ A dictionary containing the URLs as keys and their summaries as values.
+ """
+ contents = await self.web_browser_engine.run(url, *urls)
+ if not urls:
+ contents = [contents]
+
+ summaries = {}
+ prompt_template = WEB_BROWSE_AND_SUMMARIZE_PROMPT.format(query=query, content="{}")
+ for u, content in zip([url, *urls], contents):
+ content = content.inner_text
+ chunk_summaries = []
+ for prompt in generate_prompt_chunk(content, prompt_template, self.llm.model, system_text, CONFIG.max_tokens_rsp):
+ logger.debug(prompt)
+ summary = await self._aask(prompt, [system_text])
+ if summary == "Not relevant.":
+ continue
+ chunk_summaries.append(summary)
+
+ if not chunk_summaries:
+ summaries[u] = None
+ continue
+
+ if len(chunk_summaries) == 1:
+ summaries[u] = chunk_summaries[0]
+ continue
+
+ content = "\n".join(chunk_summaries)
+ prompt = WEB_BROWSE_AND_SUMMARIZE_PROMPT.format(query=query, content=content)
+ summary = await self._aask(prompt, [system_text])
+ summaries[u] = summary
+ return summaries
+
+
+class ConductResearch(Action):
+ """Action class to conduct research and generate a research report."""
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ if CONFIG.model_for_researcher_report:
+ self.llm.model = CONFIG.model_for_researcher_report
+
+ async def run(
+ self,
+ topic: str,
+ content: str,
+ system_text: str = RESEARCH_BASE_SYSTEM,
+ ) -> str:
+ """Run the action to conduct research and generate a research report.
+
+ Args:
+ topic: The research topic.
+ content: The content for research.
+ system_text: The system text.
+
+ Returns:
+ The generated research report.
+ """
+ prompt = CONDUCT_RESEARCH_PROMPT.format(topic=topic, content=content)
+ logger.debug(prompt)
+ self.llm.auto_max_tokens = True
+ return await self._aask(prompt, [system_text])
+
+
+def get_research_system_text(topic: str, language: str):
+ """Get the system text for conducting research.
+
+ Args:
+ topic: The research topic.
+ language: The language for the system text.
+
+ Returns:
+ The system text for conducting research.
+ """
+ return " ".join((RESEARCH_TOPIC_SYSTEM.format(topic=topic), LANG_PROMPT.format(language=language)))
diff --git a/metagpt/actions/run_code.py b/metagpt/actions/run_code.py
index 1bc5cc13a..f69d2cd1a 100644
--- a/metagpt/actions/run_code.py
+++ b/metagpt/actions/run_code.py
@@ -5,13 +5,13 @@
@Author : alexanderwu
@File : run_code.py
"""
-import traceback
import os
import subprocess
-from typing import List, Tuple
+import traceback
+from typing import Tuple
-from metagpt.logs import logger
from metagpt.actions.action import Action
+from metagpt.logs import logger
PROMPT_TEMPLATE = """
Role: You are a senior development and qa engineer, your role is summarize the code running result.
@@ -27,7 +27,7 @@ Please summarize the cause of the errors and give correction instruction
Determine the ONE file to rewrite in order to fix the error, for example, xyz.py, or test_xyz.py
## Status:
Determine if all of the code works fine, if so write PASS, else FAIL,
-WRITE ONLY ONE WORD, PASS OR FAIL, IN THI SECTION
+WRITE ONLY ONE WORD, PASS OR FAIL, IN THIS SECTION
## Send To:
Please write Engineer if the errors are due to problematic development codes, and QaEngineer to problematic test codes, and NoOne if there are no errors,
WRITE ONLY ONE WORD, Engineer OR QaEngineer OR NoOne, IN THIS SECTION.
@@ -55,6 +55,7 @@ standard output: {outs};
standard errors: {errs};
"""
+
class RunCode(Action):
def __init__(self, name="RunCode", context=None, llm=None):
super().__init__(name, context, llm)
@@ -65,7 +66,7 @@ class RunCode(Action):
# We will document_store the result in this dictionary
namespace = {}
exec(code, namespace)
- return namespace.get('result', ""), ""
+ return namespace.get("result", ""), ""
except Exception:
# If there is an error in the code, return the error message
return "", traceback.format_exc()
@@ -81,10 +82,12 @@ class RunCode(Action):
# Modify the PYTHONPATH environment variable
additional_python_paths = [working_directory] + additional_python_paths
additional_python_paths = ":".join(additional_python_paths)
- env['PYTHONPATH'] = additional_python_paths + ':' + env.get('PYTHONPATH', '')
+ env["PYTHONPATH"] = additional_python_paths + ":" + env.get("PYTHONPATH", "")
# Start the subprocess
- process = subprocess.Popen(command, cwd=working_directory, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env)
+ process = subprocess.Popen(
+ command, cwd=working_directory, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env
+ )
try:
# Wait for the process to complete, with a timeout
@@ -93,7 +96,7 @@ class RunCode(Action):
logger.info("The command did not complete within the given timeout.")
process.kill() # Kill the process if it times out
stdout, stderr = process.communicate()
- return stdout.decode('utf-8'), stderr.decode('utf-8')
+ return stdout.decode("utf-8"), stderr.decode("utf-8")
async def run(
self, code, mode="script", code_file_name="", test_code="", test_file_name="", command=[], **kwargs
@@ -108,11 +111,13 @@ class RunCode(Action):
logger.info(f"{errs=}")
context = CONTEXT.format(
- code=code, code_file_name=code_file_name,
- test_code=test_code, test_file_name=test_file_name,
+ code=code,
+ code_file_name=code_file_name,
+ test_code=test_code,
+ test_file_name=test_file_name,
command=" ".join(command),
- outs=outs[:500], # outs might be long but they are not important, truncate them to avoid token overflow
- errs=errs[:10000] # truncate errors to avoid token overflow
+ outs=outs[:500], # outs might be long but they are not important, truncate them to avoid token overflow
+ errs=errs[:10000], # truncate errors to avoid token overflow
)
prompt = PROMPT_TEMPLATE.format(context=context)
diff --git a/metagpt/actions/write_test.py b/metagpt/actions/write_test.py
index e1c1571c3..5e50fdb55 100644
--- a/metagpt/actions/write_test.py
+++ b/metagpt/actions/write_test.py
@@ -5,7 +5,6 @@
@Author : alexanderwu
@File : write_test.py
"""
-from metagpt.logs import logger
from metagpt.actions.action import Action
from metagpt.utils.common import CodeParser
@@ -29,6 +28,7 @@ you should correctly import the necessary classes based on these file locations!
## {test_file_name}: Write test code with triple quoto. Do your best to implement THIS ONLY ONE FILE.
"""
+
class WriteTest(Action):
def __init__(self, name="WriteTest", context=None, llm=None):
super().__init__(name, context, llm)
@@ -43,7 +43,7 @@ class WriteTest(Action):
code_to_test=code_to_test,
test_file_name=test_file_name,
source_file_path=source_file_path,
- workspace=workspace
+ workspace=workspace,
)
code = await self.write_code(prompt)
return code
diff --git a/metagpt/config.py b/metagpt/config.py
index d53571468..fb1aa485c 100644
--- a/metagpt/config.py
+++ b/metagpt/config.py
@@ -4,14 +4,14 @@
提供配置,单例
"""
import os
-import openai
+import openai
import yaml
from metagpt.const import PROJECT_ROOT
from metagpt.logs import logger
-from metagpt.utils.singleton import Singleton
from metagpt.tools import SearchEngineType, WebBrowserEngineType
+from metagpt.utils.singleton import Singleton
class NotConfiguredException(Exception):
@@ -46,7 +46,6 @@ class Config(metaclass=Singleton):
self.openai_api_key = self._get("OPENAI_API_KEY")
if not self.openai_api_key or "YOUR_API_KEY" == self.openai_api_key:
raise NotConfiguredException("Set OPENAI_API_KEY first")
-
self.openai_api_base = self._get("OPENAI_API_BASE")
if not self.openai_api_base or "YOUR_API_BASE" == self.openai_api_base:
openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy
@@ -67,22 +66,22 @@ class Config(metaclass=Singleton):
self.google_api_key = self._get("GOOGLE_API_KEY")
self.google_cse_id = self._get("GOOGLE_CSE_ID")
self.search_engine = self._get("SEARCH_ENGINE", SearchEngineType.SERPAPI_GOOGLE)
-
+
self.web_browser_engine = WebBrowserEngineType(self._get("WEB_BROWSER_ENGINE", "playwright"))
self.playwright_browser_type = self._get("PLAYWRIGHT_BROWSER_TYPE", "chromium")
self.selenium_browser_type = self._get("SELENIUM_BROWSER_TYPE", "chrome")
-
+
self.long_term_memory = self._get('LONG_TERM_MEMORY', False)
if self.long_term_memory:
logger.warning("LONG_TERM_MEMORY is True")
self.max_budget = self._get("MAX_BUDGET", 10.0)
self.total_cost = 0.0
+
self.puppeteer_config = self._get("PUPPETEER_CONFIG","")
self.mmdc = self._get("MMDC","mmdc")
- self.update_costs = self._get("UPDATE_COSTS",True)
self.calc_usage = self._get("CALC_USAGE",True)
-
-
+ self.model_for_researcher_summary = self._get("MODEL_FOR_RESEARCHER_SUMMARY")
+ self.model_for_researcher_report = self._get("MODEL_FOR_RESEARCHER_REPORT")
def _init_with_config_files_and_env(self, configs: dict, yaml_file):
"""从config/key.yaml / config/config.yaml / env三处按优先级递减加载"""
diff --git a/metagpt/const.py b/metagpt/const.py
index abbfb40e0..505eebd46 100644
--- a/metagpt/const.py
+++ b/metagpt/const.py
@@ -32,5 +32,6 @@ UT_PY_PATH = UT_PATH / "files/ut/"
API_QUESTIONS_PATH = UT_PATH / "files/question/"
YAPI_URL = "http://yapi.deepwisdomai.com/"
TMP = PROJECT_ROOT / 'tmp'
+RESEARCH_PATH = DATA_PATH / "research"
MEM_TTL = 24 * 30 * 3600
diff --git a/metagpt/document_store/__init__.py b/metagpt/document_store/__init__.py
index 7d7c6e5e9..766e141a5 100644
--- a/metagpt/document_store/__init__.py
+++ b/metagpt/document_store/__init__.py
@@ -7,3 +7,5 @@
"""
from metagpt.document_store.faiss_store import FaissStore
+
+__all__ = ["FaissStore"]
diff --git a/metagpt/document_store/base_store.py b/metagpt/document_store/base_store.py
index 01877e106..3dc96c0d6 100644
--- a/metagpt/document_store/base_store.py
+++ b/metagpt/document_store/base_store.py
@@ -15,7 +15,7 @@ class BaseStore(ABC):
"""FIXME: consider add_index, set_index and think 颗粒度"""
@abstractmethod
- def search(self, query, *args, **kwargs):
+ def search(self, *args, **kwargs):
raise NotImplementedError
@abstractmethod
diff --git a/metagpt/document_store/qdrant_store.py b/metagpt/document_store/qdrant_store.py
new file mode 100644
index 000000000..98b82cf87
--- /dev/null
+++ b/metagpt/document_store/qdrant_store.py
@@ -0,0 +1,129 @@
+from dataclasses import dataclass
+from typing import List
+
+from qdrant_client import QdrantClient
+from qdrant_client.models import Filter, PointStruct, VectorParams
+
+from metagpt.document_store.base_store import BaseStore
+
+
+@dataclass
+class QdrantConnection:
+ """
+ Args:
+ url: qdrant url
+ host: qdrant host
+ port: qdrant port
+ memory: qdrant service use memory mode
+ api_key: qdrant cloud api_key
+ """
+ url: str = None
+ host: str = None
+ port: int = None
+ memory: bool = False
+ api_key: str = None
+
+
+class QdrantStore(BaseStore):
+ def __init__(self, connect: QdrantConnection):
+ if connect.memory:
+ self.client = QdrantClient(":memory:")
+ elif connect.url:
+ self.client = QdrantClient(url=connect.url, api_key=connect.api_key)
+ elif connect.host and connect.port:
+ self.client = QdrantClient(
+ host=connect.host, port=connect.port, api_key=connect.api_key
+ )
+ else:
+ raise Exception("please check QdrantConnection.")
+
+ def create_collection(
+ self,
+ collection_name: str,
+ vectors_config: VectorParams,
+ force_recreate=False,
+ **kwargs,
+ ):
+ """
+ create a collection
+ Args:
+ collection_name: collection name
+ vectors_config: VectorParams object,detail in https://github.com/qdrant/qdrant-client
+ force_recreate: default is False, if True, will delete exists collection,then create it
+ **kwargs:
+
+ Returns:
+
+ """
+ try:
+ self.client.get_collection(collection_name)
+ if force_recreate:
+ res = self.client.recreate_collection(
+ collection_name, vectors_config=vectors_config, **kwargs
+ )
+ return res
+ return True
+ except: # noqa: E722
+ return self.client.recreate_collection(
+ collection_name, vectors_config=vectors_config, **kwargs
+ )
+
+ def has_collection(self, collection_name: str):
+ try:
+ self.client.get_collection(collection_name)
+ return True
+ except: # noqa: E722
+ return False
+
+ def delete_collection(self, collection_name: str, timeout=60):
+ res = self.client.delete_collection(collection_name, timeout=timeout)
+ if not res:
+ raise Exception(f"Delete collection {collection_name} failed.")
+
+ def add(self, collection_name: str, points: List[PointStruct]):
+ """
+ add some vector data to qdrant
+ Args:
+ collection_name: collection name
+ points: list of PointStruct object, about PointStruct detail in https://github.com/qdrant/qdrant-client
+
+ Returns: NoneX
+
+ """
+ # self.client.upload_records()
+ self.client.upsert(
+ collection_name,
+ points,
+ )
+
+ def search(
+ self,
+ collection_name: str,
+ query: List[float],
+ query_filter: Filter = None,
+ k=10,
+ return_vector=False,
+ ):
+ """
+ vector search
+ Args:
+ collection_name: qdrant collection name
+ query: input vector
+ query_filter: Filter object, detail in https://github.com/qdrant/qdrant-client
+ k: return the most similar k pieces of data
+ return_vector: whether return vector
+
+ Returns: list of dict
+
+ """
+ hits = self.client.search(
+ collection_name=collection_name,
+ query_vector=query,
+ query_filter=query_filter,
+ limit=k,
+ with_vectors=return_vector,
+ )
+ return [hit.__dict__ for hit in hits]
+
+ def write(self, *args, **kwargs):
+ pass
diff --git a/metagpt/environment.py b/metagpt/environment.py
index c4d612d85..24e6ada2f 100644
--- a/metagpt/environment.py
+++ b/metagpt/environment.py
@@ -16,7 +16,10 @@ from metagpt.schema import Message
class Environment(BaseModel):
- """环境,承载一批角色,角色可以向环境发布消息,可以被其他角色观察到"""
+ """环境,承载一批角色,角色可以向环境发布消息,可以被其他角色观察到
+ Environment, hosting a batch of roles, roles can publish messages to the environment, and can be observed by other roles
+
+ """
roles: dict[str, Role] = Field(default_factory=dict)
memory: Memory = Field(default_factory=Memory)
@@ -26,23 +29,31 @@ class Environment(BaseModel):
arbitrary_types_allowed = True
def add_role(self, role: Role):
- """增加一个在当前环境的Role"""
+ """增加一个在当前环境的角色
+ Add a role in the current environment
+ """
role.set_env(self)
self.roles[role.profile] = role
def add_roles(self, roles: Iterable[Role]):
- """增加一批在当前环境的Role"""
+ """增加一批在当前环境的角色
+ Add a batch of characters in the current environment
+ """
for role in roles:
self.add_role(role)
def publish_message(self, message: Message):
- """向当前环境发布信息"""
+ """向当前环境发布信息
+ Post information to the current environment
+ """
# self.message_queue.put(message)
self.memory.add(message)
self.history += f"\n{message}"
async def run(self, k=1):
- """处理一次所有Role的运行"""
+ """处理一次所有信息的运行
+ Process all Role runs at once
+ """
# while not self.message_queue.empty():
# message = self.message_queue.get()
# rsp = await self.manager.handle(message, self)
@@ -56,9 +67,13 @@ class Environment(BaseModel):
await asyncio.gather(*futures)
def get_roles(self) -> dict[str, Role]:
- """获得环境内的所有Role"""
+ """获得环境内的所有角色
+ Process all Role runs at once
+ """
return self.roles
def get_role(self, name: str) -> Role:
- """获得环境内的指定Role"""
+ """获得环境内的指定角色
+ get all the environment roles
+ """
return self.roles.get(name, None)
diff --git a/metagpt/llm.py b/metagpt/llm.py
index ae7f4c6f1..6a9a9132f 100644
--- a/metagpt/llm.py
+++ b/metagpt/llm.py
@@ -14,5 +14,7 @@ CLAUDE_LLM = Claude()
async def ai_func(prompt):
- """使用LLM进行QA"""
+ """使用LLM进行QA
+ QA with LLMs
+ """
return await DEFAULT_LLM.aask(prompt)
diff --git a/metagpt/logs.py b/metagpt/logs.py
index fa4befa7d..0adee23ff 100644
--- a/metagpt/logs.py
+++ b/metagpt/logs.py
@@ -14,7 +14,9 @@ from metagpt.const import PROJECT_ROOT
def define_log_level(print_level="INFO", logfile_level="DEBUG"):
- """调整日志级别到level之上"""
+ """调整日志级别到level之上
+ Adjust the log level to above level
+ """
_logger.remove()
_logger.add(sys.stderr, level=print_level)
_logger.add(PROJECT_ROOT / 'logs/log.txt', level=logfile_level)
diff --git a/metagpt/manager.py b/metagpt/manager.py
index 3cb445108..9d238c621 100644
--- a/metagpt/manager.py
+++ b/metagpt/manager.py
@@ -33,6 +33,7 @@ class Manager:
async def handle(self, message: Message, environment):
"""
管理员处理信息,现在简单的将信息递交给下一个人
+ The administrator processes the information, now simply passes the information on to the next person
:param message:
:param environment:
:return:
@@ -50,6 +51,7 @@ class Manager:
# chosen_role_name = self.llm.ask(self.prompt_template.format(context))
# FIXME: 现在通过简单的字典决定流向,但之后还是应该有思考过程
+ #The direction of flow is now determined by a simple dictionary, but there should still be a thought process afterwards
next_role_profile = self.role_directions[message.role]
# logger.debug(f"{next_role_profile}")
for _, role in roles.items():
diff --git a/metagpt/memory/__init__.py b/metagpt/memory/__init__.py
index 2eff0d890..710930626 100644
--- a/metagpt/memory/__init__.py
+++ b/metagpt/memory/__init__.py
@@ -9,3 +9,8 @@
from metagpt.memory.memory import Memory
from metagpt.memory.longterm_memory import LongTermMemory
+
+__all__ = [
+ "Memory",
+ "LongTermMemory",
+]
diff --git a/metagpt/memory/longterm_memory.py b/metagpt/memory/longterm_memory.py
index 154fcfbda..3c2963613 100644
--- a/metagpt/memory/longterm_memory.py
+++ b/metagpt/memory/longterm_memory.py
@@ -2,12 +2,10 @@
# -*- coding: utf-8 -*-
# @Desc : the implement of Long-term memory
-from typing import Iterable, Type
-
from metagpt.logs import logger
-from metagpt.schema import Message
from metagpt.memory import Memory
from metagpt.memory.memory_storage import MemoryStorage
+from metagpt.schema import Message
class LongTermMemory(Memory):
@@ -27,10 +25,11 @@ class LongTermMemory(Memory):
messages = self.memory_storage.recover_memory(role_id)
self.rc = rc
if not self.memory_storage.is_initialized:
- logger.warning(f'It may the first time to run Agent {role_id}, the long-term memory is empty')
+ logger.warning(f"It may the first time to run Agent {role_id}, the long-term memory is empty")
else:
- logger.warning(f'Agent {role_id} has existed memory storage with {len(messages)} messages '
- f'and has recovered them.')
+ logger.warning(
+ f"Agent {role_id} has existed memory storage with {len(messages)} messages " f"and has recovered them."
+ )
self.msg_from_recover = True
self.add_batch(messages)
self.msg_from_recover = False
diff --git a/metagpt/provider/__init__.py b/metagpt/provider/__init__.py
index 785dbdd66..56dc19b4b 100644
--- a/metagpt/provider/__init__.py
+++ b/metagpt/provider/__init__.py
@@ -7,3 +7,6 @@
"""
from metagpt.provider.openai_api import OpenAIGPTAPI
+
+
+__all__ = ["OpenAIGPTAPI"]
diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py
index f57a6bcf7..0f7100db8 100644
--- a/metagpt/provider/openai_api.py
+++ b/metagpt/provider/openai_api.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/5/5 23:08
@@ -8,10 +7,11 @@
"""
import asyncio
import time
-from functools import wraps
from typing import NamedTuple
import traceback
import openai
+from openai.error import APIConnectionError
+from tenacity import retry, stop_after_attempt, after_log, wait_fixed, retry_if_exception_type
from metagpt.config import CONFIG
from metagpt.logs import logger
@@ -21,9 +21,11 @@ from metagpt.utils.token_counter import (
TOKEN_COSTS,
count_message_tokens,
count_string_tokens,
+ get_max_completion_tokens,
)
+<<<<<<< HEAD
def retry(max_retries):
def decorator(f):
@wraps(f)
@@ -41,15 +43,20 @@ def retry(max_retries):
return decorator
+=======
+>>>>>>> main
class RateLimiter:
"""Rate control class, each call goes through wait_if_needed, sleep if rate control is needed"""
+
def __init__(self, rpm):
self.last_call_time = 0
- self.interval = 1.1 * 60 / rpm # Here 1.1 is used because even if the calls are made strictly according to time, they will still be QOS'd; consider switching to simple error retry later
+ # Here 1.1 is used because even if the calls are made strictly according to time,
+ # they will still be QOS'd; consider switching to simple error retry later
+ self.interval = 1.1 * 60 / rpm
self.rpm = rpm
def split_batches(self, batch):
- return [batch[i:i + self.rpm] for i in range(0, len(batch), self.rpm)]
+ return [batch[i : i + self.rpm] for i in range(0, len(batch), self.rpm)]
async def wait_if_needed(self, num_requests):
current_time = time.time()
@@ -72,6 +79,7 @@ class Costs(NamedTuple):
class CostManager(metaclass=Singleton):
"""计算使用接口的开销"""
+
def __init__(self):
self.total_prompt_tokens = 0
self.total_completion_tokens = 0
@@ -89,13 +97,12 @@ class CostManager(metaclass=Singleton):
"""
self.total_prompt_tokens += prompt_tokens
self.total_completion_tokens += completion_tokens
- cost = (
- prompt_tokens * TOKEN_COSTS[model]["prompt"]
- + completion_tokens * TOKEN_COSTS[model]["completion"]
- ) / 1000
+ cost = (prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model]["completion"]) / 1000
self.total_cost += cost
- logger.info(f"Total running cost: ${self.total_cost:.3f} | Max budget: ${CONFIG.max_budget:.3f} | "
- f"Current cost: ${cost:.3f}, {prompt_tokens=}, {completion_tokens=}")
+ logger.info(
+ f"Total running cost: ${self.total_cost:.3f} | Max budget: ${CONFIG.max_budget:.3f} | "
+ f"Current cost: ${cost:.3f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
+ )
CONFIG.total_cost = self.total_cost
def get_total_prompt_tokens(self):
@@ -130,14 +137,25 @@ class CostManager(metaclass=Singleton):
return Costs(self.total_prompt_tokens, self.total_completion_tokens, self.total_cost, self.total_budget)
+def log_and_reraise(retry_state):
+ logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}")
+ logger.warning("""
+Recommend going to https://deepwisdom.feishu.cn/wiki/MsGnwQBjiif9c3koSJNcYaoSnu4#part-XdatdVlhEojeAfxaaEZcMV3ZniQ
+See FAQ 5.8
+""")
+ raise retry_state.outcome.exception()
+
+
class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
"""
Check https://platform.openai.com/examples for examples
"""
+
def __init__(self):
self.__init_openai(CONFIG)
self.llm = openai
self.model = CONFIG.openai_api_model
+ self.auto_max_tokens = False
self._cost_manager = CostManager()
RateLimiter.__init__(self, rpm=self.rpm)
@@ -167,41 +185,42 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
# iterate through the stream of events
async for chunk in response:
collected_chunks.append(chunk) # save the event response
- chunk_message = chunk['choices'][0]['delta'] # extract the message
+ chunk_message = chunk["choices"][0]["delta"] # extract the message
collected_messages.append(chunk_message) # save the message
if "content" in chunk_message:
print(chunk_message["content"], end="")
print()
- full_reply_content = ''.join([m.get('content', '') for m in collected_messages])
+ full_reply_content = "".join([m.get("content", "") for m in collected_messages])
usage = self._calc_usage(messages, full_reply_content)
self._update_costs(usage)
return full_reply_content
def _cons_kwargs(self, messages: list[dict]) -> dict:
- if CONFIG.openai_api_type == 'azure':
+ if CONFIG.openai_api_type == "azure":
kwargs = {
"deployment_id": CONFIG.deployment_id,
"messages": messages,
- "max_tokens": CONFIG.max_tokens_rsp,
+ "max_tokens": self.get_max_tokens(messages),
"n": 1,
"stop": None,
- "temperature": 0.3
+ "temperature": 0.3,
}
else:
kwargs = {
"model": self.model,
"messages": messages,
- "max_tokens": CONFIG.max_tokens_rsp,
+ "max_tokens": self.get_max_tokens(messages),
"n": 1,
"stop": None,
- "temperature": 0.3
+ "temperature": 0.3,
}
+ kwargs["timeout"] = 3
return kwargs
async def _achat_completion(self, messages: list[dict]) -> dict:
rsp = await self.llm.ChatCompletion.acreate(**self._cons_kwargs(messages))
- self._update_costs(rsp.get('usage'))
+ self._update_costs(rsp.get("usage"))
return rsp
def _chat_completion(self, messages: list[dict]) -> dict:
@@ -219,7 +238,13 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
# messages = self.messages_to_dict(messages)
return await self._achat_completion(messages)
- @retry(max_retries=6)
+ @retry(
+ stop=stop_after_attempt(3),
+ wait=wait_fixed(1),
+ after=after_log(logger, logger.level('WARNING').name),
+ retry=retry_if_exception_type(APIConnectionError),
+ retry_error_callback=log_and_reraise,
+ )
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
"""when streaming, print each token in place."""
if stream:
@@ -230,11 +255,16 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
def _calc_usage(self, messages: list[dict], rsp: str) -> dict:
usage = {}
if CONFIG.calc_usage:
- prompt_tokens = count_message_tokens(messages, self.model)
- completion_tokens = count_string_tokens(rsp, self.model)
- usage['prompt_tokens'] = prompt_tokens
- usage['completion_tokens'] = completion_tokens
- return usage
+ try:
+ prompt_tokens = count_message_tokens(messages, self.model)
+ completion_tokens = count_string_tokens(rsp, self.model)
+ usage['prompt_tokens'] = prompt_tokens
+ usage['completion_tokens'] = completion_tokens
+ return usage
+ except Exception as e:
+ logger.error("usage calculation failed!", e)
+ else:
+ return usage
async def acompletion_batch(self, batch: list[list[dict]]) -> list[dict]:
"""返回完整JSON"""
@@ -263,10 +293,18 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
return results
def _update_costs(self, usage: dict):
- if CONFIG.update_costs:
- prompt_tokens = int(usage['prompt_tokens'])
- completion_tokens = int(usage['completion_tokens'])
- self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
+ if CONFIG.calc_usage:
+ try:
+ prompt_tokens = int(usage['prompt_tokens'])
+ completion_tokens = int(usage['completion_tokens'])
+ self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
+ except Exception as e:
+ logger.error("updating costs failed!", e)
def get_costs(self) -> Costs:
return self._cost_manager.get_costs()
+
+ def get_max_tokens(self, messages: list[dict]):
+ if not self.auto_max_tokens:
+ return CONFIG.max_tokens_rsp
+ return get_max_completion_tokens(messages, self.model, CONFIG.max_tokens_rsp)
diff --git a/metagpt/roles/__init__.py b/metagpt/roles/__init__.py
index b1911df06..1768b786c 100644
--- a/metagpt/roles/__init__.py
+++ b/metagpt/roles/__init__.py
@@ -8,10 +8,23 @@
from metagpt.roles.role import Role
from metagpt.roles.architect import Architect
-from metagpt.roles.product_manager import ProductManager
from metagpt.roles.project_manager import ProjectManager
+from metagpt.roles.product_manager import ProductManager
from metagpt.roles.engineer import Engineer
from metagpt.roles.qa_engineer import QaEngineer
from metagpt.roles.seacher import Searcher
from metagpt.roles.sales import Sales
from metagpt.roles.customer_service import CustomerService
+
+
+__all__ = [
+ "Role",
+ "Architect",
+ "ProjectManager",
+ "ProductManager",
+ "Engineer",
+ "QaEngineer",
+ "Searcher",
+ "Sales",
+ "CustomerService",
+]
diff --git a/metagpt/roles/qa_engineer.py b/metagpt/roles/qa_engineer.py
index 5e12a1abd..65bf2cc5b 100644
--- a/metagpt/roles/qa_engineer.py
+++ b/metagpt/roles/qa_engineer.py
@@ -6,40 +6,44 @@
@File : qa_engineer.py
"""
import os
-import re
from pathlib import Path
-from typing import Type
-from metagpt.actions import WriteTest, WriteCode, WriteDesign, RunCode, DebugError
+from metagpt.actions import DebugError, RunCode, WriteCode, WriteDesign, WriteTest
from metagpt.const import WORKSPACE_ROOT
from metagpt.logs import logger
from metagpt.roles import Role
from metagpt.schema import Message
-from metagpt.roles.engineer import Engineer
from metagpt.utils.common import CodeParser, parse_recipient
-from metagpt.utils.special_tokens import MSG_SEP, FILENAME_CODE_SEP
+from metagpt.utils.special_tokens import FILENAME_CODE_SEP, MSG_SEP
+
class QaEngineer(Role):
- def __init__(self, name="Edward", profile="QaEngineer",
- goal="Write comprehensive and robust tests to ensure codes will work as expected without bugs",
- constraints="The test code you write should conform to code standard like PEP8, be modular, easy to read and maintain",
- test_round_allowed=5):
+ def __init__(
+ self,
+ name="Edward",
+ profile="QaEngineer",
+ goal="Write comprehensive and robust tests to ensure codes will work as expected without bugs",
+ constraints="The test code you write should conform to code standard like PEP8, be modular, easy to read and maintain",
+ test_round_allowed=5,
+ ):
super().__init__(name, profile, goal, constraints)
- self._init_actions([WriteTest]) # FIXME: a bit hack here, only init one action to circumvent _think() logic, will overwrite _think() in future updates
+ self._init_actions(
+ [WriteTest]
+ ) # FIXME: a bit hack here, only init one action to circumvent _think() logic, will overwrite _think() in future updates
self._watch([WriteCode, WriteTest, RunCode, DebugError])
self.test_round = 0
self.test_round_allowed = test_round_allowed
-
+
@classmethod
def parse_workspace(cls, system_design_msg: Message) -> str:
if not system_design_msg.instruct_content:
return system_design_msg.instruct_content.dict().get("Python package name")
return CodeParser.parse_str(block="Python package name", text=system_design_msg.content)
-
+
def get_workspace(self, return_proj_dir=True) -> Path:
msg = self._rc.memory.get_by_action(WriteDesign)[-1]
if not msg:
- return WORKSPACE_ROOT / 'src'
+ return WORKSPACE_ROOT / "src"
workspace = self.parse_workspace(msg)
# project directory: workspace/{package_name}, which contains package source code folder, tests folder, resources folder, etc.
if return_proj_dir:
@@ -48,49 +52,52 @@ class QaEngineer(Role):
return WORKSPACE_ROOT / workspace / workspace
def write_file(self, filename: str, code: str):
- workspace = self.get_workspace() / 'tests'
+ workspace = self.get_workspace() / "tests"
file = workspace / filename
file.parent.mkdir(parents=True, exist_ok=True)
file.write_text(code)
async def _write_test(self, message: Message) -> None:
-
code_msgs = message.content.split(MSG_SEP)
- result_msg_all = []
+ # result_msg_all = []
for code_msg in code_msgs:
-
# write tests
file_name, file_path = code_msg.split(FILENAME_CODE_SEP)
code_to_test = open(file_path, "r").read()
if "test" in file_name:
- continue # Engineer might write some test files, skip testing a test file
+ continue # Engineer might write some test files, skip testing a test file
test_file_name = "test_" + file_name
test_file_path = self.get_workspace() / "tests" / test_file_name
- logger.info(f'Writing {test_file_name}..')
+ logger.info(f"Writing {test_file_name}..")
test_code = await WriteTest().run(
code_to_test=code_to_test,
test_file_name=test_file_name,
# source_file_name=file_name,
source_file_path=file_path,
- workspace=self.get_workspace()
+ workspace=self.get_workspace(),
)
self.write_file(test_file_name, test_code)
# prepare context for run tests in next round
- command = ['python', f'tests/{test_file_name}']
+ command = ["python", f"tests/{test_file_name}"]
file_info = {
- "file_name": file_name, "file_path": str(file_path),
- "test_file_name": test_file_name, "test_file_path": str(test_file_path),
- "command": command
+ "file_name": file_name,
+ "file_path": str(file_path),
+ "test_file_name": test_file_name,
+ "test_file_path": str(test_file_path),
+ "command": command,
}
msg = Message(
- content=str(file_info), role=self.profile, cause_by=WriteTest,
- sent_from=self.profile, send_to=self.profile
+ content=str(file_info),
+ role=self.profile,
+ cause_by=WriteTest,
+ sent_from=self.profile,
+ send_to=self.profile,
)
self._publish_message(msg)
-
- logger.info(f'Done {self.get_workspace()}/tests generating.')
-
+
+ logger.info(f"Done {self.get_workspace()}/tests generating.")
+
async def _run_code(self, msg):
file_info = eval(msg.content)
development_file_path = file_info["file_path"]
@@ -110,17 +117,14 @@ class QaEngineer(Role):
test_code=test_code,
test_file_name=file_info["test_file_name"],
command=file_info["command"],
- working_directory=proj_dir, # workspace/package_name, will run tests/test_xxx.py here
- additional_python_paths=[development_code_dir], # workspace/package_name/package_name,
- # import statement inside package code needs this
+ working_directory=proj_dir, # workspace/package_name, will run tests/test_xxx.py here
+ additional_python_paths=[development_code_dir], # workspace/package_name/package_name,
+ # import statement inside package code needs this
)
- recipient = parse_recipient(result_msg) # the recipient might be Engineer or myself
+ recipient = parse_recipient(result_msg) # the recipient might be Engineer or myself
content = str(file_info) + FILENAME_CODE_SEP + result_msg
- msg = Message(
- content=content, role=self.profile, cause_by=RunCode,
- sent_from=self.profile, send_to=recipient
- )
+ msg = Message(content=content, role=self.profile, cause_by=RunCode, sent_from=self.profile, send_to=recipient)
self._publish_message(msg)
async def _debug_error(self, msg):
@@ -128,21 +132,27 @@ class QaEngineer(Role):
file_name, code = await DebugError().run(context)
if file_name:
self.write_file(file_name, code)
- recipient = msg.sent_from # send back to the one who ran the code for another run, might be one's self
- msg = Message(content=file_info, role=self.profile, cause_by=DebugError, sent_from=self.profile, send_to=recipient)
+ recipient = msg.sent_from # send back to the one who ran the code for another run, might be one's self
+ msg = Message(
+ content=file_info, role=self.profile, cause_by=DebugError, sent_from=self.profile, send_to=recipient
+ )
self._publish_message(msg)
-
+
async def _observe(self) -> int:
await super()._observe()
- self._rc.news = [msg for msg in self._rc.news \
- if msg.send_to == self.profile] # only relevant msgs count as observed news
+ self._rc.news = [
+ msg for msg in self._rc.news if msg.send_to == self.profile
+ ] # only relevant msgs count as observed news
return len(self._rc.news)
async def _act(self) -> Message:
if self.test_round > self.test_round_allowed:
result_msg = Message(
content=f"Exceeding {self.test_round_allowed} rounds of tests, skip (writing code counts as a round, too)",
- role=self.profile, cause_by=WriteTest, sent_from=self.profile, send_to=""
+ role=self.profile,
+ cause_by=WriteTest,
+ sent_from=self.profile,
+ send_to="",
)
return result_msg
@@ -161,6 +171,9 @@ class QaEngineer(Role):
self.test_round += 1
result_msg = Message(
content=f"Round {self.test_round} of tests done",
- role=self.profile, cause_by=WriteTest, sent_from=self.profile, send_to=""
+ role=self.profile,
+ cause_by=WriteTest,
+ sent_from=self.profile,
+ send_to="",
)
return result_msg
diff --git a/metagpt/roles/researcher.py b/metagpt/roles/researcher.py
new file mode 100644
index 000000000..815cfa172
--- /dev/null
+++ b/metagpt/roles/researcher.py
@@ -0,0 +1,93 @@
+#!/usr/bin/env python
+
+import asyncio
+
+from pydantic import BaseModel
+
+from metagpt.actions import CollectLinks, ConductResearch, WebBrowseAndSummarize
+from metagpt.actions.research import get_research_system_text
+from metagpt.const import RESEARCH_PATH
+from metagpt.logs import logger
+from metagpt.roles import Role
+from metagpt.schema import Message
+
+
+class Report(BaseModel):
+ topic: str
+ links: dict[str, list[str]] = None
+ summaries: list[tuple[str, str]] = None
+ content: str = ""
+
+
+class Researcher(Role):
+ def __init__(
+ self,
+ name: str = "David",
+ profile: str = "Researcher",
+ goal: str = "Gather information and conduct research",
+ constraints: str = "Ensure accuracy and relevance of information",
+ language: str = "en-us",
+ **kwargs,
+ ):
+ super().__init__(name, profile, goal, constraints, **kwargs)
+ self._init_actions([CollectLinks(name), WebBrowseAndSummarize(name), ConductResearch(name)])
+ self.language = language
+ if language not in ("en-us", "zh-cn"):
+ logger.warning(f"The language `{language}` has not been tested, it may not work.")
+
+ async def _think(self) -> None:
+ if self._rc.todo is None:
+ self._set_state(0)
+ return
+
+ if self._rc.state + 1 < len(self._states):
+ self._set_state(self._rc.state + 1)
+ else:
+ self._rc.todo = None
+
+ async def _act(self) -> Message:
+ logger.info(f"{self._setting}: ready to {self._rc.todo}")
+ todo = self._rc.todo
+ msg = self._rc.memory.get(k=1)[0]
+ if isinstance(msg.instruct_content, Report):
+ instruct_content = msg.instruct_content
+ topic = instruct_content.topic
+ else:
+ topic = msg.content
+
+ research_system_text = get_research_system_text(topic, self.language)
+ if isinstance(todo, CollectLinks):
+ links = await todo.run(topic, 4, 4)
+ ret = Message("", Report(topic=topic, links=links), role=self.profile, cause_by=type(todo))
+ elif isinstance(todo, WebBrowseAndSummarize):
+ links = instruct_content.links
+ todos = (todo.run(*url, query=query, system_text=research_system_text) for (query, url) in links.items())
+ summaries = await asyncio.gather(*todos)
+ summaries = list((url, summary) for i in summaries for (url, summary) in i.items() if summary)
+ ret = Message("", Report(topic=topic, summaries=summaries), role=self.profile, cause_by=type(todo))
+ else:
+ summaries = instruct_content.summaries
+ summary_text = "\n---\n".join(f"url: {url}\nsummary: {summary}" for (url, summary) in summaries)
+ content = await self._rc.todo.run(topic, summary_text, system_text=research_system_text)
+ ret = Message("", Report(topic=topic, content=content), role=self.profile, cause_by=type(self._rc.todo))
+ self._rc.memory.add(ret)
+ return ret
+
+ async def _react(self) -> Message:
+ while True:
+ await self._think()
+ if self._rc.todo is None:
+ break
+ msg = await self._act()
+ report = msg.instruct_content
+ self.write_report(report.topic, report.content)
+ return msg
+
+ def write_report(self, topic: str, content: str):
+ filepath = RESEARCH_PATH / f"{topic}.md"
+ filepath.write_text(content)
+
+
+if __name__ == "__main__":
+ role = Researcher(language="en-us")
+ asyncio.run(role.run("dataiku vs. datarobot"))
diff --git a/metagpt/schema.py b/metagpt/schema.py
index 64db39d0d..27f5dd10c 100644
--- a/metagpt/schema.py
+++ b/metagpt/schema.py
@@ -46,21 +46,27 @@ class Message:
@dataclass
class UserMessage(Message):
- """便于支持OpenAI的消息"""
+ """便于支持OpenAI的消息
+ Facilitate support for OpenAI messages
+ """
def __init__(self, content: str):
super().__init__(content, 'user')
@dataclass
class SystemMessage(Message):
- """便于支持OpenAI的消息"""
+ """便于支持OpenAI的消息
+ Facilitate support for OpenAI messages
+ """
def __init__(self, content: str):
super().__init__(content, 'system')
@dataclass
class AIMessage(Message):
- """便于支持OpenAI的消息"""
+ """便于支持OpenAI的消息
+ Facilitate support for OpenAI messages
+ """
def __init__(self, content: str):
super().__init__(content, 'assistant')
diff --git a/metagpt/tools/__init__.py b/metagpt/tools/__init__.py
index f9b7abc52..e1f921c05 100644
--- a/metagpt/tools/__init__.py
+++ b/metagpt/tools/__init__.py
@@ -14,6 +14,7 @@ class SearchEngineType(Enum):
SERPAPI_GOOGLE = auto()
DIRECT_GOOGLE = auto()
SERPER_GOOGLE = auto()
+ DUCK_DUCK_GO = auto()
CUSTOM_ENGINE = auto()
diff --git a/metagpt/tools/sd_engine.py b/metagpt/tools/sd_engine.py
index e462f1bda..a63dbe5ac 100644
--- a/metagpt/tools/sd_engine.py
+++ b/metagpt/tools/sd_engine.py
@@ -2,29 +2,27 @@
# @Date : 2023/7/19 16:28
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
-import os
import asyncio
+import base64
+import io
+import json
+import os
from os.path import join
from typing import List
-import json
-import io
-import base64
from aiohttp import ClientSession
from PIL import Image, PngImagePlugin
-from metagpt.logs import logger
from metagpt.config import Config
from metagpt.const import WORKSPACE_ROOT
+from metagpt.logs import logger
config = Config()
payload = {
"prompt": "",
"negative_prompt": "(easynegative:0.8),black, dark,Low resolution",
- "override_settings": {
- "sd_model_checkpoint": "galaxytimemachinesGTM_photoV20"
- },
+ "override_settings": {"sd_model_checkpoint": "galaxytimemachinesGTM_photoV20"},
"seed": -1,
"batch_size": 1,
"n_iter": 1,
@@ -36,21 +34,20 @@ payload = {
"tiling": False,
"do_not_save_samples": False,
"do_not_save_grid": False,
- 'enable_hr': False,
- 'hr_scale': 2,
- 'hr_upscaler': 'Latent',
- 'hr_second_pass_steps': 0,
- 'hr_resize_x': 0,
- 'hr_resize_y': 0,
- 'hr_upscale_to_x': 0,
- 'hr_upscale_to_y': 0,
- 'truncate_x': 0,
- 'truncate_y': 0,
- 'applied_old_hires_behavior_to': None,
+ "enable_hr": False,
+ "hr_scale": 2,
+ "hr_upscaler": "Latent",
+ "hr_second_pass_steps": 0,
+ "hr_resize_x": 0,
+ "hr_resize_y": 0,
+ "hr_upscale_to_x": 0,
+ "hr_upscale_to_y": 0,
+ "truncate_x": 0,
+ "truncate_y": 0,
+ "applied_old_hires_behavior_to": None,
"eta": None,
-
"sampler_index": "DPM++ SDE Karras",
- "alwayson_scripts": {}
+ "alwayson_scripts": {},
}
default_negative_prompt = "(easynegative:0.8),black, dark,Low resolution"
@@ -60,14 +57,20 @@ class SDEngine:
def __init__(self):
# Initialize the SDEngine with configuration
self.config = Config()
- self.sd_url = self.config.get('SD_URL')
+ self.sd_url = self.config.get("SD_URL")
self.sd_t2i_url = f"{self.sd_url}{self.config.get('SD_T2I_API')}"
# Define default payload settings for SD API
self.payload = payload
logger.info(self.sd_t2i_url)
-
- def construct_payload(self, prompt, negtive_prompt=default_negative_prompt, width=512, height=512,
- sd_model="galaxytimemachinesGTM_photoV20"):
+
+ def construct_payload(
+ self,
+ prompt,
+ negtive_prompt=default_negative_prompt,
+ width=512,
+ height=512,
+ sd_model="galaxytimemachinesGTM_photoV20",
+ ):
# Configure the payload with provided inputs
self.payload["prompt"] = prompt
self.payload["negtive_prompt"] = negtive_prompt
@@ -76,13 +79,13 @@ class SDEngine:
self.payload["override_settings"]["sd_model_checkpoint"] = sd_model
logger.info(f"call sd payload is {self.payload}")
return self.payload
-
+
def _save(self, imgs, save_name=""):
- save_dir = WORKSPACE_ROOT / "resources"/"SD_Output"
+ save_dir = WORKSPACE_ROOT / "resources" / "SD_Output"
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
batch_decode_base64_to_image(imgs, save_dir, save_name=save_name)
-
+
async def run_t2i(self, prompts: List):
# Asynchronously run the SD API for multiple prompts
session = ClientSession()
@@ -90,25 +93,26 @@ class SDEngine:
results = await self.run(url=self.sd_t2i_url, payload=payload, session=session)
self._save(results, save_name=f"output_{payload_idx}")
await session.close()
-
+
async def run(self, url, payload, session):
# Perform the HTTP POST request to the SD API
async with session.post(url, json=payload, timeout=600) as rsp:
data = await rsp.read()
-
+
rsp_json = json.loads(data)
- imgs = rsp_json['images']
+ imgs = rsp_json["images"]
logger.info(f"callback rsp json is {rsp_json.keys()}")
return imgs
-
+
async def run_i2i(self):
# todo: 添加图生图接口调用
raise NotImplementedError
-
+
async def run_sam(self):
# todo:添加SAM接口调用
raise NotImplementedError
+
def decode_base64_to_image(img, save_name):
image = Image.open(io.BytesIO(base64.b64decode(img.split(",", 1)[0])))
pnginfo = PngImagePlugin.PngInfo()
@@ -124,12 +128,10 @@ def batch_decode_base64_to_image(imgs, save_dir="", save_name=""):
if __name__ == "__main__":
- import asyncio
-
engine = SDEngine()
prompt = "pixel style, game design, a game interface should be minimalistic and intuitive with the score and high score displayed at the top. The snake and its food should be easily distinguishable. The game should have a simple color scheme, with a contrasting color for the snake and its food. Complete interface boundary"
-
+
engine.construct_payload(prompt)
-
+
event_loop = asyncio.get_event_loop()
event_loop.run_until_complete(engine.run_t2i(prompt))
diff --git a/metagpt/tools/search_engine.py b/metagpt/tools/search_engine.py
index cfd4e8789..d28700054 100644
--- a/metagpt/tools/search_engine.py
+++ b/metagpt/tools/search_engine.py
@@ -7,122 +7,76 @@
"""
from __future__ import annotations
-import json
+import importlib
+from typing import Callable, Coroutine, Literal, overload
-from metagpt.config import Config
-from metagpt.logs import logger
-from metagpt.tools.search_engine_serpapi import SerpAPIWrapper
-from metagpt.tools.search_engine_serper import SerperWrapper
-
-config = Config()
+from metagpt.config import CONFIG
from metagpt.tools import SearchEngineType
class SearchEngine:
- """
- TODO: 合入Google Search 并进行反代
- 注:这里Google需要挂Proxifier或者类似全局代理
- - DDG: https://pypi.org/project/duckduckgo-search/
- - GOOGLE: https://programmablesearchengine.google.com/controlpanel/overview?cx=63f9de531d0e24de9
- """
- def __init__(self, engine=None, run_func=None):
- self.config = Config()
- self.run_func = run_func
- self.engine = engine or self.config.search_engine
+ """Class representing a search engine.
- @classmethod
- def run_google(cls, query, max_results=8):
- # results = ddg(query, max_results=max_results)
- results = google_official_search(query, num_results=max_results)
- logger.info(results)
- return results
+ Args:
+ engine: The search engine type. Defaults to the search engine specified in the config.
+ run_func: The function to run the search. Defaults to None.
- async def run(self, query: str, max_results=8):
- if self.engine == SearchEngineType.SERPAPI_GOOGLE:
- api = SerpAPIWrapper()
- rsp = await api.run(query)
- elif self.engine == SearchEngineType.DIRECT_GOOGLE:
- rsp = SearchEngine.run_google(query, max_results)
- elif self.engine == SearchEngineType.SERPER_GOOGLE:
- api = SerperWrapper()
- rsp = await api.run(query)
- elif self.engine == SearchEngineType.CUSTOM_ENGINE:
- rsp = self.run_func(query)
+ Attributes:
+ run_func: The function to run the search.
+ engine: The search engine type.
+ """
+ def __init__(
+ self,
+ engine: SearchEngineType | None = None,
+ run_func: Callable[[str, int, bool], Coroutine[None, None, str | list[str]]] = None,
+ ):
+ engine = engine or CONFIG.search_engine
+ if engine == SearchEngineType.SERPAPI_GOOGLE:
+ module = "metagpt.tools.search_engine_serpapi"
+ run_func = importlib.import_module(module).SerpAPIWrapper().run
+ elif engine == SearchEngineType.SERPER_GOOGLE:
+ module = "metagpt.tools.search_engine_serper"
+ run_func = importlib.import_module(module).SerperWrapper().run
+ elif engine == SearchEngineType.DIRECT_GOOGLE:
+ module = "metagpt.tools.search_engine_googleapi"
+ run_func = importlib.import_module(module).GoogleAPIWrapper().run
+ elif engine == SearchEngineType.DUCK_DUCK_GO:
+ module = "metagpt.tools.search_engine_ddg"
+ run_func = importlib.import_module(module).DDGAPIWrapper().run
+ elif engine == SearchEngineType.CUSTOM_ENGINE:
+ pass # run_func = run_func
else:
raise NotImplementedError
- return rsp
+ self.engine = engine
+ self.run_func = run_func
+ @overload
+ def run(
+ self,
+ query: str,
+ max_results: int = 8,
+ as_string: Literal[True] = True,
+ ) -> str:
+ ...
-def google_official_search(query: str, num_results: int = 8, focus=['snippet', 'link', 'title']) -> dict | list[dict]:
- """Return the results of a Google search using the official Google API
+ @overload
+ def run(
+ self,
+ query: str,
+ max_results: int = 8,
+ as_string: Literal[False] = False,
+ ) -> list[dict[str, str]]:
+ ...
- Args:
- query (str): The search query.
- num_results (int): The number of results to return.
+ async def run(self, query: str, max_results: int = 8, as_string: bool = True) -> str | list[dict[str, str]]:
+ """Run a search query.
- Returns:
- str: The results of the search.
- """
+ Args:
+ query: The search query.
+ max_results: The maximum number of results to return. Defaults to 8.
+ as_string: Whether to return the results as a string or a list of dictionaries. Defaults to True.
- from googleapiclient.discovery import build
- from googleapiclient.errors import HttpError
-
- try:
- api_key = config.google_api_key
- custom_search_engine_id = config.google_cse_id
-
- with build("customsearch", "v1", developerKey=api_key) as service:
-
- result = (
- service.cse()
- .list(q=query, cx=custom_search_engine_id, num=num_results)
- .execute()
- )
- logger.info(result)
- # Extract the search result items from the response
- search_results = result.get("items", [])
-
- # Create a list of only the URLs from the search results
- search_results_details = [{i: j for i, j in item_dict.items() if i in focus} for item_dict in search_results]
-
- except HttpError as e:
- # Handle errors in the API call
- error_details = json.loads(e.content.decode())
-
- # Check if the error is related to an invalid or missing API key
- if error_details.get("error", {}).get(
- "code"
- ) == 403 and "invalid API key" in error_details.get("error", {}).get(
- "message", ""
- ):
- return "Error: The provided Google API key is invalid or missing."
- else:
- return f"Error: {e}"
- # google_result can be a list or a string depending on the search results
-
- # Return the list of search result URLs
- return search_results_details
-
-
-def safe_google_results(results: str | list) -> str:
- """
- Return the results of a google search in a safe format.
-
- Args:
- results (str | list): The search results.
-
- Returns:
- str: The results of the search.
- """
- if isinstance(results, list):
- safe_message = json.dumps(
- # FIXME: # .encode("utf-8", "ignore") 这里去掉了,但是AutoGPT里有,很奇怪
- [result for result in results]
- )
- else:
- safe_message = results.encode("utf-8", "ignore").decode("utf-8")
- return safe_message
-
-
-if __name__ == '__main__':
- SearchEngine.run(query='wtf')
+ Returns:
+ The search results as a string or a list of dictionaries.
+ """
+ return await self.run_func(query, max_results=max_results, as_string=as_string)
diff --git a/metagpt/tools/search_engine_ddg.py b/metagpt/tools/search_engine_ddg.py
new file mode 100644
index 000000000..c054afed1
--- /dev/null
+++ b/metagpt/tools/search_engine_ddg.py
@@ -0,0 +1,107 @@
+#!/usr/bin/env python
+
+from __future__ import annotations
+
+import asyncio
+import json
+from concurrent import futures
+from typing import Literal, overload
+
+from duckduckgo_search import DDGS
+from googleapiclient.errors import HttpError
+
+from metagpt.config import CONFIG
+from metagpt.logs import logger
+
+
+class DDGAPIWrapper:
+ """Wrapper around duckduckgo_search API.
+
+ To use this module, you should have the `duckduckgo_search` Python package installed.
+ """
+ def __init__(
+ self,
+ *,
+ loop: asyncio.AbstractEventLoop | None = None,
+ executor: futures.Executor | None = None,
+ ):
+ kwargs = {}
+ if CONFIG.global_proxy:
+ kwargs["proxies"] = CONFIG.global_proxy
+ self.loop = loop
+ self.executor = executor
+ self.ddgs = DDGS(**kwargs)
+
+ @overload
+ def run(
+ self,
+ query: str,
+ max_results: int = 8,
+ as_string: Literal[True] = True,
+ focus: list[str] | None = None,
+ ) -> str:
+ ...
+
+ @overload
+ def run(
+ self,
+ query: str,
+ max_results: int = 8,
+ as_string: Literal[False] = False,
+ focus: list[str] | None = None,
+ ) -> list[dict[str, str]]:
+ ...
+
+ async def run(
+ self,
+ query: str,
+ max_results: int = 8,
+ as_string: bool = True,
+ ) -> str | list[dict]:
+ """Return the results of a Google search using the official Google API
+
+ Args:
+ query: The search query.
+ max_results: The number of results to return.
+ as_string: A boolean flag to determine the return type of the results. If True, the function will
+ return a formatted string with the search results. If False, it will return a list of dictionaries
+ containing detailed information about each search result.
+
+ Returns:
+ The results of the search.
+ """
+ loop = self.loop or asyncio.get_event_loop()
+ future = loop.run_in_executor(
+ self.executor,
+ self._search_from_ddgs,
+ query,
+ max_results,
+ )
+ try:
+ search_results = await future
+ # Extract the search result items from the response
+
+ except HttpError as e:
+ # Handle errors in the API call
+ logger.exception(f"fail to search {query} for {e}")
+ search_results = []
+
+ # Return the list of search result URLs
+ if as_string:
+ return json.dumps(search_results, ensure_ascii=False)
+ return search_results
+
+ def _search_from_ddgs(self, query: str, max_results: int):
+ return [
+ {
+ "link": i["href"],
+ "snippet": i["body"],
+ "title": i["title"]
+ } for (_, i) in zip(range(max_results), self.ddgs.text(query))
+ ]
+
+
+if __name__ == "__main__":
+ import fire
+
+ fire.Fire(DDGAPIWrapper().run)
diff --git a/metagpt/tools/search_engine_googleapi.py b/metagpt/tools/search_engine_googleapi.py
new file mode 100644
index 000000000..c226ca8d2
--- /dev/null
+++ b/metagpt/tools/search_engine_googleapi.py
@@ -0,0 +1,117 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+from __future__ import annotations
+
+import asyncio
+import json
+from concurrent import futures
+from urllib.parse import urlparse
+
+import httplib2
+from googleapiclient.discovery import build
+from googleapiclient.errors import HttpError
+
+from metagpt.config import CONFIG
+from metagpt.logs import logger
+
+
+class GoogleAPIWrapper:
+ """Wrapper around GoogleAPI.
+
+ To use this module, you should have the `google-api-python-client` Python package installed
+ and set property values for the configurations `GOOGLE_API_KEY` and `GOOGLE_CSE_ID`. See
+ https://programmablesearchengine.google.com/controlpanel/all.
+ """
+ def __init__(
+ self,
+ *,
+ loop: asyncio.AbstractEventLoop | None = None,
+ executor: futures.Executor | None = None,
+ ):
+ build_kwargs = {"developerKey": CONFIG.google_api_key}
+ if CONFIG.global_proxy:
+ parse_result = urlparse(CONFIG.global_proxy)
+ proxy_type = parse_result.scheme
+ if proxy_type == "https":
+ proxy_type = "http"
+ build_kwargs["http"] = httplib2.Http(
+ proxy_info=httplib2.ProxyInfo(
+ getattr(httplib2.socks, f"PROXY_TYPE_{proxy_type.upper()}"),
+ parse_result.hostname,
+ parse_result.port,
+ ),
+ )
+ service = build("customsearch", "v1", **build_kwargs)
+ self.google_api_client = service.cse()
+ self.custom_search_engine_id = CONFIG.google_cse_id
+ self.loop = loop
+ self.executor = executor
+
+ async def run(
+ self,
+ query: str,
+ max_results: int = 8,
+ as_string: bool = True,
+ focus: list[str] | None = None,
+ ) -> str | list[dict]:
+ """Return the results of a Google search using the official Google API.
+
+ Args:
+ query: The search query.
+ max_results: The number of results to return.
+ as_string: A boolean flag to determine the return type of the results. If True, the function will
+ return a formatted string with the search results. If False, it will return a list of dictionaries
+ containing detailed information about each search result.
+ focus: Specific information to be focused on from each search result.
+
+ Returns:
+ The results of the search.
+ """
+ loop = self.loop or asyncio.get_event_loop()
+ future = loop.run_in_executor(
+ self.executor,
+ self.google_api_client.list(
+ q=query,
+ num=max_results,
+ cx=self.custom_search_engine_id
+ ).execute
+ )
+ try:
+ result = await future
+ # Extract the search result items from the response
+ search_results = result.get("items", [])
+
+ except HttpError as e:
+ # Handle errors in the API call
+ logger.exception(f"fail to search {query} for {e}")
+ search_results = []
+
+ focus = focus or ["snippet", "link", "title"]
+ details = [{i: j for i, j in item_dict.items() if i in focus} for item_dict in search_results]
+ # Return the list of search result URLs
+ if as_string:
+ return safe_google_results(details)
+
+ return details
+
+
+def safe_google_results(results: str | list) -> str:
+ """Return the results of a google search in a safe format.
+
+ Args:
+ results: The search results.
+
+ Returns:
+ The results of the search.
+ """
+ if isinstance(results, list):
+ safe_message = json.dumps([result for result in results])
+ else:
+ safe_message = results.encode("utf-8", "ignore").decode("utf-8")
+ return safe_message
+
+
+if __name__ == "__main__":
+ import fire
+
+ fire.Fire(GoogleAPIWrapper().run)
diff --git a/metagpt/tools/search_engine_serpapi.py b/metagpt/tools/search_engine_serpapi.py
index 28033f237..3d2d7cfe4 100644
--- a/metagpt/tools/search_engine_serpapi.py
+++ b/metagpt/tools/search_engine_serpapi.py
@@ -37,16 +37,17 @@ class SerpAPIWrapper(BaseModel):
class Config:
arbitrary_types_allowed = True
- async def run(self, query: str, **kwargs: Any) -> str:
+ async def run(self, query: str, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str:
"""Run query through SerpAPI and parse result async."""
- return self._process_response(await self.results(query))
+ return self._process_response(await self.results(query, max_results), as_string=as_string)
- async def results(self, query: str) -> dict:
+ async def results(self, query: str, max_results: int) -> dict:
"""Use aiohttp to run query through SerpAPI and return the results async."""
def construct_url_and_params() -> Tuple[str, Dict[str, str]]:
params = self.get_params(query)
params["source"] = "python"
+ params["num"] = max_results
if self.serpapi_api_key:
params["serp_api_key"] = self.serpapi_api_key
params["output"] = "json"
@@ -74,10 +75,10 @@ class SerpAPIWrapper(BaseModel):
return params
@staticmethod
- def _process_response(res: dict) -> str:
+ def _process_response(res: dict, as_string: bool) -> str:
"""Process response from SerpAPI."""
# logger.debug(res)
- focus = ['title', 'snippet', 'link']
+ focus = ["title", "snippet", "link"]
get_focused = lambda x: {i: j for i, j in x.items() if i in focus}
if "error" in res.keys():
@@ -86,20 +87,11 @@ class SerpAPIWrapper(BaseModel):
toret = res["answer_box"]["answer"]
elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys():
toret = res["answer_box"]["snippet"]
- elif (
- "answer_box" in res.keys()
- and "snippet_highlighted_words" in res["answer_box"].keys()
- ):
+ elif "answer_box" in res.keys() and "snippet_highlighted_words" in res["answer_box"].keys():
toret = res["answer_box"]["snippet_highlighted_words"][0]
- elif (
- "sports_results" in res.keys()
- and "game_spotlight" in res["sports_results"].keys()
- ):
+ elif "sports_results" in res.keys() and "game_spotlight" in res["sports_results"].keys():
toret = res["sports_results"]["game_spotlight"]
- elif (
- "knowledge_graph" in res.keys()
- and "description" in res["knowledge_graph"].keys()
- ):
+ elif "knowledge_graph" in res.keys() and "description" in res["knowledge_graph"].keys():
toret = res["knowledge_graph"]["description"]
elif "snippet" in res["organic_results"][0].keys():
toret = res["organic_results"][0]["snippet"]
@@ -112,4 +104,10 @@ class SerpAPIWrapper(BaseModel):
if res.get("organic_results"):
toret_l += [get_focused(i) for i in res.get("organic_results")]
- return str(toret) + '\n' + str(toret_l)
+ return str(toret) + '\n' + str(toret_l) if as_string else toret_l
+
+
+if __name__ == "__main__":
+ import fire
+
+ fire.Fire(SerpAPIWrapper().run)
diff --git a/metagpt/tools/search_engine_serper.py b/metagpt/tools/search_engine_serper.py
index 80c2f8001..2ae2c3b7d 100644
--- a/metagpt/tools/search_engine_serper.py
+++ b/metagpt/tools/search_engine_serper.py
@@ -36,16 +36,19 @@ class SerperWrapper(BaseModel):
class Config:
arbitrary_types_allowed = True
- async def run(self, query: str, **kwargs: Any) -> str:
+ async def run(self, query: str, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str:
"""Run query through Serper and parse result async."""
- queries = query.split("\n")
- return "\n".join([self._process_response(res) for res in await self.results(queries)])
+ if isinstance(query, str):
+ return self._process_response((await self.results([query], max_results))[0], as_string=as_string)
+ else:
+ results = [self._process_response(res, as_string) for res in await self.results(query, max_results)]
+ return "\n".join(results) if as_string else results
- async def results(self, queries: list[str]) -> dict:
+ async def results(self, queries: list[str], max_results: int = 8) -> dict:
"""Use aiohttp to run query through Serper and return the results async."""
def construct_url_and_payload_and_headers() -> Tuple[str, Dict[str, str]]:
- payloads = self.get_payloads(queries)
+ payloads = self.get_payloads(queries, max_results)
url = "https://google.serper.dev/search"
headers = self.get_headers()
return url, payloads, headers
@@ -61,12 +64,13 @@ class SerperWrapper(BaseModel):
return res
- def get_payloads(self, queries: list[str]) -> Dict[str, str]:
+ def get_payloads(self, queries: list[str], max_results: int) -> Dict[str, str]:
"""Get payloads for Serper."""
payloads = []
for query in queries:
_payload = {
"q": query,
+ "num": max_results,
}
payloads.append({**self.payload, **_payload})
return json.dumps(payloads, sort_keys=True)
@@ -79,7 +83,7 @@ class SerperWrapper(BaseModel):
return headers
@staticmethod
- def _process_response(res: dict) -> str:
+ def _process_response(res: dict, as_string: bool = False) -> str:
"""Process response from SerpAPI."""
# logger.debug(res)
focus = ['title', 'snippet', 'link']
@@ -117,4 +121,10 @@ class SerperWrapper(BaseModel):
if res.get("organic"):
toret_l += [get_focused(i) for i in res.get("organic")]
- return str(toret) + '\n' + str(toret_l)
+ return str(toret) + '\n' + str(toret_l) if as_string else toret_l
+
+
+if __name__ == "__main__":
+ import fire
+
+ fire.Fire(SerperWrapper().run)
diff --git a/metagpt/tools/web_browser_engine.py b/metagpt/tools/web_browser_engine.py
index d1f83934f..453d87f31 100644
--- a/metagpt/tools/web_browser_engine.py
+++ b/metagpt/tools/web_browser_engine.py
@@ -1,22 +1,20 @@
#!/usr/bin/env python
from __future__ import annotations
-import asyncio
-import importlib
-from typing import Any, Callable, Coroutine, overload
+import importlib
+from typing import Any, Callable, Coroutine, Literal, overload
from metagpt.config import CONFIG
from metagpt.tools import WebBrowserEngineType
-from bs4 import BeautifulSoup
+from metagpt.utils.parse_html import WebPage
class WebBrowserEngine:
def __init__(
self,
engine: WebBrowserEngineType | None = None,
- run_func: Callable[..., Coroutine[Any, Any, str | list[str]]] | None = None,
- parse_func: Callable[[str], str] | None = None,
+ run_func: Callable[..., Coroutine[Any, Any, WebPage | list[WebPage]]] | None = None,
):
engine = engine or CONFIG.web_browser_engine
@@ -30,30 +28,25 @@ class WebBrowserEngine:
run_func = run_func
else:
raise NotImplementedError
- self.parse_func = parse_func or get_page_content
self.run_func = run_func
self.engine = engine
@overload
- async def run(self, url: str) -> str:
+ async def run(self, url: str) -> WebPage:
...
@overload
- async def run(self, url: str, *urls: str) -> list[str]:
+ async def run(self, url: str, *urls: str) -> list[WebPage]:
...
- async def run(self, url: str, *urls: str) -> str | list[str]:
- page = await self.run_func(url, *urls)
- if isinstance(page, str):
- return self.parse_func(page)
- return [self.parse_func(i) for i in page]
-
-
-def get_page_content(page: str):
- soup = BeautifulSoup(page, "html.parser")
- return "\n".join(i.text.strip() for i in soup.find_all(["h1", "h2", "h3", "h4", "h5", "p", "pre"]))
+ async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]:
+ return await self.run_func(url, *urls)
if __name__ == "__main__":
- text = asyncio.run(WebBrowserEngine().run("https://fuzhi.ai/"))
- print(text)
+ import fire
+
+ async def main(url: str, *urls: str, engine_type: Literal["playwright", "selenium"] = "playwright", **kwargs):
+ return await WebBrowserEngine(WebBrowserEngineType(engine_type), **kwargs).run(url, *urls)
+
+ fire.Fire(main)
diff --git a/metagpt/tools/web_browser_engine_playwright.py b/metagpt/tools/web_browser_engine_playwright.py
index ae8644cce..030e7701b 100644
--- a/metagpt/tools/web_browser_engine_playwright.py
+++ b/metagpt/tools/web_browser_engine_playwright.py
@@ -2,12 +2,15 @@
from __future__ import annotations
import asyncio
-from pathlib import Path
import sys
+from pathlib import Path
from typing import Literal
+
from playwright.async_api import async_playwright
+
from metagpt.config import CONFIG
from metagpt.logs import logger
+from metagpt.utils.parse_html import WebPage
class PlaywrightWrapper:
@@ -16,7 +19,7 @@ class PlaywrightWrapper:
To use this module, you should have the `playwright` Python package installed and ensure that
the required browsers are also installed. You can install playwright by running the command
`pip install metagpt[playwright]` and download the necessary browser binaries by running the
- command `playwright install` for the first time."
+ command `playwright install` for the first time.
"""
def __init__(
@@ -40,27 +43,30 @@ class PlaywrightWrapper:
self._context_kwargs = context_kwargs
self._has_run_precheck = False
- async def run(self, url: str, *urls: str) -> str | list[str]:
+ async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]:
async with async_playwright() as ap:
browser_type = getattr(ap, self.browser_type)
await self._run_precheck(browser_type)
browser = await browser_type.launch(**self.launch_kwargs)
-
- async def _scrape(url):
- context = await browser.new_context(**self._context_kwargs)
- page = await context.new_page()
- async with page:
- try:
- await page.goto(url)
- await page.evaluate("window.scrollTo(0, document.body.scrollHeight)")
- content = await page.content()
- return content
- except Exception as e:
- return f"Fail to load page content for {e}"
+ _scrape = self._scrape
if urls:
- return await asyncio.gather(_scrape(url), *(_scrape(i) for i in urls))
- return await _scrape(url)
+ return await asyncio.gather(_scrape(browser, url), *(_scrape(browser, i) for i in urls))
+ return await _scrape(browser, url)
+
+ async def _scrape(self, browser, url):
+ context = await browser.new_context(**self._context_kwargs)
+ page = await context.new_page()
+ async with page:
+ try:
+ await page.goto(url)
+ await page.evaluate("window.scrollTo(0, document.body.scrollHeight)")
+ html = await page.content()
+ inner_text = await page.evaluate("() => document.body.innerText")
+ except Exception as e:
+ inner_text = f"Fail to load page content for {e}"
+ html = ""
+ return WebPage(inner_text=inner_text, html=html, url=url)
async def _run_precheck(self, browser_type):
if self._has_run_precheck:
@@ -72,6 +78,10 @@ class PlaywrightWrapper:
if CONFIG.global_proxy:
kwargs["env"] = {"ALL_PROXY": CONFIG.global_proxy}
await _install_browsers(self.browser_type, **kwargs)
+
+ if self._has_run_precheck:
+ return
+
if not executable_path.exists():
parts = executable_path.parts
available_paths = list(Path(*parts[:-3]).glob(f"{self.browser_type}-*"))
@@ -85,25 +95,37 @@ class PlaywrightWrapper:
self._has_run_precheck = True
+def _get_install_lock():
+ global _install_lock
+ if _install_lock is None:
+ _install_lock = asyncio.Lock()
+ return _install_lock
+
+
async def _install_browsers(*browsers, **kwargs) -> None:
- process = await asyncio.create_subprocess_exec(
- sys.executable,
- "-m",
- "playwright",
- "install",
- *browsers,
- "--with-deps",
- stdout=asyncio.subprocess.PIPE,
- stderr=asyncio.subprocess.PIPE,
- **kwargs,
- )
+ async with _get_install_lock():
+ browsers = [i for i in browsers if i not in _install_cache]
+ if not browsers:
+ return
+ process = await asyncio.create_subprocess_exec(
+ sys.executable,
+ "-m",
+ "playwright",
+ "install",
+ *browsers,
+ # "--with-deps",
+ stdout=asyncio.subprocess.PIPE,
+ stderr=asyncio.subprocess.PIPE,
+ **kwargs,
+ )
- await asyncio.gather(_log_stream(process.stdout, logger.info), _log_stream(process.stderr, logger.warning))
+ await asyncio.gather(_log_stream(process.stdout, logger.info), _log_stream(process.stderr, logger.warning))
- if await process.wait() == 0:
- logger.info(f"Install browser for playwright successfully.")
- else:
- logger.warning(f"Fail to install browser for playwright.")
+ if await process.wait() == 0:
+ logger.info("Install browser for playwright successfully.")
+ else:
+ logger.warning("Fail to install browser for playwright.")
+ _install_cache.update(browsers)
async def _log_stream(sr, log_func):
@@ -114,8 +136,14 @@ async def _log_stream(sr, log_func):
log_func(f"[playwright install browser]: {line.decode().strip()}")
+_install_lock: asyncio.Lock = None
+_install_cache = set()
+
+
if __name__ == "__main__":
- for i in ("chromium", "firefox", "webkit"):
- text = asyncio.run(PlaywrightWrapper(i).run("https://httpbin.org/ip"))
- print(text)
- print(i)
+ import fire
+
+ async def main(url: str, *urls: str, browser_type: str = "chromium", **kwargs):
+ return await PlaywrightWrapper(browser_type, **kwargs).run(url, *urls)
+
+ fire.Fire(main)
diff --git a/metagpt/tools/web_browser_engine_selenium.py b/metagpt/tools/web_browser_engine_selenium.py
index bd8a456ea..d727709b8 100644
--- a/metagpt/tools/web_browser_engine_selenium.py
+++ b/metagpt/tools/web_browser_engine_selenium.py
@@ -2,16 +2,17 @@
from __future__ import annotations
import asyncio
-from copy import deepcopy
import importlib
+from concurrent import futures
+from copy import deepcopy
from typing import Literal
-from metagpt.config import CONFIG
-import asyncio
from selenium.webdriver.common.by import By
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.support.wait import WebDriverWait
-from concurrent import futures
+
+from metagpt.config import CONFIG
+from metagpt.utils.parse_html import WebPage
class SeleniumWrapper:
@@ -48,7 +49,7 @@ class SeleniumWrapper:
self.loop = loop
self.executor = executor
- async def run(self, url: str, *urls: str) -> str | list[str]:
+ async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]:
await self._run_precheck()
_scrape = lambda url: self.loop.run_in_executor(self.executor, self._scrape_website, url)
@@ -69,9 +70,15 @@ class SeleniumWrapper:
def _scrape_website(self, url):
with self._get_driver() as driver:
- driver.get(url)
- WebDriverWait(driver, 30).until(EC.presence_of_element_located((By.TAG_NAME, "body")))
- return driver.page_source
+ try:
+ driver.get(url)
+ WebDriverWait(driver, 30).until(EC.presence_of_element_located((By.TAG_NAME, "body")))
+ inner_text = driver.execute_script("return document.body.innerText;")
+ html = driver.page_source
+ except Exception as e:
+ inner_text = f"Fail to load page content for {e}"
+ html = ""
+ return WebPage(inner_text=inner_text, html=html, url=url)
_webdriver_manager_types = {
@@ -97,6 +104,7 @@ def _gen_get_driver_func(browser_type, *args, executable_path=None):
def _get_driver():
options = Options()
options.add_argument("--headless")
+ options.add_argument("--enable-javascript")
if browser_type == "chrome":
options.add_argument("--no-sandbox")
for i in args:
@@ -107,5 +115,9 @@ def _gen_get_driver_func(browser_type, *args, executable_path=None):
if __name__ == "__main__":
- text = asyncio.run(SeleniumWrapper("chrome").run("https://fuzhi.ai/"))
- print(text)
+ import fire
+
+ async def main(url: str, *urls: str, browser_type: str = "chrome", **kwargs):
+ return await SeleniumWrapper(browser_type, **kwargs).run(url, *urls)
+
+ fire.Fire(main)
diff --git a/metagpt/utils/__init__.py b/metagpt/utils/__init__.py
index 579308a3b..f13175cf8 100644
--- a/metagpt/utils/__init__.py
+++ b/metagpt/utils/__init__.py
@@ -13,3 +13,12 @@ from metagpt.utils.token_counter import (
count_message_tokens,
count_string_tokens,
)
+
+
+__all__ = [
+ "read_docx",
+ "Singleton",
+ "TOKEN_COSTS",
+ "count_message_tokens",
+ "count_string_tokens",
+]
diff --git a/metagpt/utils/mermaid.py b/metagpt/utils/mermaid.py
index 3788b4743..24aabe8ae 100644
--- a/metagpt/utils/mermaid.py
+++ b/metagpt/utils/mermaid.py
@@ -5,9 +5,9 @@
@Author : alexanderwu
@File : mermaid.py
"""
-import os
import subprocess
from pathlib import Path
+
from metagpt.config import CONFIG
from metagpt.const import PROJECT_ROOT
from metagpt.logs import logger
@@ -24,25 +24,36 @@ def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height
:return: 0 if succed, -1 if failed
"""
# Write the Mermaid code to a temporary file
- tmp = Path(f'{output_file_without_suffix}.mmd')
- tmp.write_text(mermaid_code, encoding='utf-8')
+ tmp = Path(f"{output_file_without_suffix}.mmd")
+ tmp.write_text(mermaid_code, encoding="utf-8")
- if check_cmd_exists('mmdc') != 0:
- logger.warning(
- "RUN `npm install -g @mermaid-js/mermaid-cli` to install mmdc")
+ if check_cmd_exists("mmdc") != 0:
+ logger.warning("RUN `npm install -g @mermaid-js/mermaid-cli` to install mmdc")
return -1
- for suffix in ['pdf', 'svg', 'png']:
- output_file = f'{output_file_without_suffix}.{suffix}'
+ for suffix in ["pdf", "svg", "png"]:
+ output_file = f"{output_file_without_suffix}.{suffix}"
# Call the `mmdc` command to convert the Mermaid code to a PNG
logger.info(f"Generating {output_file}..")
if CONFIG.puppeteer_config:
- subprocess.run([CONFIG.mmdc, '-p', CONFIG.puppeteer_config, '-i', str(tmp), '-o',
- output_file, '-w', str(width), '-H', str(height)])
+ subprocess.run(
+ [
+ CONFIG.mmdc,
+ "-p",
+ CONFIG.puppeteer_config,
+ "-i",
+ str(tmp),
+ "-o",
+ output_file,
+ "-w",
+ str(width),
+ "-H",
+ str(height),
+ ]
+ )
else:
- subprocess.run([CONFIG.mmdc, '-i', str(tmp), '-o',
- output_file, '-w', str(width), '-H', str(height)])
+ subprocess.run([CONFIG.mmdc, "-i", str(tmp), "-o", output_file, "-w", str(width), "-H", str(height)])
return 0
@@ -97,7 +108,7 @@ MMC2 = """sequenceDiagram
SE-->>M: return summary"""
-if __name__ == '__main__':
+if __name__ == "__main__":
# logger.info(print_members(print_members))
- mermaid_to_file(MMC1, PROJECT_ROOT / 'tmp/1.png')
- mermaid_to_file(MMC2, PROJECT_ROOT / 'tmp/2.png')
+ mermaid_to_file(MMC1, PROJECT_ROOT / "tmp/1.png")
+ mermaid_to_file(MMC2, PROJECT_ROOT / "tmp/2.png")
diff --git a/metagpt/utils/parse_html.py b/metagpt/utils/parse_html.py
new file mode 100644
index 000000000..62de26541
--- /dev/null
+++ b/metagpt/utils/parse_html.py
@@ -0,0 +1,57 @@
+#!/usr/bin/env python
+from __future__ import annotations
+
+from typing import Generator, Optional
+from urllib.parse import urljoin, urlparse
+
+from bs4 import BeautifulSoup
+from pydantic import BaseModel
+
+
+class WebPage(BaseModel):
+ inner_text: str
+ html: str
+ url: str
+
+ class Config:
+ underscore_attrs_are_private = True
+
+ _soup : Optional[BeautifulSoup] = None
+ _title: Optional[str] = None
+
+ @property
+ def soup(self) -> BeautifulSoup:
+ if self._soup is None:
+ self._soup = BeautifulSoup(self.html, "html.parser")
+ return self._soup
+
+ @property
+ def title(self):
+ if self._title is None:
+ title_tag = self.soup.find("title")
+ self._title = title_tag.text.strip() if title_tag is not None else ""
+ return self._title
+
+ def get_links(self) -> Generator[str, None, None]:
+ for i in self.soup.find_all("a", href=True):
+ url = i["href"]
+ result = urlparse(url)
+ if not result.scheme and result.path:
+ yield urljoin(self.url, url)
+ elif url.startswith(("http://", "https://")):
+ yield urljoin(self.url, url)
+
+
+def get_html_content(page: str, base: str):
+ soup = _get_soup(page)
+
+ return soup.get_text(strip=True)
+
+
+def _get_soup(page: str):
+ soup = BeautifulSoup(page, "html.parser")
+ # https://stackoverflow.com/questions/1936466/how-to-scrape-only-visible-webpage-text-with-beautifulsoup
+ for s in soup(["style", "script", "[document]", "head", "title"]):
+ s.extract()
+
+ return soup
diff --git a/metagpt/utils/serialize.py b/metagpt/utils/serialize.py
index 34dee7098..ffafca8cd 100644
--- a/metagpt/utils/serialize.py
+++ b/metagpt/utils/serialize.py
@@ -3,14 +3,11 @@
# @Desc : the implement of serialization and deserialization
import copy
-from typing import Tuple, List, Type, Union, Dict
import pickle
-from collections import defaultdict
-from pydantic import create_model
+from typing import Dict, List, Tuple
-from metagpt.schema import Message
-from metagpt.actions.action import Action
from metagpt.actions.action_output import ActionOutput
+from metagpt.schema import Message
def actionoutout_schema_to_mapping(schema: Dict) -> Dict:
@@ -34,12 +31,12 @@ def actionoutout_schema_to_mapping(schema: Dict) -> Dict:
```
"""
mapping = dict()
- for field, property in schema['properties'].items():
- if property['type'] == 'string':
+ for field, property in schema["properties"].items():
+ if property["type"] == "string":
mapping[field] = (str, ...)
- elif property['type'] == 'array' and property['items']['type'] == 'string':
+ elif property["type"] == "array" and property["items"]["type"] == "string":
mapping[field] = (List[str], ...)
- elif property['type'] == 'array' and property['items']['type'] == 'array':
+ elif property["type"] == "array" and property["items"]["type"] == "array":
# here only consider the `Tuple[str, str]` situation
mapping[field] = (List[Tuple[str, str]], ...)
return mapping
@@ -53,11 +50,7 @@ def serialize_message(message: Message):
schema = ic.schema()
mapping = actionoutout_schema_to_mapping(schema)
- message_cp.instruct_content = {
- 'class': schema['title'],
- 'mapping': mapping,
- 'value': ic.dict()
- }
+ message_cp.instruct_content = {"class": schema["title"], "mapping": mapping, "value": ic.dict()}
msg_ser = pickle.dumps(message_cp)
return msg_ser
@@ -67,9 +60,8 @@ def deserialize_message(message_ser: str) -> Message:
message = pickle.loads(message_ser)
if message.instruct_content:
ic = message.instruct_content
- ic_obj = ActionOutput.create_model_class(class_name=ic['class'],
- mapping=ic['mapping'])
- ic_new = ic_obj(**ic['value'])
+ ic_obj = ActionOutput.create_model_class(class_name=ic["class"], mapping=ic["mapping"])
+ ic_new = ic_obj(**ic["value"])
message.instruct_content = ic_new
return message
diff --git a/metagpt/utils/text.py b/metagpt/utils/text.py
new file mode 100644
index 000000000..be3c52edd
--- /dev/null
+++ b/metagpt/utils/text.py
@@ -0,0 +1,124 @@
+from typing import Generator, Sequence
+
+from metagpt.utils.token_counter import TOKEN_MAX, count_string_tokens
+
+
+def reduce_message_length(msgs: Generator[str, None, None], model_name: str, system_text: str, reserved: int = 0,) -> str:
+ """Reduce the length of concatenated message segments to fit within the maximum token size.
+
+ Args:
+ msgs: A generator of strings representing progressively shorter valid prompts.
+ model_name: The name of the encoding to use. (e.g., "gpt-3.5-turbo")
+ system_text: The system prompts.
+ reserved: The number of reserved tokens.
+
+ Returns:
+ The concatenated message segments reduced to fit within the maximum token size.
+
+ Raises:
+ RuntimeError: If it fails to reduce the concatenated message length.
+ """
+ max_token = TOKEN_MAX.get(model_name, 2048) - count_string_tokens(system_text, model_name) - reserved
+ for msg in msgs:
+ if count_string_tokens(msg, model_name) < max_token:
+ return msg
+
+ raise RuntimeError("fail to reduce message length")
+
+
+def generate_prompt_chunk(
+ text: str,
+ prompt_template: str,
+ model_name: str,
+ system_text: str,
+ reserved: int = 0,
+) -> Generator[str, None, None]:
+ """Split the text into chunks of a maximum token size.
+
+ Args:
+ text: The text to split.
+ prompt_template: The template for the prompt, containing a single `{}` placeholder. For example, "### Reference\n{}".
+ model_name: The name of the encoding to use. (e.g., "gpt-3.5-turbo")
+ system_text: The system prompts.
+ reserved: The number of reserved tokens.
+
+ Yields:
+ The chunk of text.
+ """
+ paragraphs = text.splitlines(keepends=True)
+ current_token = 0
+ current_lines = []
+
+ reserved = reserved + count_string_tokens(prompt_template+system_text, model_name)
+ # 100 is a magic number to ensure the maximum context length is not exceeded
+ max_token = TOKEN_MAX.get(model_name, 2048) - reserved - 100
+
+ while paragraphs:
+ paragraph = paragraphs.pop(0)
+ token = count_string_tokens(paragraph, model_name)
+ if current_token + token <= max_token:
+ current_lines.append(paragraph)
+ current_token += token
+ elif token > max_token:
+ paragraphs = split_paragraph(paragraph) + paragraphs
+ continue
+ else:
+ yield prompt_template.format("".join(current_lines))
+ current_lines = [paragraph]
+ current_token = token
+
+ if current_lines:
+ yield prompt_template.format("".join(current_lines))
+
+
+def split_paragraph(paragraph: str, sep: str = ".,", count: int = 2) -> list[str]:
+ """Split a paragraph into multiple parts.
+
+ Args:
+ paragraph: The paragraph to split.
+ sep: The separator character.
+ count: The number of parts to split the paragraph into.
+
+ Returns:
+ A list of split parts of the paragraph.
+ """
+ for i in sep:
+ sentences = list(_split_text_with_ends(paragraph, i))
+ if len(sentences) <= 1:
+ continue
+ ret = ["".join(j) for j in _split_by_count(sentences, count)]
+ return ret
+ return _split_by_count(paragraph, count)
+
+
+def decode_unicode_escape(text: str) -> str:
+ """Decode a text with unicode escape sequences.
+
+ Args:
+ text: The text to decode.
+
+ Returns:
+ The decoded text.
+ """
+ return text.encode("utf-8").decode("unicode_escape", "ignore")
+
+
+def _split_by_count(lst: Sequence , count: int):
+ avg = len(lst) // count
+ remainder = len(lst) % count
+ start = 0
+ for i in range(count):
+ end = start + avg + (1 if i < remainder else 0)
+ yield lst[start:end]
+ start = end
+
+
+def _split_text_with_ends(text: str, sep: str = "."):
+ parts = []
+ for i in text:
+ parts.append(i)
+ if i == sep:
+ yield "".join(parts)
+ parts = []
+ if parts:
+ yield "".join(parts)
diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py
index 99ae5e176..591bb60f0 100644
--- a/metagpt/utils/token_counter.py
+++ b/metagpt/utils/token_counter.py
@@ -25,6 +25,21 @@ TOKEN_COSTS = {
}
+TOKEN_MAX = {
+ "gpt-3.5-turbo": 4096,
+ "gpt-3.5-turbo-0301": 4096,
+ "gpt-3.5-turbo-0613": 4096,
+ "gpt-3.5-turbo-16k": 16384,
+ "gpt-3.5-turbo-16k-0613": 16384,
+ "gpt-4-0314": 8192,
+ "gpt-4": 8192,
+ "gpt-4-32k": 32768,
+ "gpt-4-32k-0314": 32768,
+ "gpt-4-0613": 8192,
+ "text-embedding-ada-002": 8192,
+}
+
+
def count_message_tokens(messages, model="gpt-3.5-turbo-0613"):
"""Return the number of tokens used by a list of messages."""
try:
@@ -39,7 +54,7 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"):
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613",
- }:
+ }:
tokens_per_message = 3
tokens_per_name = 1
elif model == "gpt-3.5-turbo-0301":
@@ -79,3 +94,18 @@ def count_string_tokens(string: str, model_name: str) -> int:
"""
encoding = tiktoken.encoding_for_model(model_name)
return len(encoding.encode(string))
+
+
+def get_max_completion_tokens(messages: list[dict], model: str, default: int) -> int:
+ """Calculate the maximum number of completion tokens for a given model and list of messages.
+
+ Args:
+ messages: A list of messages.
+ model: The model name.
+
+ Returns:
+ The maximum number of completion tokens.
+ """
+ if model not in TOKEN_MAX:
+ return default
+ return TOKEN_MAX[model] - count_message_tokens(messages)
diff --git a/requirements.txt b/requirements.txt
index d29a0c975..72021b8e7 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -17,7 +17,7 @@ numpy==1.24.3
openai==0.27.8
openpyxl
pandas==1.4.1
-pydantic==1.10.7
+pydantic==1.10.8
#pygame==2.1.3
#pymilvus==2.2.8
pytest==7.2.2
@@ -37,4 +37,5 @@ typing-inspect==0.8.0
typing_extensions==4.5.0
aiofiles
libcst==1.0.1
+qdrant-client==1.4.0
diff --git a/ruff.toml b/ruff.toml
new file mode 100644
index 000000000..7835865e0
--- /dev/null
+++ b/ruff.toml
@@ -0,0 +1,40 @@
+select = ["E", "F"]
+ignore = ["E501", "E712", "E722", "F821", "E731"]
+
+# Allow autofix for all enabled rules (when `--fix`) is provided.
+fixable = ["A", "B", "C", "D", "E", "F", "G", "I", "N", "Q", "S", "T", "W", "ANN", "ARG", "BLE", "COM", "DJ", "DTZ", "EM", "ERA", "EXE", "FBT", "ICN", "INP", "ISC", "NPY", "PD", "PGH", "PIE", "PL", "PT", "PTH", "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID", "TRY", "UP", "YTT"]
+unfixable = []
+
+# Exclude a variety of commonly ignored directories.
+exclude = [
+ ".bzr",
+ ".direnv",
+ ".eggs",
+ ".git",
+ ".git-rewrite",
+ ".hg",
+ ".mypy_cache",
+ ".nox",
+ ".pants.d",
+ ".pytype",
+ ".ruff_cache",
+ ".svn",
+ ".tox",
+ ".venv",
+ "__pypackages__",
+ "_build",
+ "buck-out",
+ "build",
+ "dist",
+ "node_modules",
+ "venv",
+]
+
+# Same as Black.
+line-length = 119
+
+# Allow unused variables when underscore-prefixed.
+dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
+
+# Assume Python 3.9
+target-version = "py39"
\ No newline at end of file
diff --git a/setup.py b/setup.py
index e65696901..2a8edaae7 100644
--- a/setup.py
+++ b/setup.py
@@ -44,7 +44,7 @@ setup(
install_requires=requirements,
extras_require={
"playwright": ["playwright>=1.26", "beautifulsoup4"],
- "selenium": ["selenium>4", "webdriver_manager<3.9", "beautifulsoup4"],
+ "selenium": ["selenium>4", "webdriver_manager", "beautifulsoup4"],
},
cmdclass={
"install_mermaid": InstallMermaidCLI,
diff --git a/tests/metagpt/actions/test_run_code.py b/tests/metagpt/actions/test_run_code.py
index 489da28c6..1e451cb14 100644
--- a/tests/metagpt/actions/test_run_code.py
+++ b/tests/metagpt/actions/test_run_code.py
@@ -6,24 +6,23 @@
@File : test_run_code.py
"""
import pytest
-import asyncio
+
from metagpt.actions.run_code import RunCode
+
@pytest.mark.asyncio
async def test_run_text():
- action = RunCode()
- result, errs = await RunCode.run_text('result = 1 + 1')
+ result, errs = await RunCode.run_text("result = 1 + 1")
assert result == 2
assert errs == ""
- result, errs = await RunCode.run_text('result = 1 / 0')
+ result, errs = await RunCode.run_text("result = 1 / 0")
assert result == ""
assert "ZeroDivisionError" in errs
+
@pytest.mark.asyncio
async def test_run_script():
- action = RunCode()
-
# Successful command
out, err = await RunCode.run_script(".", command=["echo", "Hello World"])
assert out.strip() == "Hello World"
@@ -33,6 +32,7 @@ async def test_run_script():
out, err = await RunCode.run_script(".", command=["python", "-c", "print(1/0)"])
assert "ZeroDivisionError" in err
+
@pytest.mark.asyncio
async def test_run():
action = RunCode()
@@ -47,10 +47,11 @@ async def test_run():
test_file_name="",
command=["echo", "Hello World"],
working_directory=".",
- additional_python_paths=[]
+ additional_python_paths=[],
)
assert "PASS" in result
+
@pytest.mark.asyncio
async def test_run_failure():
action = RunCode()
@@ -65,6 +66,6 @@ async def test_run_failure():
test_file_name="",
command=["python", "-c", "print(1/0)"],
working_directory=".",
- additional_python_paths=[]
+ additional_python_paths=[],
)
- assert "FAIL" in result
\ No newline at end of file
+ assert "FAIL" in result
diff --git a/tests/metagpt/actions/test_write_code_review.py b/tests/metagpt/actions/test_write_code_review.py
index cee7eb941..21bc563ec 100644
--- a/tests/metagpt/actions/test_write_code_review.py
+++ b/tests/metagpt/actions/test_write_code_review.py
@@ -8,8 +8,6 @@
import pytest
from metagpt.actions.write_code_review import WriteCodeReview
-from metagpt.logs import logger
-from tests.metagpt.actions.mock import SEARCH_CODE_SAMPLE
@pytest.mark.asyncio
@@ -20,11 +18,7 @@ def add(a, b):
"""
# write_code_review = WriteCodeReview("write_code_review")
- code = await WriteCodeReview().run(
- context="编写一个从a加b的函数,返回a+b",
- code=code,
- filename="math.py"
- )
+ code = await WriteCodeReview().run(context="编写一个从a加b的函数,返回a+b", code=code, filename="math.py")
# 我们不能精确地预测生成的代码评审,但我们可以检查返回的是否为字符串
assert isinstance(code, str)
@@ -33,6 +27,7 @@ def add(a, b):
captured = capfd.readouterr()
print(f"输出内容: {captured.out}")
+
# @pytest.mark.asyncio
# async def test_write_code_review_directly():
# code = SEARCH_CODE_SAMPLE
diff --git a/tests/metagpt/document_store/test_qdrant_store.py b/tests/metagpt/document_store/test_qdrant_store.py
new file mode 100644
index 000000000..a63a4329d
--- /dev/null
+++ b/tests/metagpt/document_store/test_qdrant_store.py
@@ -0,0 +1,77 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+"""
+@Time : 2023/6/11 21:08
+@Author : hezhaozhao
+@File : test_qdrant_store.py
+"""
+import random
+
+from qdrant_client.models import (
+ Distance,
+ FieldCondition,
+ Filter,
+ PointStruct,
+ Range,
+ VectorParams,
+)
+
+from metagpt.document_store.qdrant_store import QdrantConnection, QdrantStore
+
+seed_value = 42
+random.seed(seed_value)
+
+vectors = [[random.random() for _ in range(2)] for _ in range(10)]
+
+points = [
+ PointStruct(
+ id=idx, vector=vector, payload={"color": "red", "rand_number": idx % 10}
+ )
+ for idx, vector in enumerate(vectors)
+]
+
+
+def test_milvus_store():
+ qdrant_connection = QdrantConnection(memory=True)
+ vectors_config = VectorParams(size=2, distance=Distance.COSINE)
+ qdrant_store = QdrantStore(qdrant_connection)
+ qdrant_store.create_collection("Book", vectors_config, force_recreate=True)
+ assert qdrant_store.has_collection("Book") is True
+ qdrant_store.delete_collection("Book")
+ assert qdrant_store.has_collection("Book") is False
+ qdrant_store.create_collection("Book", vectors_config)
+ assert qdrant_store.has_collection("Book") is True
+ qdrant_store.add("Book", points)
+ results = qdrant_store.search("Book", query=[1.0, 1.0])
+ assert results[0]["id"] == 2
+ assert results[0]["score"] == 0.999106722578389
+ assert results[1]["score"] == 7
+ assert results[1]["score"] == 0.9961650411397226
+ results = qdrant_store.search("Book", query=[1.0, 1.0], return_vector=True)
+ assert results[0]["id"] == 2
+ assert results[0]["score"] == 0.999106722578389
+ assert results[0]["vector"] == [0.7363563179969788, 0.6765939593315125]
+ assert results[1]["score"] == 7
+ assert results[1]["score"] == 0.9961650411397226
+ assert results[1]["vector"] == [0.7662628889083862, 0.6425272226333618]
+ results = qdrant_store.search(
+ "Book",
+ query=[1.0, 1.0],
+ query_filter=Filter(
+ must=[FieldCondition(key="rand_number", range=Range(gte=8))]
+ ),
+ )
+ assert results[0]["id"] == 8
+ assert results[0]["score"] == 0.9100373450784073
+ assert results[1]["id"] == 9
+ assert results[1]["score"] == 0.7127610621127889
+ results = qdrant_store.search(
+ "Book",
+ query=[1.0, 1.0],
+ query_filter=Filter(
+ must=[FieldCondition(key="rand_number", range=Range(gte=8))]
+ ),
+ return_vector=True,
+ )
+ assert results[0]["vector"] == [0.35037919878959656, 0.9366079568862915]
+ assert results[1]["vector"] == [0.9999677538871765, 0.00802854634821415]
diff --git a/tests/metagpt/roles/test_researcher.py b/tests/metagpt/roles/test_researcher.py
new file mode 100644
index 000000000..01b5dae3b
--- /dev/null
+++ b/tests/metagpt/roles/test_researcher.py
@@ -0,0 +1,32 @@
+from pathlib import Path
+from random import random
+from tempfile import TemporaryDirectory
+
+import pytest
+
+from metagpt.roles import researcher
+
+
+async def mock_llm_ask(self, prompt: str, system_msgs):
+ if "Please provide up to 2 necessary keywords" in prompt:
+ return '["dataiku", "datarobot"]'
+ elif "Provide up to 4 queries related to your research topic" in prompt:
+ return '["Dataiku machine learning platform", "DataRobot AI platform comparison", ' \
+ '"Dataiku vs DataRobot features", "Dataiku and DataRobot use cases"]'
+ elif "sort the remaining search results" in prompt:
+ return '[1,2]'
+ elif "Not relevant." in prompt:
+ return "Not relevant" if random() > 0.5 else prompt[-100:]
+ elif "provide a detailed research report" in prompt:
+ return f"# Research Report\n## Introduction\n{prompt}"
+ return ""
+
+
+@pytest.mark.asyncio
+async def test_researcher(mocker):
+ with TemporaryDirectory() as dirname:
+ topic = "dataiku vs. datarobot"
+ mocker.patch("metagpt.provider.base_gpt_api.BaseGPTAPI.aask", mock_llm_ask)
+ researcher.RESEARCH_PATH = Path(dirname)
+ await researcher.Researcher().run(topic)
+ assert (researcher.RESEARCH_PATH / f"{topic}.md").read_text().startswith("# Research Report")
diff --git a/tests/metagpt/roles/ui_role.py b/tests/metagpt/roles/ui_role.py
index 101be9c69..a45a89cde 100644
--- a/tests/metagpt/roles/ui_role.py
+++ b/tests/metagpt/roles/ui_role.py
@@ -2,22 +2,19 @@
# @Date : 2023/7/15 16:40
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
-import re
import os
-from importlib import import_module
+import re
from functools import wraps
+from importlib import import_module
-from metagpt.logs import logger
-from metagpt.actions import Action, ActionOutput
-from metagpt.roles import ProductManager, Role
-from metagpt.schema import Message
+from metagpt.actions import Action, ActionOutput, WritePRD
from metagpt.const import WORKSPACE_ROOT
-
-from metagpt.actions import WritePRD
-from metagpt.software_company import SoftwareCompany
+from metagpt.logs import logger
+from metagpt.roles import Role
+from metagpt.schema import Message
from metagpt.tools.sd_engine import SDEngine
-PROMPT_TEMPLATE = '''
+PROMPT_TEMPLATE = """
# Context
{context}
@@ -34,9 +31,9 @@ Attention: Use '##' to split sections, not '#', and '## This is a paragraph with a link and some emphasized text.
+| Header 1 | +Header 2 | +
|---|---|
| Row 1, Cell 1 | +Row 1, Cell 2 | +
| Row 2, Cell 1 | +Row 2, Cell 2 | +
+
+
+
+
+"""
+
+CONTENT = 'This is a HeadingThis is a paragraph witha linkand someemphasizedtext.Item 1Item 2Item 3Numbered Item 1Numbered '\
+'Item 2Numbered Item 3Header 1Header 2Row 1, Cell 1Row 1, Cell 2Row 2, Cell 1Row 2, Cell 2Name:Email:SubmitThis is a div '\
+'with a class "box".a link'
+
+
+def test_web_page():
+ page = parse_html.WebPage(inner_text=CONTENT, html=PAGE, url="http://example.com")
+ assert page.title == "Random HTML Example"
+ assert list(page.get_links()) == ["http://example.com/test", "https://metagpt.com"]
+
+
+def test_get_page_content():
+ ret = parse_html.get_html_content(PAGE, "http://example.com")
+ assert ret == CONTENT
diff --git a/tests/metagpt/utils/test_serialize.py b/tests/metagpt/utils/test_serialize.py
index de8ccba4c..69f317f79 100644
--- a/tests/metagpt/utils/test_serialize.py
+++ b/tests/metagpt/utils/test_serialize.py
@@ -3,94 +3,64 @@
# @Desc : the unittest of serialize
from typing import List, Tuple
-import pytest
-from pydantic import create_model
-
-from metagpt.actions.action_output import ActionOutput
from metagpt.actions import WritePRD
+from metagpt.actions.action_output import ActionOutput
from metagpt.schema import Message
-from metagpt.utils.serialize import actionoutout_schema_to_mapping, serialize_message, deserialize_message
+from metagpt.utils.serialize import (
+ actionoutout_schema_to_mapping,
+ deserialize_message,
+ serialize_message,
+)
def test_actionoutout_schema_to_mapping():
- schema = {
- 'title': 'test',
- 'type': 'object',
- 'properties': {
- 'field': {
- 'title': 'field',
- 'type': 'string'
- }
- }
- }
+ schema = {"title": "test", "type": "object", "properties": {"field": {"title": "field", "type": "string"}}}
mapping = actionoutout_schema_to_mapping(schema)
- assert mapping['field'] == (str, ...)
+ assert mapping["field"] == (str, ...)
schema = {
- 'title': 'test',
- 'type': 'object',
- 'properties': {
- 'field': {
- 'title': 'field',
- 'type': 'array',
- 'items': {
- 'type': 'string'
- }
- }
- }
+ "title": "test",
+ "type": "object",
+ "properties": {"field": {"title": "field", "type": "array", "items": {"type": "string"}}},
}
mapping = actionoutout_schema_to_mapping(schema)
- assert mapping['field'] == (List[str], ...)
+ assert mapping["field"] == (List[str], ...)
schema = {
- 'title': 'test',
- 'type': 'object',
- 'properties': {
- 'field': {
- 'title': 'field',
- 'type': 'array',
- 'items': {
- 'type': 'array',
- 'minItems': 2,
- 'maxItems': 2,
- 'items': [
- {
- 'type': 'string'
- },
- {
- 'type': 'string'
- }
- ]
- }
+ "title": "test",
+ "type": "object",
+ "properties": {
+ "field": {
+ "title": "field",
+ "type": "array",
+ "items": {
+ "type": "array",
+ "minItems": 2,
+ "maxItems": 2,
+ "items": [{"type": "string"}, {"type": "string"}],
+ },
}
- }
+ },
}
mapping = actionoutout_schema_to_mapping(schema)
- assert mapping['field'] == (List[Tuple[str, str]], ...)
+ assert mapping["field"] == (List[Tuple[str, str]], ...)
assert True, True
def test_serialize_and_deserialize_message():
- out_mapping = {
- 'field1': (str, ...),
- 'field2': (List[str], ...)
- }
- out_data = {
- 'field1': 'field1 value',
- 'field2': ['field2 value1', 'field2 value2']
- }
- ic_obj = ActionOutput.create_model_class('prd', out_mapping)
+ out_mapping = {"field1": (str, ...), "field2": (List[str], ...)}
+ out_data = {"field1": "field1 value", "field2": ["field2 value1", "field2 value2"]}
+ ic_obj = ActionOutput.create_model_class("prd", out_mapping)
- message = Message(content='prd demand',
- instruct_content=ic_obj(**out_data),
- role='user',
- cause_by=WritePRD) # WritePRD as test action
+ message = Message(
+ content="prd demand", instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD
+ ) # WritePRD as test action
message_ser = serialize_message(message)
new_message = deserialize_message(message_ser)
assert new_message.content == message.content
assert new_message.cause_by == message.cause_by
- assert new_message.instruct_content.field1 == out_data['field1']
+ assert new_message.instruct_content.field1 == out_data["field1"]
diff --git a/tests/metagpt/utils/test_text.py b/tests/metagpt/utils/test_text.py
new file mode 100644
index 000000000..0caf8abaa
--- /dev/null
+++ b/tests/metagpt/utils/test_text.py
@@ -0,0 +1,77 @@
+import pytest
+
+from metagpt.utils.text import (
+ decode_unicode_escape,
+ generate_prompt_chunk,
+ reduce_message_length,
+ split_paragraph,
+)
+
+
+def _msgs():
+ length = 20
+ while length:
+ yield "Hello," * 1000 * length
+ length -= 1
+
+
+def _paragraphs(n):
+ return " ".join("Hello World." for _ in range(n))
+
+
+@pytest.mark.parametrize(
+ "msgs, model_name, system_text, reserved, expected",
+ [
+ (_msgs(), "gpt-3.5-turbo", "System", 1500, 1),
+ (_msgs(), "gpt-3.5-turbo-16k", "System", 3000, 6),
+ (_msgs(), "gpt-3.5-turbo-16k", "Hello," * 1000, 3000, 5),
+ (_msgs(), "gpt-4", "System", 2000, 3),
+ (_msgs(), "gpt-4", "Hello," * 1000, 2000, 2),
+ (_msgs(), "gpt-4-32k", "System", 4000, 14),
+ (_msgs(), "gpt-4-32k", "Hello," * 2000, 4000, 12),
+ ]
+)
+def test_reduce_message_length(msgs, model_name, system_text, reserved, expected):
+ assert len(reduce_message_length(msgs, model_name, system_text, reserved)) / (len("Hello,")) / 1000 == expected
+
+
+@pytest.mark.parametrize(
+ "text, prompt_template, model_name, system_text, reserved, expected",
+ [
+ (" ".join("Hello World." for _ in range(1000)), "Prompt: {}", "gpt-3.5-turbo", "System", 1500, 2),
+ (" ".join("Hello World." for _ in range(1000)), "Prompt: {}", "gpt-3.5-turbo-16k", "System", 3000, 1),
+ (" ".join("Hello World." for _ in range(4000)), "Prompt: {}", "gpt-4", "System", 2000, 2),
+ (" ".join("Hello World." for _ in range(8000)), "Prompt: {}", "gpt-4-32k", "System", 4000, 1),
+ ]
+)
+def test_generate_prompt_chunk(text, prompt_template, model_name, system_text, reserved, expected):
+ ret = list(generate_prompt_chunk(text, prompt_template, model_name, system_text, reserved))
+ assert len(ret) == expected
+
+
+@pytest.mark.parametrize(
+ "paragraph, sep, count, expected",
+ [
+ (_paragraphs(10), ".", 2, [_paragraphs(5), f" {_paragraphs(5)}"]),
+ (_paragraphs(10), ".", 3, [_paragraphs(4), f" {_paragraphs(3)}", f" {_paragraphs(3)}"]),
+ (f"{_paragraphs(5)}\n{_paragraphs(3)}", "\n.", 2, [f"{_paragraphs(5)}\n", _paragraphs(3)]),
+ ("......", ".", 2, ["...", "..."]),
+ ("......", ".", 3, ["..", "..", ".."]),
+ (".......", ".", 2, ["....", "..."]),
+ ]
+)
+def test_split_paragraph(paragraph, sep, count, expected):
+ ret = split_paragraph(paragraph, sep, count)
+ assert ret == expected
+
+
+@pytest.mark.parametrize(
+ "text, expected",
+ [
+ ("Hello\\nWorld", "Hello\nWorld"),
+ ("Hello\\tWorld", "Hello\tWorld"),
+ ("Hello\\u0020World", "Hello World"),
+ ]
+)
+def test_decode_unicode_escape(text, expected):
+ assert decode_unicode_escape(text) == expected