Merge branch 'geekan:main' into main

This commit is contained in:
brucemeek 2023-08-17 11:02:27 -05:00 committed by GitHub
commit a0e6d20034
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
50 changed files with 1734 additions and 250 deletions

277
metagpt/actions/research.py Normal file
View file

@ -0,0 +1,277 @@
#!/usr/bin/env python
from __future__ import annotations
import asyncio
import json
from typing import Callable
from pydantic import parse_obj_as
from metagpt.actions import Action
from metagpt.config import CONFIG
from metagpt.logs import logger
from metagpt.tools.search_engine import SearchEngine
from metagpt.tools.web_browser_engine import WebBrowserEngine, WebBrowserEngineType
from metagpt.utils.text import generate_prompt_chunk, reduce_message_length
LANG_PROMPT = "Please respond in {language}."
RESEARCH_BASE_SYSTEM = """You are an AI critical thinker research assistant. Your sole purpose is to write well \
written, critically acclaimed, objective and structured reports on the given text."""
RESEARCH_TOPIC_SYSTEM = "You are an AI researcher assistant, and your research topic is:\n#TOPIC#\n{topic}"
SEARCH_TOPIC_PROMPT = """Please provide up to 2 necessary keywords related to your research topic for Google search. \
Your response must be in JSON format, for example: ["keyword1", "keyword2"]."""
SUMMARIZE_SEARCH_PROMPT = """### Requirements
1. The keywords related to your research topic and the search results are shown in the "Search Result Information" section.
2. Provide up to {decomposition_nums} queries related to your research topic base on the search results.
3. Please respond in the following JSON format: ["query1", "query2", "query3", ...].
### Search Result Information
{search_results}
"""
COLLECT_AND_RANKURLS_PROMPT = """### Topic
{topic}
### Query
{query}
### The online search results
{results}
### Requirements
Please remove irrelevant search results that are not related to the query or topic. Then, sort the remaining search results \
based on the link credibility. If two results have equal credibility, prioritize them based on the relevance. Provide the
ranked results' indices in JSON format, like [0, 1, 3, 4, ...], without including other words.
"""
WEB_BROWSE_AND_SUMMARIZE_PROMPT = '''### Requirements
1. Utilize the text in the "Reference Information" section to respond to the question "{query}".
2. If the question cannot be directly answered using the text, but the text is related to the research topic, please provide \
a comprehensive summary of the text.
3. If the text is entirely unrelated to the research topic, please reply with a simple text "Not relevant."
4. Include all relevant factual information, numbers, statistics, etc., if available.
### Reference Information
{content}
'''
CONDUCT_RESEARCH_PROMPT = '''### Reference Information
{content}
### Requirements
Please provide a detailed research report in response to the following topic: "{topic}", using the information provided \
above. The report must meet the following requirements:
- Focus on directly addressing the chosen topic.
- Ensure a well-structured and in-depth presentation, incorporating relevant facts and figures where available.
- Present data and findings in an intuitive manner, utilizing feature comparative tables, if applicable.
- The report should have a minimum word count of 2,000 and be formatted with Markdown syntax following APA style guidelines.
- Include all source URLs in APA format at the end of the report.
'''
class CollectLinks(Action):
"""Action class to collect links from a search engine."""
def __init__(
self,
name: str = "",
*args,
rank_func: Callable[[list[str]], None] | None = None,
**kwargs,
):
super().__init__(name, *args, **kwargs)
self.desc = "Collect links from a search engine."
self.search_engine = SearchEngine()
self.rank_func = rank_func
async def run(
self,
topic: str,
decomposition_nums: int = 4,
url_per_query: int = 4,
system_text: str | None = None,
) -> dict[str, list[str]]:
"""Run the action to collect links.
Args:
topic: The research topic.
decomposition_nums: The number of search questions to generate.
url_per_query: The number of URLs to collect per search question.
system_text: The system text.
Returns:
A dictionary containing the search questions as keys and the collected URLs as values.
"""
system_text = system_text if system_text else RESEARCH_TOPIC_SYSTEM.format(topic=topic)
keywords = await self._aask(SEARCH_TOPIC_PROMPT, [system_text])
try:
keywords = json.loads(keywords)
keywords = parse_obj_as(list[str], keywords)
except Exception as e:
logger.exception(f"fail to get keywords related to the research topic \"{topic}\" for {e}")
keywords = [topic]
results = await asyncio.gather(*(self.search_engine.run(i, as_string=False) for i in keywords))
def gen_msg():
while True:
search_results = "\n".join(f"#### Keyword: {i}\n Search Result: {j}\n" for (i, j) in zip(keywords, results))
prompt = SUMMARIZE_SEARCH_PROMPT.format(decomposition_nums=decomposition_nums, search_results=search_results)
yield prompt
remove = max(results, key=len)
remove.pop()
if len(remove) == 0:
break
prompt = reduce_message_length(gen_msg(), self.llm.model, system_text, CONFIG.max_tokens_rsp)
logger.debug(prompt)
queries = await self._aask(prompt, [system_text])
try:
queries = json.loads(queries)
queries = parse_obj_as(list[str], queries)
except Exception as e:
logger.exception(f"fail to break down the research question due to {e}")
queries = keywords
ret = {}
for query in queries:
ret[query] = await self._search_and_rank_urls(topic, query, url_per_query)
return ret
async def _search_and_rank_urls(self, topic: str, query: str, num_results: int = 4) -> list[str]:
"""Search and rank URLs based on a query.
Args:
topic: The research topic.
query: The search query.
num_results: The number of URLs to collect.
Returns:
A list of ranked URLs.
"""
max_results = max(num_results * 2, 6)
results = await self.search_engine.run(query, max_results=max_results, as_string=False)
_results = "\n".join(f"{i}: {j}" for i, j in zip(range(max_results), results))
prompt = COLLECT_AND_RANKURLS_PROMPT.format(topic=topic, query=query, results=_results)
logger.debug(prompt)
indices = await self._aask(prompt)
try:
indices = json.loads(indices)
assert all(isinstance(i, int) for i in indices)
except Exception as e:
logger.exception(f"fail to rank results for {e}")
indices = list(range(max_results))
results = [results[i] for i in indices]
if self.rank_func:
results = self.rank_func(results)
return [i["link"] for i in results[:num_results]]
class WebBrowseAndSummarize(Action):
"""Action class to explore the web and provide summaries of articles and webpages."""
def __init__(
self,
*args,
browse_func: Callable[[list[str]], None] | None = None,
**kwargs,
):
super().__init__(*args, **kwargs)
if CONFIG.model_for_researcher_summary:
self.llm.model = CONFIG.model_for_researcher_summary
self.web_browser_engine = WebBrowserEngine(
engine=WebBrowserEngineType.CUSTOM if browse_func else None,
run_func=browse_func,
)
self.desc = "Explore the web and provide summaries of articles and webpages."
async def run(
self,
url: str,
*urls: str,
query: str,
system_text: str = RESEARCH_BASE_SYSTEM,
) -> dict[str, str]:
"""Run the action to browse the web and provide summaries.
Args:
url: The main URL to browse.
urls: Additional URLs to browse.
query: The research question.
system_text: The system text.
Returns:
A dictionary containing the URLs as keys and their summaries as values.
"""
contents = await self.web_browser_engine.run(url, *urls)
if not urls:
contents = [contents]
summaries = {}
prompt_template = WEB_BROWSE_AND_SUMMARIZE_PROMPT.format(query=query, content="{}")
for u, content in zip([url, *urls], contents):
content = content.inner_text
chunk_summaries = []
for prompt in generate_prompt_chunk(content, prompt_template, self.llm.model, system_text, CONFIG.max_tokens_rsp):
logger.debug(prompt)
summary = await self._aask(prompt, [system_text])
if summary == "Not relevant.":
continue
chunk_summaries.append(summary)
if not chunk_summaries:
summaries[u] = None
continue
if len(chunk_summaries) == 1:
summaries[u] = chunk_summaries[0]
continue
content = "\n".join(chunk_summaries)
prompt = WEB_BROWSE_AND_SUMMARIZE_PROMPT.format(query=query, content=content)
summary = await self._aask(prompt, [system_text])
summaries[u] = summary
return summaries
class ConductResearch(Action):
"""Action class to conduct research and generate a research report."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if CONFIG.model_for_researcher_report:
self.llm.model = CONFIG.model_for_researcher_report
async def run(
self,
topic: str,
content: str,
system_text: str = RESEARCH_BASE_SYSTEM,
) -> str:
"""Run the action to conduct research and generate a research report.
Args:
topic: The research topic.
content: The content for research.
system_text: The system text.
Returns:
The generated research report.
"""
prompt = CONDUCT_RESEARCH_PROMPT.format(topic=topic, content=content)
logger.debug(prompt)
self.llm.auto_max_tokens = True
return await self._aask(prompt, [system_text])
def get_research_system_text(topic: str, language: str):
"""Get the system text for conducting research.
Args:
topic: The research topic.
language: The language for the system text.
Returns:
The system text for conducting research.
"""
return " ".join((RESEARCH_TOPIC_SYSTEM.format(topic=topic), LANG_PROMPT.format(language=language)))

View file

@ -5,13 +5,13 @@
@Author : alexanderwu
@File : run_code.py
"""
import traceback
import os
import subprocess
from typing import List, Tuple
import traceback
from typing import Tuple
from metagpt.logs import logger
from metagpt.actions.action import Action
from metagpt.logs import logger
PROMPT_TEMPLATE = """
Role: You are a senior development and qa engineer, your role is summarize the code running result.
@ -27,7 +27,7 @@ Please summarize the cause of the errors and give correction instruction
Determine the ONE file to rewrite in order to fix the error, for example, xyz.py, or test_xyz.py
## Status:
Determine if all of the code works fine, if so write PASS, else FAIL,
WRITE ONLY ONE WORD, PASS OR FAIL, IN THI SECTION
WRITE ONLY ONE WORD, PASS OR FAIL, IN THIS SECTION
## Send To:
Please write Engineer if the errors are due to problematic development codes, and QaEngineer to problematic test codes, and NoOne if there are no errors,
WRITE ONLY ONE WORD, Engineer OR QaEngineer OR NoOne, IN THIS SECTION.
@ -55,6 +55,7 @@ standard output: {outs};
standard errors: {errs};
"""
class RunCode(Action):
def __init__(self, name="RunCode", context=None, llm=None):
super().__init__(name, context, llm)
@ -65,7 +66,7 @@ class RunCode(Action):
# We will document_store the result in this dictionary
namespace = {}
exec(code, namespace)
return namespace.get('result', ""), ""
return namespace.get("result", ""), ""
except Exception:
# If there is an error in the code, return the error message
return "", traceback.format_exc()
@ -81,10 +82,12 @@ class RunCode(Action):
# Modify the PYTHONPATH environment variable
additional_python_paths = [working_directory] + additional_python_paths
additional_python_paths = ":".join(additional_python_paths)
env['PYTHONPATH'] = additional_python_paths + ':' + env.get('PYTHONPATH', '')
env["PYTHONPATH"] = additional_python_paths + ":" + env.get("PYTHONPATH", "")
# Start the subprocess
process = subprocess.Popen(command, cwd=working_directory, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env)
process = subprocess.Popen(
command, cwd=working_directory, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env
)
try:
# Wait for the process to complete, with a timeout
@ -93,7 +96,7 @@ class RunCode(Action):
logger.info("The command did not complete within the given timeout.")
process.kill() # Kill the process if it times out
stdout, stderr = process.communicate()
return stdout.decode('utf-8'), stderr.decode('utf-8')
return stdout.decode("utf-8"), stderr.decode("utf-8")
async def run(
self, code, mode="script", code_file_name="", test_code="", test_file_name="", command=[], **kwargs
@ -108,11 +111,13 @@ class RunCode(Action):
logger.info(f"{errs=}")
context = CONTEXT.format(
code=code, code_file_name=code_file_name,
test_code=test_code, test_file_name=test_file_name,
code=code,
code_file_name=code_file_name,
test_code=test_code,
test_file_name=test_file_name,
command=" ".join(command),
outs=outs[:500], # outs might be long but they are not important, truncate them to avoid token overflow
errs=errs[:10000] # truncate errors to avoid token overflow
outs=outs[:500], # outs might be long but they are not important, truncate them to avoid token overflow
errs=errs[:10000], # truncate errors to avoid token overflow
)
prompt = PROMPT_TEMPLATE.format(context=context)

View file

@ -5,7 +5,6 @@
@Author : alexanderwu
@File : environment.py
"""
from metagpt.logs import logger
from metagpt.actions.action import Action
from metagpt.utils.common import CodeParser
@ -29,6 +28,7 @@ you should correctly import the necessary classes based on these file locations!
## {test_file_name}: Write test code with triple quoto. Do your best to implement THIS ONLY ONE FILE.
"""
class WriteTest(Action):
def __init__(self, name="WriteTest", context=None, llm=None):
super().__init__(name, context, llm)
@ -43,7 +43,7 @@ class WriteTest(Action):
code_to_test=code_to_test,
test_file_name=test_file_name,
source_file_path=source_file_path,
workspace=workspace
workspace=workspace,
)
code = await self.write_code(prompt)
return code

View file

@ -32,5 +32,6 @@ UT_PY_PATH = UT_PATH / "files/ut/"
API_QUESTIONS_PATH = UT_PATH / "files/question/"
YAPI_URL = "http://yapi.deepwisdomai.com/"
TMP = PROJECT_ROOT / 'tmp'
RESEARCH_PATH = DATA_PATH / "research"
MEM_TTL = 24 * 30 * 3600

View file

@ -7,3 +7,5 @@
"""
from metagpt.document_store.faiss_store import FaissStore
__all__ = ["FaissStore"]

View file

@ -15,7 +15,7 @@ class BaseStore(ABC):
"""FIXME: consider add_index, set_index and think about granularity."""
@abstractmethod
def search(self, query, *args, **kwargs):
def search(self, *args, **kwargs):
raise NotImplementedError
@abstractmethod

View file

@ -0,0 +1,129 @@
from dataclasses import dataclass
from typing import List
from qdrant_client import QdrantClient
from qdrant_client.models import Filter, PointStruct, VectorParams
from metagpt.document_store.base_store import BaseStore
@dataclass
class QdrantConnection:
"""
Args:
url: qdrant url
host: qdrant host
port: qdrant port
memory: qdrant service use memory mode
api_key: qdrant cloud api_key
"""
url: str = None
host: str = None
port: int = None
memory: bool = False
api_key: str = None
class QdrantStore(BaseStore):
def __init__(self, connect: QdrantConnection):
if connect.memory:
self.client = QdrantClient(":memory:")
elif connect.url:
self.client = QdrantClient(url=connect.url, api_key=connect.api_key)
elif connect.host and connect.port:
self.client = QdrantClient(
host=connect.host, port=connect.port, api_key=connect.api_key
)
else:
raise Exception("please check QdrantConnection.")
def create_collection(
self,
collection_name: str,
vectors_config: VectorParams,
force_recreate=False,
**kwargs,
):
"""
create a collection
Args:
collection_name: collection name
vectors_config: VectorParams object,detail in https://github.com/qdrant/qdrant-client
force_recreate: default is False, if True, will delete exists collection,then create it
**kwargs:
Returns:
"""
try:
self.client.get_collection(collection_name)
if force_recreate:
res = self.client.recreate_collection(
collection_name, vectors_config=vectors_config, **kwargs
)
return res
return True
except: # noqa: E722
return self.client.recreate_collection(
collection_name, vectors_config=vectors_config, **kwargs
)
def has_collection(self, collection_name: str):
try:
self.client.get_collection(collection_name)
return True
except: # noqa: E722
return False
def delete_collection(self, collection_name: str, timeout=60):
res = self.client.delete_collection(collection_name, timeout=timeout)
if not res:
raise Exception(f"Delete collection {collection_name} failed.")
def add(self, collection_name: str, points: List[PointStruct]):
"""
add some vector data to qdrant
Args:
collection_name: collection name
points: list of PointStruct object, about PointStruct detail in https://github.com/qdrant/qdrant-client
Returns: NoneX
"""
# self.client.upload_records()
self.client.upsert(
collection_name,
points,
)
def search(
self,
collection_name: str,
query: List[float],
query_filter: Filter = None,
k=10,
return_vector=False,
):
"""
vector search
Args:
collection_name: qdrant collection name
query: input vector
query_filter: Filter object, detail in https://github.com/qdrant/qdrant-client
k: return the most similar k pieces of data
return_vector: whether return vector
Returns: list of dict
"""
hits = self.client.search(
collection_name=collection_name,
query_vector=query,
query_filter=query_filter,
limit=k,
with_vectors=return_vector,
)
return [hit.__dict__ for hit in hits]
def write(self, *args, **kwargs):
pass

View file

@ -9,3 +9,8 @@
from metagpt.memory.memory import Memory
from metagpt.memory.longterm_memory import LongTermMemory
__all__ = [
"Memory",
"LongTermMemory",
]

View file

@ -2,12 +2,10 @@
# -*- coding: utf-8 -*-
# @Desc : the implement of Long-term memory
from typing import Iterable, Type
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.memory import Memory
from metagpt.memory.memory_storage import MemoryStorage
from metagpt.schema import Message
class LongTermMemory(Memory):
@ -27,10 +25,11 @@ class LongTermMemory(Memory):
messages = self.memory_storage.recover_memory(role_id)
self.rc = rc
if not self.memory_storage.is_initialized:
logger.warning(f'It may the first time to run Agent {role_id}, the long-term memory is empty')
logger.warning(f"It may the first time to run Agent {role_id}, the long-term memory is empty")
else:
logger.warning(f'Agent {role_id} has existed memory storage with {len(messages)} messages '
f'and has recovered them.')
logger.warning(
f"Agent {role_id} has existed memory storage with {len(messages)} messages " f"and has recovered them."
)
self.msg_from_recover = True
self.add_batch(messages)
self.msg_from_recover = False

View file

@ -7,3 +7,6 @@
"""
from metagpt.provider.openai_api import OpenAIGPTAPI
__all__ = ["OpenAIGPTAPI"]

View file

@ -122,6 +122,15 @@ See FAQ 5.8
raise retry_state.outcome.exception()
def log_and_reraise(retry_state):
logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}")
logger.warning("""
Recommend going to https://deepwisdom.feishu.cn/wiki/MsGnwQBjiif9c3koSJNcYaoSnu4#part-XdatdVlhEojeAfxaaEZcMV3ZniQ
See FAQ 5.8
""")
raise retry_state.outcome.exception()
class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
"""
Check https://platform.openai.com/examples for examples
@ -223,11 +232,16 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
def _calc_usage(self, messages: list[dict], rsp: str) -> dict:
usage = {}
if CONFIG.calc_usage:
prompt_tokens = count_message_tokens(messages, self.model)
completion_tokens = count_string_tokens(rsp, self.model)
usage['prompt_tokens'] = prompt_tokens
usage['completion_tokens'] = completion_tokens
return usage
try:
prompt_tokens = count_message_tokens(messages, self.model)
completion_tokens = count_string_tokens(rsp, self.model)
usage['prompt_tokens'] = prompt_tokens
usage['completion_tokens'] = completion_tokens
return usage
except Exception as e:
logger.error("usage calculation failed!", e)
else:
return usage
async def acompletion_batch(self, batch: list[list[dict]]) -> list[dict]:
"""Return full JSON"""
@ -256,10 +270,13 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
return results
def _update_costs(self, usage: dict):
if CONFIG.update_costs:
prompt_tokens = int(usage['prompt_tokens'])
completion_tokens = int(usage['completion_tokens'])
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
if CONFIG.calc_usage:
try:
prompt_tokens = int(usage['prompt_tokens'])
completion_tokens = int(usage['completion_tokens'])
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
except Exception as e:
logger.error("updating costs failed!", e)
def get_costs(self) -> Costs:
return self._cost_manager.get_costs()

View file

@ -8,10 +8,23 @@
from metagpt.roles.role import Role
from metagpt.roles.architect import Architect
from metagpt.roles.product_manager import ProductManager
from metagpt.roles.project_manager import ProjectManager
from metagpt.roles.product_manager import ProductManager
from metagpt.roles.engineer import Engineer
from metagpt.roles.qa_engineer import QaEngineer
from metagpt.roles.seacher import Searcher
from metagpt.roles.sales import Sales
from metagpt.roles.customer_service import CustomerService
__all__ = [
"Role",
"Architect",
"ProjectManager",
"ProductManager",
"Engineer",
"QaEngineer",
"Searcher",
"Sales",
"CustomerService",
]

View file

@ -6,40 +6,44 @@
@File : qa_engineer.py
"""
import os
import re
from pathlib import Path
from typing import Type
from metagpt.actions import WriteTest, WriteCode, WriteDesign, RunCode, DebugError
from metagpt.actions import DebugError, RunCode, WriteCode, WriteDesign, WriteTest
from metagpt.const import WORKSPACE_ROOT
from metagpt.logs import logger
from metagpt.roles import Role
from metagpt.schema import Message
from metagpt.roles.engineer import Engineer
from metagpt.utils.common import CodeParser, parse_recipient
from metagpt.utils.special_tokens import MSG_SEP, FILENAME_CODE_SEP
from metagpt.utils.special_tokens import FILENAME_CODE_SEP, MSG_SEP
class QaEngineer(Role):
def __init__(self, name="Edward", profile="QaEngineer",
goal="Write comprehensive and robust tests to ensure codes will work as expected without bugs",
constraints="The test code you write should conform to code standard like PEP8, be modular, easy to read and maintain",
test_round_allowed=5):
def __init__(
self,
name="Edward",
profile="QaEngineer",
goal="Write comprehensive and robust tests to ensure codes will work as expected without bugs",
constraints="The test code you write should conform to code standard like PEP8, be modular, easy to read and maintain",
test_round_allowed=5,
):
super().__init__(name, profile, goal, constraints)
self._init_actions([WriteTest]) # FIXME: a bit hack here, only init one action to circumvent _think() logic, will overwrite _think() in future updates
self._init_actions(
[WriteTest]
) # FIXME: a bit hack here, only init one action to circumvent _think() logic, will overwrite _think() in future updates
self._watch([WriteCode, WriteTest, RunCode, DebugError])
self.test_round = 0
self.test_round_allowed = test_round_allowed
@classmethod
def parse_workspace(cls, system_design_msg: Message) -> str:
if not system_design_msg.instruct_content:
return system_design_msg.instruct_content.dict().get("Python package name")
return CodeParser.parse_str(block="Python package name", text=system_design_msg.content)
def get_workspace(self, return_proj_dir=True) -> Path:
msg = self._rc.memory.get_by_action(WriteDesign)[-1]
if not msg:
return WORKSPACE_ROOT / 'src'
return WORKSPACE_ROOT / "src"
workspace = self.parse_workspace(msg)
# project directory: workspace/{package_name}, which contains package source code folder, tests folder, resources folder, etc.
if return_proj_dir:
@ -48,49 +52,52 @@ class QaEngineer(Role):
return WORKSPACE_ROOT / workspace / workspace
def write_file(self, filename: str, code: str):
workspace = self.get_workspace() / 'tests'
workspace = self.get_workspace() / "tests"
file = workspace / filename
file.parent.mkdir(parents=True, exist_ok=True)
file.write_text(code)
async def _write_test(self, message: Message) -> None:
code_msgs = message.content.split(MSG_SEP)
result_msg_all = []
# result_msg_all = []
for code_msg in code_msgs:
# write tests
file_name, file_path = code_msg.split(FILENAME_CODE_SEP)
code_to_test = open(file_path, "r").read()
if "test" in file_name:
continue # Engineer might write some test files, skip testing a test file
continue # Engineer might write some test files, skip testing a test file
test_file_name = "test_" + file_name
test_file_path = self.get_workspace() / "tests" / test_file_name
logger.info(f'Writing {test_file_name}..')
logger.info(f"Writing {test_file_name}..")
test_code = await WriteTest().run(
code_to_test=code_to_test,
test_file_name=test_file_name,
# source_file_name=file_name,
source_file_path=file_path,
workspace=self.get_workspace()
workspace=self.get_workspace(),
)
self.write_file(test_file_name, test_code)
# prepare context for run tests in next round
command = ['python', f'tests/{test_file_name}']
command = ["python", f"tests/{test_file_name}"]
file_info = {
"file_name": file_name, "file_path": str(file_path),
"test_file_name": test_file_name, "test_file_path": str(test_file_path),
"command": command
"file_name": file_name,
"file_path": str(file_path),
"test_file_name": test_file_name,
"test_file_path": str(test_file_path),
"command": command,
}
msg = Message(
content=str(file_info), role=self.profile, cause_by=WriteTest,
sent_from=self.profile, send_to=self.profile
content=str(file_info),
role=self.profile,
cause_by=WriteTest,
sent_from=self.profile,
send_to=self.profile,
)
self._publish_message(msg)
logger.info(f'Done {self.get_workspace()}/tests generating.')
logger.info(f"Done {self.get_workspace()}/tests generating.")
async def _run_code(self, msg):
file_info = eval(msg.content)
development_file_path = file_info["file_path"]
@ -110,17 +117,14 @@ class QaEngineer(Role):
test_code=test_code,
test_file_name=file_info["test_file_name"],
command=file_info["command"],
working_directory=proj_dir, # workspace/package_name, will run tests/test_xxx.py here
additional_python_paths=[development_code_dir], # workspace/package_name/package_name,
# import statement inside package code needs this
working_directory=proj_dir, # workspace/package_name, will run tests/test_xxx.py here
additional_python_paths=[development_code_dir], # workspace/package_name/package_name,
# import statement inside package code needs this
)
recipient = parse_recipient(result_msg) # the recipient might be Engineer or myself
recipient = parse_recipient(result_msg) # the recipient might be Engineer or myself
content = str(file_info) + FILENAME_CODE_SEP + result_msg
msg = Message(
content=content, role=self.profile, cause_by=RunCode,
sent_from=self.profile, send_to=recipient
)
msg = Message(content=content, role=self.profile, cause_by=RunCode, sent_from=self.profile, send_to=recipient)
self._publish_message(msg)
async def _debug_error(self, msg):
@ -128,21 +132,27 @@ class QaEngineer(Role):
file_name, code = await DebugError().run(context)
if file_name:
self.write_file(file_name, code)
recipient = msg.sent_from # send back to the one who ran the code for another run, might be one's self
msg = Message(content=file_info, role=self.profile, cause_by=DebugError, sent_from=self.profile, send_to=recipient)
recipient = msg.sent_from # send back to the one who ran the code for another run, might be one's self
msg = Message(
content=file_info, role=self.profile, cause_by=DebugError, sent_from=self.profile, send_to=recipient
)
self._publish_message(msg)
async def _observe(self) -> int:
await super()._observe()
self._rc.news = [msg for msg in self._rc.news \
if msg.send_to == self.profile] # only relevant msgs count as observed news
self._rc.news = [
msg for msg in self._rc.news if msg.send_to == self.profile
] # only relevant msgs count as observed news
return len(self._rc.news)
async def _act(self) -> Message:
if self.test_round > self.test_round_allowed:
result_msg = Message(
content=f"Exceeding {self.test_round_allowed} rounds of tests, skip (writing code counts as a round, too)",
role=self.profile, cause_by=WriteTest, sent_from=self.profile, send_to=""
role=self.profile,
cause_by=WriteTest,
sent_from=self.profile,
send_to="",
)
return result_msg
@ -161,6 +171,9 @@ class QaEngineer(Role):
self.test_round += 1
result_msg = Message(
content=f"Round {self.test_round} of tests done",
role=self.profile, cause_by=WriteTest, sent_from=self.profile, send_to=""
role=self.profile,
cause_by=WriteTest,
sent_from=self.profile,
send_to="",
)
return result_msg

View file

@ -0,0 +1,93 @@
#!/usr/bin/env python
import asyncio
from pydantic import BaseModel
from metagpt.actions import CollectLinks, ConductResearch, WebBrowseAndSummarize
from metagpt.actions.research import get_research_system_text
from metagpt.const import RESEARCH_PATH
from metagpt.logs import logger
from metagpt.roles import Role
from metagpt.schema import Message
class Report(BaseModel):
topic: str
links: dict[str, list[str]] = None
summaries: list[tuple[str, str]] = None
content: str = ""
class Researcher(Role):
def __init__(
self,
name: str = "David",
profile: str = "Researcher",
goal: str = "Gather information and conduct research",
constraints: str = "Ensure accuracy and relevance of information",
language: str = "en-us",
**kwargs,
):
super().__init__(name, profile, goal, constraints, **kwargs)
self._init_actions([CollectLinks(name), WebBrowseAndSummarize(name), ConductResearch(name)])
self.language = language
if language not in ("en-us", "zh-cn"):
logger.warning(f"The language `{language}` has not been tested, it may not work.")
async def _think(self) -> None:
if self._rc.todo is None:
self._set_state(0)
return
if self._rc.state + 1 < len(self._states):
self._set_state(self._rc.state + 1)
else:
self._rc.todo = None
async def _act(self) -> Message:
logger.info(f"{self._setting}: ready to {self._rc.todo}")
todo = self._rc.todo
msg = self._rc.memory.get(k=1)[0]
if isinstance(msg.instruct_content, Report):
instruct_content = msg.instruct_content
topic = instruct_content.topic
else:
topic = msg.content
research_system_text = get_research_system_text(topic, self.language)
if isinstance(todo, CollectLinks):
links = await todo.run(topic, 4, 4)
ret = Message("", Report(topic=topic, links=links), role=self.profile, cause_by=type(todo))
elif isinstance(todo, WebBrowseAndSummarize):
links = instruct_content.links
todos = (todo.run(*url, query=query, system_text=research_system_text) for (query, url) in links.items())
summaries = await asyncio.gather(*todos)
summaries = list((url, summary) for i in summaries for (url, summary) in i.items() if summary)
ret = Message("", Report(topic=topic, summaries=summaries), role=self.profile, cause_by=type(todo))
else:
summaries = instruct_content.summaries
summary_text = "\n---\n".join(f"url: {url}\nsummary: {summary}" for (url, summary) in summaries)
content = await self._rc.todo.run(topic, summary_text, system_text=research_system_text)
ret = Message("", Report(topic=topic, content=content), role=self.profile, cause_by=type(self._rc.todo))
self._rc.memory.add(ret)
return ret
async def _react(self) -> Message:
while True:
await self._think()
if self._rc.todo is None:
break
msg = await self._act()
report = msg.instruct_content
self.write_report(report.topic, report.content)
return msg
def write_report(self, topic: str, content: str):
filepath = RESEARCH_PATH / f"{topic}.md"
filepath.write_text(content)
if __name__ == "__main__":
role = Researcher(language="en-us")
asyncio.run(role.run("dataiku vs. datarobot"))

View file

@ -14,6 +14,7 @@ class SearchEngineType(Enum):
SERPAPI_GOOGLE = auto()
DIRECT_GOOGLE = auto()
SERPER_GOOGLE = auto()
DUCK_DUCK_GO = auto()
CUSTOM_ENGINE = auto()

View file

@ -0,0 +1,107 @@
#!/usr/bin/env python
from __future__ import annotations
import asyncio
import json
from concurrent import futures
from typing import Literal, overload
from duckduckgo_search import DDGS
from googleapiclient.errors import HttpError
from metagpt.config import CONFIG
from metagpt.logs import logger
class DDGAPIWrapper:
"""Wrapper around duckduckgo_search API.
To use this module, you should have the `duckduckgo_search` Python package installed.
"""
def __init__(
self,
*,
loop: asyncio.AbstractEventLoop | None = None,
executor: futures.Executor | None = None,
):
kwargs = {}
if CONFIG.global_proxy:
kwargs["proxies"] = CONFIG.global_proxy
self.loop = loop
self.executor = executor
self.ddgs = DDGS(**kwargs)
@overload
def run(
self,
query: str,
max_results: int = 8,
as_string: Literal[True] = True,
focus: list[str] | None = None,
) -> str:
...
@overload
def run(
self,
query: str,
max_results: int = 8,
as_string: Literal[False] = False,
focus: list[str] | None = None,
) -> list[dict[str, str]]:
...
async def run(
self,
query: str,
max_results: int = 8,
as_string: bool = True,
) -> str | list[dict]:
"""Return the results of a Google search using the official Google API
Args:
query: The search query.
max_results: The number of results to return.
as_string: A boolean flag to determine the return type of the results. If True, the function will
return a formatted string with the search results. If False, it will return a list of dictionaries
containing detailed information about each search result.
Returns:
The results of the search.
"""
loop = self.loop or asyncio.get_event_loop()
future = loop.run_in_executor(
self.executor,
self._search_from_ddgs,
query,
max_results,
)
try:
search_results = await future
# Extract the search result items from the response
except HttpError as e:
# Handle errors in the API call
logger.exception(f"fail to search {query} for {e}")
search_results = []
# Return the list of search result URLs
if as_string:
return json.dumps(search_results, ensure_ascii=False)
return search_results
def _search_from_ddgs(self, query: str, max_results: int):
return [
{
"link": i["href"],
"snippet": i["body"],
"title": i["title"]
} for (_, i) in zip(range(max_results), self.ddgs.text(query))
]
if __name__ == "__main__":
import fire
fire.Fire(DDGAPIWrapper().run)

View file

@ -0,0 +1,117 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import annotations
import asyncio
import json
from concurrent import futures
from urllib.parse import urlparse
import httplib2
from googleapiclient.discovery import build
from googleapiclient.errors import HttpError
from metagpt.config import CONFIG
from metagpt.logs import logger
class GoogleAPIWrapper:
"""Wrapper around GoogleAPI.
To use this module, you should have the `google-api-python-client` Python package installed
and set property values for the configurations `GOOGLE_API_KEY` and `GOOGLE_CSE_ID`. See
https://programmablesearchengine.google.com/controlpanel/all.
"""
def __init__(
self,
*,
loop: asyncio.AbstractEventLoop | None = None,
executor: futures.Executor | None = None,
):
build_kwargs = {"developerKey": CONFIG.google_api_key}
if CONFIG.global_proxy:
parse_result = urlparse(CONFIG.global_proxy)
proxy_type = parse_result.scheme
if proxy_type == "https":
proxy_type = "http"
build_kwargs["http"] = httplib2.Http(
proxy_info=httplib2.ProxyInfo(
getattr(httplib2.socks, f"PROXY_TYPE_{proxy_type.upper()}"),
parse_result.hostname,
parse_result.port,
),
)
service = build("customsearch", "v1", **build_kwargs)
self.google_api_client = service.cse()
self.custom_search_engine_id = CONFIG.google_cse_id
self.loop = loop
self.executor = executor
async def run(
self,
query: str,
max_results: int = 8,
as_string: bool = True,
focus: list[str] | None = None,
) -> str | list[dict]:
"""Return the results of a Google search using the official Google API.
Args:
query: The search query.
max_results: The number of results to return.
as_string: A boolean flag to determine the return type of the results. If True, the function will
return a formatted string with the search results. If False, it will return a list of dictionaries
containing detailed information about each search result.
focus: Specific information to be focused on from each search result.
Returns:
The results of the search.
"""
loop = self.loop or asyncio.get_event_loop()
future = loop.run_in_executor(
self.executor,
self.google_api_client.list(
q=query,
num=max_results,
cx=self.custom_search_engine_id
).execute
)
try:
result = await future
# Extract the search result items from the response
search_results = result.get("items", [])
except HttpError as e:
# Handle errors in the API call
logger.exception(f"fail to search {query} for {e}")
search_results = []
focus = focus or ["snippet", "link", "title"]
details = [{i: j for i, j in item_dict.items() if i in focus} for item_dict in search_results]
# Return the list of search result URLs
if as_string:
return safe_google_results(details)
return details
def safe_google_results(results: str | list) -> str:
"""Return the results of a google search in a safe format.
Args:
results: The search results.
Returns:
The results of the search.
"""
if isinstance(results, list):
safe_message = json.dumps([result for result in results])
else:
safe_message = results.encode("utf-8", "ignore").decode("utf-8")
return safe_message
if __name__ == "__main__":
import fire
fire.Fire(GoogleAPIWrapper().run)

View file

@ -13,3 +13,12 @@ from metagpt.utils.token_counter import (
count_message_tokens,
count_string_tokens,
)
__all__ = [
"read_docx",
"Singleton",
"TOKEN_COSTS",
"count_message_tokens",
"count_string_tokens",
]

View file

@ -0,0 +1,57 @@
#!/usr/bin/env python
from __future__ import annotations
from typing import Generator, Optional
from urllib.parse import urljoin, urlparse
from bs4 import BeautifulSoup
from pydantic import BaseModel
class WebPage(BaseModel):
inner_text: str
html: str
url: str
class Config:
underscore_attrs_are_private = True
_soup : Optional[BeautifulSoup] = None
_title: Optional[str] = None
@property
def soup(self) -> BeautifulSoup:
if self._soup is None:
self._soup = BeautifulSoup(self.html, "html.parser")
return self._soup
@property
def title(self):
if self._title is None:
title_tag = self.soup.find("title")
self._title = title_tag.text.strip() if title_tag is not None else ""
return self._title
def get_links(self) -> Generator[str, None, None]:
for i in self.soup.find_all("a", href=True):
url = i["href"]
result = urlparse(url)
if not result.scheme and result.path:
yield urljoin(self.url, url)
elif url.startswith(("http://", "https://")):
yield urljoin(self.url, url)
def get_html_content(page: str, base: str):
soup = _get_soup(page)
return soup.get_text(strip=True)
def _get_soup(page: str):
soup = BeautifulSoup(page, "html.parser")
# https://stackoverflow.com/questions/1936466/how-to-scrape-only-visible-webpage-text-with-beautifulsoup
for s in soup(["style", "script", "[document]", "head", "title"]):
s.extract()
return soup

View file

@ -3,14 +3,11 @@
# @Desc : the implement of serialization and deserialization
import copy
from typing import Tuple, List, Type, Union, Dict
import pickle
from collections import defaultdict
from pydantic import create_model
from typing import Dict, List, Tuple
from metagpt.schema import Message
from metagpt.actions.action import Action
from metagpt.actions.action_output import ActionOutput
from metagpt.schema import Message
def actionoutout_schema_to_mapping(schema: Dict) -> Dict:
@ -34,12 +31,12 @@ def actionoutout_schema_to_mapping(schema: Dict) -> Dict:
```
"""
mapping = dict()
for field, property in schema['properties'].items():
if property['type'] == 'string':
for field, property in schema["properties"].items():
if property["type"] == "string":
mapping[field] = (str, ...)
elif property['type'] == 'array' and property['items']['type'] == 'string':
elif property["type"] == "array" and property["items"]["type"] == "string":
mapping[field] = (List[str], ...)
elif property['type'] == 'array' and property['items']['type'] == 'array':
elif property["type"] == "array" and property["items"]["type"] == "array":
# here only consider the `Tuple[str, str]` situation
mapping[field] = (List[Tuple[str, str]], ...)
return mapping
@ -53,11 +50,7 @@ def serialize_message(message: Message):
schema = ic.schema()
mapping = actionoutout_schema_to_mapping(schema)
message_cp.instruct_content = {
'class': schema['title'],
'mapping': mapping,
'value': ic.dict()
}
message_cp.instruct_content = {"class": schema["title"], "mapping": mapping, "value": ic.dict()}
msg_ser = pickle.dumps(message_cp)
return msg_ser
@ -67,9 +60,8 @@ def deserialize_message(message_ser: str) -> Message:
message = pickle.loads(message_ser)
if message.instruct_content:
ic = message.instruct_content
ic_obj = ActionOutput.create_model_class(class_name=ic['class'],
mapping=ic['mapping'])
ic_new = ic_obj(**ic['value'])
ic_obj = ActionOutput.create_model_class(class_name=ic["class"], mapping=ic["mapping"])
ic_new = ic_obj(**ic["value"])
message.instruct_content = ic_new
return message

124
metagpt/utils/text.py Normal file
View file

@ -0,0 +1,124 @@
from typing import Generator, Sequence
from metagpt.utils.token_counter import TOKEN_MAX, count_string_tokens
def reduce_message_length(msgs: Generator[str, None, None], model_name: str, system_text: str, reserved: int = 0,) -> str:
"""Reduce the length of concatenated message segments to fit within the maximum token size.
Args:
msgs: A generator of strings representing progressively shorter valid prompts.
model_name: The name of the encoding to use. (e.g., "gpt-3.5-turbo")
system_text: The system prompts.
reserved: The number of reserved tokens.
Returns:
The concatenated message segments reduced to fit within the maximum token size.
Raises:
RuntimeError: If it fails to reduce the concatenated message length.
"""
max_token = TOKEN_MAX.get(model_name, 2048) - count_string_tokens(system_text, model_name) - reserved
for msg in msgs:
if count_string_tokens(msg, model_name) < max_token:
return msg
raise RuntimeError("fail to reduce message length")
def generate_prompt_chunk(
text: str,
prompt_template: str,
model_name: str,
system_text: str,
reserved: int = 0,
) -> Generator[str, None, None]:
"""Split the text into chunks of a maximum token size.
Args:
text: The text to split.
prompt_template: The template for the prompt, containing a single `{}` placeholder. For example, "### Reference\n{}".
model_name: The name of the encoding to use. (e.g., "gpt-3.5-turbo")
system_text: The system prompts.
reserved: The number of reserved tokens.
Yields:
The chunk of text.
"""
paragraphs = text.splitlines(keepends=True)
current_token = 0
current_lines = []
reserved = reserved + count_string_tokens(prompt_template+system_text, model_name)
# 100 is a magic number to ensure the maximum context length is not exceeded
max_token = TOKEN_MAX.get(model_name, 2048) - reserved - 100
while paragraphs:
paragraph = paragraphs.pop(0)
token = count_string_tokens(paragraph, model_name)
if current_token + token <= max_token:
current_lines.append(paragraph)
current_token += token
elif token > max_token:
paragraphs = split_paragraph(paragraph) + paragraphs
continue
else:
yield prompt_template.format("".join(current_lines))
current_lines = [paragraph]
current_token = token
if current_lines:
yield prompt_template.format("".join(current_lines))
def split_paragraph(paragraph: str, sep: str = ".,", count: int = 2) -> list[str]:
"""Split a paragraph into multiple parts.
Args:
paragraph: The paragraph to split.
sep: The separator character.
count: The number of parts to split the paragraph into.
Returns:
A list of split parts of the paragraph.
"""
for i in sep:
sentences = list(_split_text_with_ends(paragraph, i))
if len(sentences) <= 1:
continue
ret = ["".join(j) for j in _split_by_count(sentences, count)]
return ret
return _split_by_count(paragraph, count)
def decode_unicode_escape(text: str) -> str:
"""Decode a text with unicode escape sequences.
Args:
text: The text to decode.
Returns:
The decoded text.
"""
return text.encode("utf-8").decode("unicode_escape", "ignore")
def _split_by_count(lst: Sequence , count: int):
avg = len(lst) // count
remainder = len(lst) % count
start = 0
for i in range(count):
end = start + avg + (1 if i < remainder else 0)
yield lst[start:end]
start = end
def _split_text_with_ends(text: str, sep: str = "."):
parts = []
for i in text:
parts.append(i)
if i == sep:
yield "".join(parts)
parts = []
if parts:
yield "".join(parts)

View file

@ -25,6 +25,21 @@ TOKEN_COSTS = {
}
TOKEN_MAX = {
"gpt-3.5-turbo": 4096,
"gpt-3.5-turbo-0301": 4096,
"gpt-3.5-turbo-0613": 4096,
"gpt-3.5-turbo-16k": 16384,
"gpt-3.5-turbo-16k-0613": 16384,
"gpt-4-0314": 8192,
"gpt-4": 8192,
"gpt-4-32k": 32768,
"gpt-4-32k-0314": 32768,
"gpt-4-0613": 8192,
"text-embedding-ada-002": 8192,
}
def count_message_tokens(messages, model="gpt-3.5-turbo-0613"):
"""Return the number of tokens used by a list of messages."""
try:
@ -39,7 +54,7 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"):
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613",
}:
}:
tokens_per_message = 3
tokens_per_name = 1
elif model == "gpt-3.5-turbo-0301":
@ -79,3 +94,18 @@ def count_string_tokens(string: str, model_name: str) -> int:
"""
encoding = tiktoken.encoding_for_model(model_name)
return len(encoding.encode(string))
def get_max_completion_tokens(messages: list[dict], model: str, default: int) -> int:
"""Calculate the maximum number of completion tokens for a given model and list of messages.
Args:
messages: A list of messages.
model: The model name.
Returns:
The maximum number of completion tokens.
"""
if model not in TOKEN_MAX:
return default
return TOKEN_MAX[model] - count_message_tokens(messages)