Merge branch 'dev' into dev

This commit is contained in:
better629 2023-12-29 02:48:39 +08:00 committed by GitHub
commit 539e1c7dce
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
81 changed files with 1402 additions and 649 deletions

1
.gitignore vendored
View file

@ -167,3 +167,4 @@ tmp.png
.dependencies.json
tests/metagpt/utils/file_repo_git
*.tmp
*.png

View file

@ -54,8 +54,8 @@ # Step 2: Clone the repository to your local machine for latest version, and ins
# Step 3: setup your OPENAI_API_KEY, or make sure it existed in the env
mkdir ~/.metagpt
cp config/config.yaml ~/.metagpt/key.yaml
vim ~/.metagpt/key.yaml
cp config/config.yaml ~/.metagpt/config.yaml
vim ~/.metagpt/config.yaml
# Step 4: run metagpt cli
metagpt "Create a 2048 game in python"

View file

@ -13,7 +13,7 @@ from metagpt.actions.add_requirement import UserRequirement
from metagpt.actions.debug_error import DebugError
from metagpt.actions.design_api import WriteDesign
from metagpt.actions.design_api_review import DesignReview
from metagpt.actions.project_management import AssignTasks, WriteTasks
from metagpt.actions.project_management import WriteTasks
from metagpt.actions.research import CollectLinks, WebBrowseAndSummarize, ConductResearch
from metagpt.actions.run_code import RunCode
from metagpt.actions.search_and_summarize import SearchAndSummarize
@ -38,7 +38,6 @@ class ActionType(Enum):
RUN_CODE = RunCode
DEBUG_ERROR = DebugError
WRITE_TASKS = WriteTasks
ASSIGN_TASKS = AssignTasks
SEARCH_AND_SUMMARIZE = SearchAndSummarize
COLLECT_LINKS = CollectLinks
WEB_BROWSE_AND_SUMMARIZE = WebBrowseAndSummarize

View file

@ -352,17 +352,3 @@ class ActionNode:
cls = self.create_children_class()
self.instruct_content = cls(**tmp)
return self
def action_node_example():
node = ActionNode(key="key-0", expected_type=str, instruction="instruction-a", example="example-b")
logger.info(node.compile(context="123", schema="raw", mode="auto"))
logger.info(node.compile(context="123", schema="json", mode="auto"))
logger.info(node.compile(context="123", schema="markdown", mode="auto"))
logger.info(node.to_dict())
logger.info(node)
if __name__ == "__main__":
action_node_example()

View file

@ -10,6 +10,3 @@ from metagpt.actions import Action
class UserRequirement(Action):
"""User Requirement without any implementation details"""
async def run(self, *args, **kwargs):
raise NotImplementedError

View file

@ -8,7 +8,6 @@
from typing import List
from metagpt.actions.action_node import ActionNode
from metagpt.logs import logger
from metagpt.utils.mermaid import MMC1, MMC2
IMPLEMENTATION_APPROACH = ActionNode(
@ -63,12 +62,3 @@ NODES = [
]
DESIGN_API_NODE = ActionNode.from_children("DesignAPI", NODES)
def main():
prompt = DESIGN_API_NODE.compile(context="")
logger.info(prompt)
if __name__ == "__main__":
main()

View file

@ -39,6 +39,8 @@ class PrepareDocuments(Action):
path = Path(CONFIG.project_path)
if path.exists() and not CONFIG.inc:
shutil.rmtree(path)
CONFIG.project_path = path
CONFIG.project_name = path.name
CONFIG.git_repo = GitRepository(local_path=path, auto_init=True)
async def run(self, with_messages, **kwargs):

View file

@ -123,9 +123,3 @@ class WriteTasks(Action):
@staticmethod
async def _save_pdf(task_doc):
await FileRepository.save_as(doc=task_doc, with_suffix=".md", relative_path=TASK_PDF_FILE_REPO)
class AssignTasks(Action):
async def run(self, *args, **kwargs):
# Here you should implement the actual action
pass

View file

@ -86,7 +86,7 @@ class CollectLinks(Action):
desc: str = "Collect links from a search engine."
search_engine: SearchEngine = Field(default_factory=SearchEngine)
rank_func: Union[Callable[[list[str]], None], None] = None
rank_func: Optional[Callable[[list[str]], None]] = None
async def run(
self,
@ -130,7 +130,8 @@ class CollectLinks(Action):
if len(remove) == 0:
break
prompt = reduce_message_length(gen_msg(), self.llm.model, system_text, CONFIG.max_tokens_rsp)
model_name = CONFIG.get_model_name(CONFIG.get_default_llm_provider_enum())
prompt = reduce_message_length(gen_msg(), model_name, system_text, CONFIG.max_tokens_rsp)
logger.debug(prompt)
queries = await self._aask(prompt, [system_text])
try:
@ -181,18 +182,18 @@ class WebBrowseAndSummarize(Action):
llm: BaseLLM = Field(default_factory=LLM)
desc: str = "Explore the web and provide summaries of articles and webpages."
browse_func: Union[Callable[[list[str]], None], None] = None
web_browser_engine: WebBrowserEngine = Field(
default_factory=lambda: WebBrowserEngine(
engine=WebBrowserEngineType.CUSTOM if WebBrowseAndSummarize.browse_func else None,
run_func=WebBrowseAndSummarize.browse_func,
)
)
web_browser_engine: Optional[WebBrowserEngine] = None
def __init__(self, **kwargs):
super().__init__(**kwargs)
if CONFIG.model_for_researcher_summary:
self.llm.model = CONFIG.model_for_researcher_summary
self.web_browser_engine = WebBrowserEngine(
engine=WebBrowserEngineType.CUSTOM if self.browse_func else None,
run_func=self.browse_func,
)
async def run(
self,
url: str,

View file

@ -82,11 +82,13 @@ class RunCode(Action):
llm: BaseLLM = Field(default_factory=LLM)
@classmethod
@handle_exception
async def run_text(cls, code) -> Tuple[str, str]:
# We will document_store the result in this dictionary
namespace = {}
exec(code, namespace)
try:
# We will document_store the result in this dictionary
namespace = {}
exec(code, namespace)
except Exception as e:
return "", str(e)
return namespace.get("result", ""), ""
@classmethod

View file

@ -21,7 +21,10 @@ Example:
This script uses the 'fire' library to create a command-line interface. It generates docstrings for the given Python code using
the specified docstring style and adds them to the code.
"""
from __future__ import annotations
import ast
from pathlib import Path
from typing import Literal, Optional
from pydantic import Field
@ -29,7 +32,7 @@ from pydantic import Field
from metagpt.actions.action import Action
from metagpt.llm import LLM
from metagpt.provider.base_llm import BaseLLM
from metagpt.utils.common import OutputParser
from metagpt.utils.common import OutputParser, aread, awrite
from metagpt.utils.pycst import merge_docstring
PYTHON_DOCSTRING_SYSTEM = """### Requirements
@ -187,6 +190,16 @@ class WriteDocstring(Action):
documented_code = OutputParser.parse_python_code(documented_code)
return merge_docstring(code, documented_code)
@staticmethod
async def write_docstring(
filename: str | Path, overwrite: bool = False, style: Literal["google", "numpy", "sphinx"] = "google"
) -> str:
data = await aread(str(filename))
code = await WriteDocstring().run(data, style=style)
if overwrite:
await awrite(filename, code)
return code
def _simplify_python_code(code: str) -> None:
"""Simplifies the given Python code by removing expressions and the last if statement.
@ -207,13 +220,4 @@ def _simplify_python_code(code: str) -> None:
if __name__ == "__main__":
import fire
async def run(filename: str, overwrite: bool = False, style: Literal["google", "numpy", "sphinx"] = "google"):
with open(filename) as f:
code = f.read()
code = await WriteDocstring().run(code, style=style)
if overwrite:
with open(filename, "w") as f:
f.write(code)
return code
fire.Fire(run)
fire.Fire(WriteDocstring.write_docstring)

View file

@ -66,7 +66,7 @@ NEW_REQ_TEMPLATE = """
class WritePRD(Action):
name: str = ""
name: str = "WritePRD"
content: Optional[str] = None
llm: BaseLLM = Field(default_factory=LLM)
@ -181,18 +181,13 @@ class WritePRD(Action):
@staticmethod
async def _rename_workspace(prd):
if CONFIG.project_path: # Updating on the old version has already been specified if it's valid. According to
# Section 2.2.3.10 of RFC 135
if not CONFIG.project_name:
CONFIG.project_name = Path(CONFIG.project_path).name
return
if not CONFIG.project_name:
if isinstance(prd, (ActionOutput, ActionNode)):
ws_name = prd.instruct_content.model_dump()["Project Name"]
else:
ws_name = CodeParser.parse_str(block="Project Name", text=prd)
CONFIG.project_name = ws_name
if ws_name:
CONFIG.project_name = ws_name
CONFIG.git_repo.rename_root(CONFIG.project_name)
async def _is_bugfix(self, context) -> bool:

View file

@ -72,6 +72,7 @@ class Config(metaclass=Singleton):
self.inc = False
self.reqa_file = ""
self.max_auto_summarize_code = 0
self.git_reinit = False
self._init_with_config_files_and_env(yaml_file)
# The agent needs to be billed per user, so billing information cannot be destroyed when the session ends.
@ -110,11 +111,7 @@ class Config(metaclass=Singleton):
if provider is LLMProviderEnum.GEMINI and not require_python_version(req_version=(3, 10)):
warnings.warn("Use Gemini requires Python >= 3.10")
model_mappings = {
LLMProviderEnum.OPENAI: self.OPENAI_API_MODEL,
LLMProviderEnum.AZURE_OPENAI: self.DEPLOYMENT_NAME,
}
model_name = model_mappings.get(provider)
model_name = self.get_model_name(provider=provider)
if model_name:
logger.info(f"{provider} Model: {model_name}")
if provider:
@ -122,6 +119,14 @@ class Config(metaclass=Singleton):
return provider
raise NotConfiguredException("You should config a LLM configuration first")
def get_model_name(self, provider=None) -> str:
provider = provider or self.get_default_llm_provider_enum()
model_mappings = {
LLMProviderEnum.OPENAI: self.OPENAI_API_MODEL,
LLMProviderEnum.AZURE_OPENAI: self.DEPLOYMENT_NAME,
}
return model_mappings.get(provider, "")
@staticmethod
def _is_valid_llm_key(k: str) -> bool:
return bool(k and k != "YOUR_API_KEY")
@ -142,7 +147,7 @@ class Config(metaclass=Singleton):
if not self._get("DISABLE_LLM_PROVIDER_CHECK"):
_ = self.get_default_llm_provider_enum()
# self.openai_base_url = self._get("OPENAI_BASE_URL")
self.openai_base_url = self._get("OPENAI_BASE_URL")
self.openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy
self.openai_api_type = self._get("OPENAI_API_TYPE")
self.openai_api_version = self._get("OPENAI_API_VERSION")

View file

@ -53,6 +53,7 @@ DEFAULT_WORKSPACE_ROOT = METAGPT_ROOT / "workspace"
EXAMPLE_PATH = METAGPT_ROOT / "examples"
DATA_PATH = METAGPT_ROOT / "data"
TEST_DATA_PATH = METAGPT_ROOT / "tests/data"
RESEARCH_PATH = DATA_PATH / "research"
TUTORIAL_PATH = DATA_PATH / "tutorial_docx"
INVOICE_OCR_TABLE_PATH = DATA_PATH / "invoice_table"

View file

@ -1,111 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/5/28 00:00
@Author : alexanderwu
@File : milvus_store.py
"""
from typing import TypedDict
import numpy as np
from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, connections
from metagpt.document_store.base_store import BaseStore
type_mapping = {int: DataType.INT64, str: DataType.VARCHAR, float: DataType.DOUBLE, np.ndarray: DataType.FLOAT_VECTOR}
def columns_to_milvus_schema(columns: dict, primary_col_name: str = "", desc: str = ""):
"""Assume the structure of columns is str: regular type"""
fields = []
for col, ctype in columns.items():
if ctype == str:
mcol = FieldSchema(name=col, dtype=type_mapping[ctype], max_length=100)
elif ctype == np.ndarray:
mcol = FieldSchema(name=col, dtype=type_mapping[ctype], dim=2)
else:
mcol = FieldSchema(name=col, dtype=type_mapping[ctype], is_primary=(col == primary_col_name))
fields.append(mcol)
schema = CollectionSchema(fields, description=desc)
return schema
class MilvusConnection(TypedDict):
alias: str
host: str
port: str
class MilvusStore(BaseStore):
"""
FIXME: ADD TESTS
https://milvus.io/docs/v2.0.x/create_collection.md
"""
def __init__(self, connection):
connections.connect(**connection)
self.collection = None
def _create_collection(self, name, schema):
collection = Collection(name=name, schema=schema, using="default", shards_num=2, consistency_level="Strong")
return collection
def create_collection(self, name, columns):
schema = columns_to_milvus_schema(columns, "idx")
self.collection = self._create_collection(name, schema)
return self.collection
def drop(self, name):
Collection(name).drop()
def load_collection(self):
self.collection.load()
def build_index(self, field="emb"):
self.collection.create_index(field, {"index_type": "FLAT", "metric_type": "L2", "params": {}})
def search(self, query: list[list[float]], *args, **kwargs):
"""
FIXME: ADD TESTS
https://milvus.io/docs/v2.0.x/search.md
All search and query operations within Milvus are executed in memory. Load the collection to memory before conducting a vector similarity search.
Note the above description, is this logic serious? This should take a long time, right?
"""
search_params = {"metric_type": "L2", "params": {"nprobe": 10}}
results = self.collection.search(
data=query,
anns_field=kwargs.get("field", "emb"),
param=search_params,
limit=10,
expr=None,
consistency_level="Strong",
)
# FIXME: results contain id, but to get the actual value from the id, we still need to call the query interface
return results
def write(self, name, schema, *args, **kwargs):
"""
FIXME: ADD TESTS
https://milvus.io/docs/v2.0.x/create_collection.md
:param args:
:param kwargs:
:return:
"""
raise NotImplementedError
def add(self, data, *args, **kwargs):
"""
FIXME: ADD TESTS
https://milvus.io/docs/v2.0.x/insert_data.md
import random
data = [
[i for i in range(2000)],
[i for i in range(10000, 12000)],
[[random.random() for _ in range(2)] for _ in range(2000)],
]
:param args:
:param kwargs:
:return:
"""
self.collection.insert(data)

View file

@ -28,7 +28,7 @@ class SkillManager:
:return:
"""
self._skills[skill.name] = skill
self._store.add(skill.desc, {}, skill.name)
self._store.add(skill.desc, {"name": skill.name, "desc": skill.desc}, skill.name)
def del_skill(self, skill_name: str):
"""

View file

@ -55,9 +55,9 @@ class BrainMemory(BaseModel):
return "\n".join(texts)
@staticmethod
async def loads(redis_key: str, redis_conf: Dict = None) -> "BrainMemory":
redis = Redis(conf=redis_conf)
if not redis.is_valid() or not redis_key:
async def loads(redis_key: str) -> "BrainMemory":
redis = Redis()
if not redis.is_valid or not redis_key:
return BrainMemory()
v = await redis.get(key=redis_key)
logger.debug(f"REDIS GET {redis_key} {v}")
@ -67,11 +67,11 @@ class BrainMemory(BaseModel):
return bm
return BrainMemory()
async def dumps(self, redis_key: str, timeout_sec: int = 30 * 60, redis_conf: Dict = None):
async def dumps(self, redis_key: str, timeout_sec: int = 30 * 60):
if not self.is_dirty:
return
redis = Redis(conf=redis_conf)
if not redis.is_valid() or not redis_key:
redis = Redis()
if not redis.is_valid or not redis_key:
return False
v = self.model_dump_json()
if self.cacheable:
@ -86,26 +86,27 @@ class BrainMemory(BaseModel):
async def set_history_summary(self, history_summary, redis_key, redis_conf):
if self.historical_summary == history_summary:
if self.is_dirty:
await self.dumps(redis_key=redis_key, redis_conf=redis_conf)
await self.dumps(redis_key=redis_key)
self.is_dirty = False
return
self.historical_summary = history_summary
self.history = []
await self.dumps(redis_key=redis_key, redis_conf=redis_conf)
await self.dumps(redis_key=redis_key)
self.is_dirty = False
def add_history(self, msg: Message):
if msg.id:
if self.to_int(msg.id, 0) <= self.to_int(self.last_history_id, -1):
return
self.history.append(msg.model_dump())
self.history.append(msg)
self.last_history_id = str(msg.id)
self.is_dirty = True
def exists(self, text) -> bool:
for m in reversed(self.history):
if m.get("content") == text:
if m.content == text:
return True
return False
@ -163,7 +164,7 @@ class BrainMemory(BaseModel):
msgs.reverse()
self.history = msgs
self.is_dirty = True
await self.dumps(redis_key=CONFIG.REDIS_KEY, redis_conf=CONFIG.REDIS_CONF)
await self.dumps(redis_key=CONFIG.REDIS_KEY)
self.is_dirty = False
return BrainMemory.to_metagpt_history_format(self.history)
@ -217,7 +218,7 @@ class BrainMemory(BaseModel):
return await self._openai_rewrite(sentence=sentence, context=context, llm=llm)
@staticmethod
async def _metagpt_rewrite(sentence: str):
async def _metagpt_rewrite(sentence: str, **kwargs):
return sentence
@staticmethod

View file

@ -58,7 +58,7 @@ class GeminiLLM(BaseLLM):
genai.configure(api_key=config.gemini_api_key)
def _user_msg(self, msg: str) -> dict[str, str]:
# Not to change BaseGPTAPI default functions but update with Gemini's conversation format.
# Not to change BaseLLM default functions but update with Gemini's conversation format.
# You should follow the format.
return {"role": "user", "parts": [msg]}

View file

@ -69,7 +69,7 @@ class OpenAILLM(BaseLLM):
self.aclient = AsyncOpenAI(**kwargs)
def _make_client_kwargs(self) -> dict:
kwargs = {"api_key": self.config.OPENAI_API_KEY, "base_url": self.config.OPENAI_BASE_URL}
kwargs = {"api_key": self.config.openai_api_key, "base_url": self.config.openai_base_url}
# to use proxy, openai v1 needs http_client
if proxy_params := self._get_proxy_params():
@ -81,8 +81,8 @@ class OpenAILLM(BaseLLM):
params = {}
if self.config.openai_proxy:
params = {"proxies": self.config.openai_proxy}
if self.config.OPENAI_BASE_URL:
params["base_url"] = self.config.OPENAI_BASE_URL
if self.config.openai_base_url:
params["base_url"] = self.config.openai_base_url
return params

View file

@ -40,10 +40,11 @@ class ProductManager(Role):
async def _think(self) -> bool:
"""Decide what to do"""
if CONFIG.git_repo:
if CONFIG.git_repo and not CONFIG.git_reinit:
self._set_state(1)
else:
self._set_state(0)
CONFIG.git_reinit = False
self.todo_action = any_to_name(WritePRD)
return bool(self.rc.todo)

View file

@ -6,6 +6,7 @@
"""
import asyncio
import re
from pydantic import BaseModel
@ -107,9 +108,11 @@ class Researcher(Role):
return msg
def write_report(self, topic: str, content: str):
filename = re.sub(r'[\\/:"*?<>|]+', " ", topic)
filename = filename.replace("\n", "")
if not RESEARCH_PATH.exists():
RESEARCH_PATH.mkdir(parents=True)
filepath = RESEARCH_PATH / f"{topic}.md"
filepath = RESEARCH_PATH / f"{filename}.md"
filepath.write_text(content)

View file

@ -0,0 +1,4 @@
# -*- coding: utf-8 -*-
# @Date : 12/23/2023 4:51 PM
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :

108
metagpt/strategy/base.py Normal file
View file

@ -0,0 +1,108 @@
# -*- coding: utf-8 -*-
# @Date : 12/25/2023 9:16 PM
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
from typing import List
from anytree import Node, RenderTree
from pydantic import BaseModel
class BaseParser(BaseModel):
def __call__(self, *args, **kwargs):
raise NotImplementedError
def propose(self, current_state: str, **kwargs) -> str:
raise NotImplementedError
def sample(self, current_state: str, **kwargs) -> str:
raise NotImplementedError
def value(self, input: str, **kwargs) -> str:
raise NotImplementedError
class BaseEvaluator(BaseModel):
def __call__(self, *args, **kwargs):
raise NotImplementedError
def status_verify(self, *args, **kwargs):
raise NotImplementedError
class ThoughtNode(Node):
"""A node representing a thought in the thought tree."""
name: str = ""
value: int = 0
id: int = 0
valid_status: bool = True
def update_value(self, value) -> None:
"""Update the value of the thought node."""
self.value = value
def update_valid_status(self, status) -> None:
"""Update the validity status of the thought node."""
self.valid_status = status
class ThoughtTree(RenderTree):
"""A tree structure to represent thoughts."""
@property
def all_nodes(self) -> List[ThoughtNode]:
"""
Get a list of all nodes in the thought tree.
Returns:
List[ThoughtNode]: A list containing all nodes in the thought tree.
"""
all_nodes = [node for _, _, node in self]
return all_nodes
def update_node(self, thought: List[dict] = [], current_node: ThoughtNode = None) -> List[ThoughtNode]:
"""
Update the tree with new thoughts.
Args:
thought (List[dict]): A list of dictionaries representing thought information.
current_node (ThoughtNode): The current node under which new thoughts will be added.
Returns:
List[ThoughtNode]: A list of ThoughtNode instances representing the updated tree nodes.
"""
nodes = []
for node_info in thought:
node = ThoughtNode(
name=node_info["node_state_instruction"], parent=current_node, id=int(node_info["node_id"])
)
nodes.append(node)
return nodes
def parse_node_path(self, node) -> List[str]:
"""
Parse and retrieve the hierarchical path of the given thought node.
This method traverses the parent nodes of the provided 'node' and constructs
the full path from the root node to the given node.
Args:
node: The thought node for which the hierarchical path needs to be parsed.
Returns:
List[str]: A list representing the full hierarchical path of the given thought node.
The list is ordered from the root node to the provided node.
"""
full_node_path = []
while node is not None:
full_node_path.append(node.name)
node = node.parent
full_node_path.reverse()
return full_node_path
def show(self) -> None:
"""Print the updated tree."""
print("\nUpdated Tree:")
for pre, _, node in self:
print(f"{pre}{node.name}, value: {node.value}, valid_status: {node.valid_status}")

View file

@ -0,0 +1,4 @@
# -*- coding: utf-8 -*-
# @Date : 12/26/2023 3:32 PM
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :

View file

@ -0,0 +1,73 @@
# -*- coding: utf-8 -*-
# @Date : 12/25/2023 1:06 PM
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
import re
from metagpt.strategy.prompt_templates.creative_writing import cot_prompt, vote_prompt
from metagpt.strategy.tot import TreeofThought
from metagpt.strategy.tot_schema import (
BaseEvaluator,
BaseParser,
Strategy,
ThoughtSolverConfig,
)
class TextGenParser(BaseParser):
propose_prompt: str = cot_prompt
value_prompt: str = vote_prompt
def __call__(self, input_text: str) -> str:
return input_text
def propose(self, current_state: str, **kwargs) -> str:
return self.propose_prompt.format(input=current_state, **kwargs)
def value(self, input: str = "", **kwargs) -> str:
# node_result = self(input)
id = kwargs.get("node_id", "0")
return self.value_prompt + f"Choice {id}:\n{input}\n"
class TextGenEvaluator(BaseEvaluator):
value_map = {"impossible": 0.001, "likely": 1, "sure": 20} # TODO: ad hoc
status_map = {val: key for key, val in value_map.items()}
def __call__(self, evaluation: str, **kwargs) -> float:
try:
value = 0
node_id = kwargs.get("node_id", "0")
pattern = r".*best choice is .*(\d+).*"
match = re.match(pattern, evaluation, re.DOTALL)
if match:
vote = int(match.groups()[0])
print(vote)
if vote == int(node_id):
value = 1
except:
value = 0
return value
def status_verify(self, value):
status = False
if value in self.status_map:
status_value = self.status_map[value]
if status_value != "impossible":
status = True
return status
if __name__ == "__main__":
import asyncio
initial_prompt = """It isn't difficult to do a handstand if you just stand on your hands. It caught him off guard that space smelled of seared steak. When she didnt like a guy who was trying to pick her up, she started using sign language. Each person who knows you has a different perception of who you are."""
parser = TextGenParser()
evaluator = TextGenEvaluator()
config = ThoughtSolverConfig(n_generate_sample=3, parser=parser, evaluator=evaluator)
tot_base = TreeofThought(strategy=Strategy.BFS, config=config)
asyncio.run(tot_base.solve(init_prompt=initial_prompt))

View file

@ -0,0 +1,64 @@
# -*- coding: utf-8 -*-
# @Date : 12/25/2023 1:36 AM
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
import re
from metagpt.strategy.prompt_templates.game24 import propose_prompt, value_prompt
from metagpt.strategy.tot import TreeofThought
from metagpt.strategy.tot_schema import (
BaseEvaluator,
BaseParser,
Strategy,
ThoughtSolverConfig,
)
class Game24Parser(BaseParser):
propose_prompt: str = propose_prompt
value_prompt: str = value_prompt
def __call__(self, input_text: str) -> str:
last_line = input_text.strip().split("\n")[-1]
return last_line.split("left: ")[-1].split(")")[0]
def propose(self, current_state: str, **kwargs) -> str:
return self.propose_prompt.format(input=current_state, **kwargs)
def value(self, input: str = "", **kwargs) -> str:
node_result = self(input)
return self.value_prompt.format(input=node_result)
class Game24Evaluator(BaseEvaluator):
value_map = {"impossible": 0.001, "likely": 1, "sure": 20} # TODO: ad hoc
status_map = {val: key for key, val in value_map.items()}
def __call__(self, evaluation: str, **kwargs) -> float:
try:
matches = re.findall(r"\b(impossible|sure|likely)\b", evaluation)
value = self.value_map[matches[0]]
except:
value = 0.001
return value
def status_verify(self, value):
status = False
if value in self.status_map:
status_value = self.status_map[value]
if status_value != "impossible":
status = True
return status
if __name__ == "__main__":
import asyncio
initial_prompt = """4 5 6 10"""
parser = Game24Parser()
evaluator = Game24Evaluator()
config = ThoughtSolverConfig(n_generate_sample=5, parser=parser, evaluator=evaluator)
tot = TreeofThought(strategy=Strategy.BFS, config=config)
asyncio.run(tot.solve(init_prompt=initial_prompt))

View file

@ -0,0 +1,4 @@
# -*- coding: utf-8 -*-
# @Date : 12/23/2023 5:21 PM
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :

View file

@ -0,0 +1,25 @@
standard_prompt = """
Write a coherent passage of 4 short paragraphs. The end sentence of each paragraph must be: {input}
"""
cot_prompt = """
Write a coherent passage of 4 short paragraphs. The end sentence of each paragraph must be: {input}
Make a plan then write. Your output should be of the following format:
Plan:
Your plan here.
Passage:
Your passage here.
"""
vote_prompt = """Given an instruction and several choices, decide which choice is most promising. Analyze each choice in detail, then conclude in the last line "The best choice is {s}", where s the integer id of the choice.
"""
compare_prompt = """Briefly analyze the coherency of the following two passages. Conclude in the last line "The more coherent passage is 1", "The more coherent passage is 2", or "The two passages are similarly coherent".
"""
score_prompt = """Analyze the following passage, then at the last line conclude "Thus the coherency score is {s}", where s is an integer from 1 to 10.
"""

View file

@ -0,0 +1,139 @@
# 5-shot
standard_prompt = """Use numbers and basic arithmetic operations (+ - * /) to obtain 24.
Input: 4 4 6 8
Answer: (4 + 8) * (6 - 4) = 24
Input: 2 9 10 12
Answer: 2 * 12 * (10 - 9) = 24
Input: 4 9 10 13
Answer: (13 - 9) * (10 - 4) = 24
Input: 1 4 8 8
Answer: (8 / 4 + 1) * 8 = 24
Input: 5 5 5 9
Answer: 5 + 5 + 5 + 9 = 24
Input: {input}
"""
# 5-shot
cot_prompt = """Use numbers and basic arithmetic operations (+ - * /) to obtain 24. Each step, you are only allowed to choose two of the remaining numbers to obtain a new number.
Input: 4 4 6 8
Steps:
4 + 8 = 12 (left: 4 6 12)
6 - 4 = 2 (left: 2 12)
2 * 12 = 24 (left: 24)
Answer: (6 - 4) * (4 + 8) = 24
Input: 2 9 10 12
Steps:
12 * 2 = 24 (left: 9 10 24)
10 - 9 = 1 (left: 1 24)
24 * 1 = 24 (left: 24)
Answer: (12 * 2) * (10 - 9) = 24
Input: 4 9 10 13
Steps:
13 - 10 = 3 (left: 3 4 9)
9 - 3 = 6 (left: 4 6)
4 * 6 = 24 (left: 24)
Answer: 4 * (9 - (13 - 10)) = 24
Input: 1 4 8 8
Steps:
8 / 4 = 2 (left: 1 2 8)
1 + 2 = 3 (left: 3 8)
3 * 8 = 24 (left: 24)
Answer: (1 + 8 / 4) * 8 = 24
Input: 5 5 5 9
Steps:
5 + 5 = 10 (left: 5 9 10)
10 + 5 = 15 (left: 9 15)
15 + 9 = 24 (left: 24)
Answer: ((5 + 5) + 5) + 9 = 24
Input: {input}
"""
# 1-shot
propose_prompt = """Here is an Example for 1 input and 8 possible thoughts:
Input: 2 8 8 14
Possible next steps:
2 + 8 = 10 (left: 8 10 14)
8 / 2 = 4 (left: 4 8 14)
14 + 2 = 16 (left: 8 8 16)
2 * 8 = 16 (left: 8 14 16)
8 - 2 = 6 (left: 6 8 14)
14 - 8 = 6 (left: 2 6 8)
14 / 2 = 7 (left: 7 8 8)
14 - 2 = 12 (left: 8 8 12)
Here is my task for 1 input and {n_generate_sample} possible thoughts:
Input: {input}
Possible next steps:
"""
value_prompt = """Evaluate if given numbers can reach 24 (sure/likely/impossible)
10 14
10 + 14 = 24
sure
11 12
11 + 12 = 23
12 - 11 = 1
11 * 12 = 132
11 / 12 = 0.91
impossible
4 4 10
4 + 4 + 10 = 8 + 10 = 18
4 * 10 - 4 = 40 - 4 = 36
(10 - 4) * 4 = 6 * 4 = 24
sure
4 9 11
9 + 11 + 4 = 20 + 4 = 24
sure
5 7 8
5 + 7 + 8 = 12 + 8 = 20
(8 - 5) * 7 = 3 * 7 = 21
I cannot obtain 24 now, but numbers are within a reasonable range
likely
5 6 6
5 + 6 + 6 = 17
(6 - 5) * 6 = 1 * 6 = 6
I cannot obtain 24 now, but numbers are within a reasonable range
likely
10 10 11
10 + 10 + 11 = 31
(11 - 10) * 10 = 10
10 10 10 are all too big
impossible
1 3 3
1 * 3 * 3 = 9
(1 + 3) * 3 = 12
1 3 3 are all too small
impossible
{input}
"""
value_last_step_prompt = """Use numbers and basic arithmetic operations (+ - * /) to obtain 24. Given an input and an answer, give a judgement (sure/impossible) if the answer is correct, i.e. it uses each input exactly once and no other numbers, and reach 24.
Input: 4 4 6 8
Answer: (4 + 8) * (6 - 4) = 24
Judge:
sure
Input: 2 9 10 12
Answer: 2 * 12 * (10 - 9) = 24
Judge:
sure
Input: 4 9 10 13
Answer: (13 - 9) * (10 - 4) = 24
Judge:
sure
Input: 4 4 6 8
Answer: (4 + 8) * (6 - 4) + 1 = 25
Judge:
impossible
Input: 2 9 10 12
Answer: 2 * (12 - 10) = 24
Judge:
impossible
Input: 4 9 10 13
Answer: (13 - 4) * (10 - 9) = 24
Judge:
impossible
Input: {input}
Answer: {answer}
Judge:"""

272
metagpt/strategy/tot.py Normal file
View file

@ -0,0 +1,272 @@
# -*- coding: utf-8 -*-
# @Date : 12/23/2023 4:51 PM
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
import asyncio
from typing import Any, List
from pydantic import BaseModel, Field
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.provider.base_llm import BaseLLM
from metagpt.strategy.base import ThoughtNode, ThoughtTree
from metagpt.strategy.tot_schema import MethodSelect, Strategy, ThoughtSolverConfig
from metagpt.utils.common import CodeParser
OUTPUT_FORMAT = """
Output a list of jsons following the format:
```json
[
{
"node_id": str = "unique identifier for a solution, can be an ordinal",
"node_state_instruction": "specified sample of solution",
},
...
]
```
"""
class ThoughtSolverBase(BaseModel):
thought_tree: str = ""
llm: BaseLLM = Field(default_factory=LLM, exclude=True)
config: ThoughtSolverConfig = Field(default_factory=ThoughtSolverConfig)
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self.llm.use_system_prompt = False
async def solve(self, init_prompt):
"""
Solve method for subclasses to implement.
"""
raise NotImplementedError("Subclasses must implement the solve method")
async def generate_thoughts(self, current_state="", current_node=None) -> List[ThoughtNode]:
"""
Generate children thoughts based on the current state.
Args:
current_state (str): The current state for which thoughts are generated.
current_node (ThoughtNode): The current node in the thought tree.
Returns:
List[ThoughtNode]: List of nodes representing the generated thoughts.
"""
state_prompt = self.config.parser.propose(
current_state=current_state, **{"n_generate_sample": self.config.n_generate_sample}
)
rsp = await self.llm.aask(msg=state_prompt + "\n" + OUTPUT_FORMAT)
thoughts = CodeParser.parse_code(block=None, text=rsp)
thoughts = eval(thoughts)
# fixme 避免不跟随生成过多nodes
# valid_thoughts = [_node for idx, _node in enumerate(thoughts) if idx < self.n_generate_sample]
return self.thought_tree.update_node(thoughts, current_node=current_node)
async def evaluate_node(self, node, parent_value) -> None:
"""
Evaluate a node and update its status and value.
Args:
node (ThoughtNode): The node to be evaluated.
parent_value (float): The parent node's value.
Returns:
None
"""
eval_prompt = self.config.parser.value(input=node.name, **{"node_id": node.id})
evaluation = await self.llm.aask(msg=eval_prompt)
value = self.config.evaluator(evaluation, **{"node_id": node.id})
status = self.config.evaluator.status_verify(value)
node.update_valid_status(status=status)
# 累计分数
node.update_value(parent_value + value)
def select_nodes(self, thought_nodes: List[ThoughtNode]) -> List[ThoughtNode]:
"""
Select nodes based on the configured selection method.
Args:
thought_nodes (List[ThoughtNode]): List of nodes to be selected.
Returns:
List[ThoughtNode]: List of selected nodes.
"""
# selection
if self.config.method_select == MethodSelect.SAMPLE:
raise NotImplementedError
elif self.config.method_select == MethodSelect.GREEDY:
select_nodes = sorted(thought_nodes, key=lambda x: x.value, reverse=True)[: self.config.n_select_sample]
for node in thought_nodes:
if node not in select_nodes:
node.parent = None # 从树中删除节点
return select_nodes
def update_solution(self):
"""
Select the result with the highest score.
Returns:
- List[ThoughtNode]: List of nodes representing the best solution.
- List[str]: List of node names forming the best solution path.
"""
best_node = max(self.thought_tree.all_nodes, key=lambda x: x.value, default=None)
best_solution_path = self.thought_tree.parse_node_path(best_node)
return [best_node], best_solution_path
class BFSSolver(ThoughtSolverBase):
async def solve(self, init_prompt=""):
"""
Solve the problem using Breadth-First Search (BFS) strategy.
Args:
init_prompt (str): The initial prompt for the solver.
Returns:
List[str]: The best solution path obtained through BFS.
"""
root = ThoughtNode(init_prompt)
self.thought_tree = ThoughtTree(root)
current_nodes = [root]
for step in range(self.config.max_steps):
solutions = await self._bfs_build(current_nodes)
selected_nodes = self.select_nodes(solutions)
current_nodes = selected_nodes
self.thought_tree.show()
best_solution, best_solution_path = self.update_solution()
logger.info(f"best solution is: {best_solution_path}")
return best_solution_path
async def _bfs_build(self, current_nodes):
"""
Build the thought tree using Breadth-First Search (BFS) strategy.
Args:
current_nodes (List[ThoughtNode]): Current nodes to expand.
Returns:
List[ThoughtNode]: The solutions obtained after expanding the current nodes.
"""
tasks = []
for node in current_nodes:
current_state = self.config.parser(node.name)
current_value = node.value
tasks.append(self.generate_and_evaluate_nodes(current_state, current_value, node))
thought_nodes_list = await asyncio.gather(*tasks)
solutions = [child_node for thought_nodes in thought_nodes_list for child_node in thought_nodes]
return solutions
async def generate_and_evaluate_nodes(self, current_state, current_value, node):
thought_nodes = await self.generate_thoughts(current_state, current_node=node)
await asyncio.gather(
*(self.evaluate_node(child_node, parent_value=current_value) for child_node in thought_nodes)
)
return thought_nodes
class DFSSolver(ThoughtSolverBase):
async def _dfs(self, root_node):
"""
Perform Depth-First Search (DFS) on the thought tree.
Args:
root_node (ThoughtNode): The root node of the thought tree.
Returns:
List[str]: The solution path obtained through DFS.
"""
impossible_state_cnt = 0
node = root_node
for step in range(self.max_steps):
current_state = self.config.parser(node.name)
current_value = node.value
thought_nodes = await self.generate_thoughts(current_state, current_node=node)
await self.evaluate_node(thought_nodes[0], parent_value=current_value)
if thought_nodes[0].valid_status is False:
impossible_state_cnt += 1
if impossible_state_cnt >= 2:
logger.info("impossible state reached, break")
break
node = thought_nodes[0]
_solution_path = self.thought_tree.parse_node_path(node)
self.thought_tree.show()
return _solution_path
async def solve(self, init_prompt="", root=ThoughtNode("")):
"""
Solve the problem using Depth-First Search (DFS) strategy.
Args:
init_prompt (str): The initial prompt for the solver.
Returns:
List[str]: The best solution path obtained through DFS.
"""
root = ThoughtNode(init_prompt)
self.thought_tree = ThoughtTree(root)
for n in range(self.config.n_solution_sample):
# fixme: 需要产生回退,当前节点不可用时回退到父节点,产生新的节点继续探索
await self._dfs(root)
best_solution, best_solution_path = self.update_solution()
logger.info(f"best solution is: {best_solution_path}")
return best_solution_path
class MCTSSolver(ThoughtSolverBase):
async def solve(self, init_prompt=""):
raise NotImplementedError
class TreeofThought(BaseModel):
config: ThoughtSolverConfig = Field(default_factory=ThoughtSolverConfig)
solver: ThoughtSolverBase = Field(default_factory=ThoughtSolverBase)
strategy: Strategy = Field(default=Strategy.BFS)
class Config:
arbitrary_types_allowed = True
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self._initialize_solver(self.strategy)
def _initialize_solver(self, strategy):
"""
Initialize the solver based on the chosen strategy.
Args:
strategy (Strategy): The strategy to use for solving.
Returns:
ThoughtSolverBase: An instance of the appropriate solver.
"""
if strategy == Strategy.BFS:
self.solver = BFSSolver(config=self.config)
elif strategy == Strategy.DFS:
self.solver = DFSSolver(config=self.config)
elif strategy == Strategy.MCTS:
self.solver = MCTSSolver(config=self.config)
else:
raise NotImplementedError(f"Invalid strategy: {strategy}, only support BFS/DFS/MCTS currently!")
async def solve(self, init_prompt=""):
"""
Solve the problem using the specified strategy.
Args:
init_prompt (str): The initial prompt for the solver.
strategy (str): The strategy to use for solving.
Returns:
Any: The solution obtained using the selected strategy.
"""
await self.solver.solve(init_prompt)

View file

@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
# @Date : 12/25/2023 9:14 PM
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
from enum import Enum
from pydantic import BaseModel, Field
from metagpt.strategy.base import BaseEvaluator, BaseParser
class MethodSelect(Enum):
SAMPLE = "sample"
GREEDY = "greedy"
class Strategy(Enum):
BFS = "BFS"
DFS = "DFS"
MCTS = "MCTS"
class ThoughtSolverConfig(BaseModel):
max_steps: int = 3
method_select: str = MethodSelect.GREEDY # ["sample"/"greedy"]
n_generate_sample: int = 5 # per node
n_select_sample: int = 3 # per path
n_solution_sample: int = 5 # only for dfs
parser: BaseParser = Field(default_factory=BaseParser)
evaluator: BaseEvaluator = Field(default_factory=BaseEvaluator)

View file

@ -95,4 +95,4 @@ class SearchEngine:
Returns:
The search results as a string or a list of dictionaries.
"""
return await self.run_func(query, max_results=max_results, as_string=as_string)
return await self.run_func(query, max_results, as_string)

View file

@ -43,7 +43,8 @@ class SerpAPIWrapper(BaseModel):
async def run(self, query, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str:
"""Run query through SerpAPI and parse result async."""
return self._process_response(await self.results(query, max_results), as_string=as_string)
result = await self.results(query, max_results)
return self._process_response(result, as_string=as_string)
async def results(self, query: str, max_results: int) -> dict:
"""Use aiohttp to run query through SerpAPI and return the results async."""

View file

@ -14,6 +14,8 @@ from typing import Literal
from selenium.webdriver.common.by import By
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.support.wait import WebDriverWait
from webdriver_manager.core.download_manager import WDMDownloadManager
from webdriver_manager.core.http import WDMHttpClient
from metagpt.config import CONFIG
from metagpt.utils.parse_html import WebPage
@ -93,6 +95,13 @@ _webdriver_manager_types = {
}
class WDMHttpProxyClient(WDMHttpClient):
def get(self, url, **kwargs):
if "proxies" not in kwargs and CONFIG.global_proxy:
kwargs["proxies"] = {"all_proxy": CONFIG.global_proxy}
return super().get(url, **kwargs)
def _gen_get_driver_func(browser_type, *args, executable_path=None):
WebDriver = getattr(importlib.import_module(f"selenium.webdriver.{browser_type}.webdriver"), "WebDriver")
Service = getattr(importlib.import_module(f"selenium.webdriver.{browser_type}.service"), "Service")
@ -101,7 +110,7 @@ def _gen_get_driver_func(browser_type, *args, executable_path=None):
if not executable_path:
module_name, type_name = _webdriver_manager_types[browser_type]
DriverManager = getattr(importlib.import_module(module_name), type_name)
driver_manager = DriverManager()
driver_manager = DriverManager(download_manager=WDMDownloadManager(http_client=WDMHttpProxyClient()))
# driver_manager.driver_cache.find_driver(driver_manager.driver))
executable_path = driver_manager.install()

View file

@ -131,13 +131,11 @@ class OutputParser:
try:
content = cls.parse_code(text=content)
except Exception:
pass
# 尝试解析list
try:
content = cls.parse_file_list(text=content)
except Exception:
pass
# 尝试解析list
try:
content = cls.parse_file_list(text=content)
except Exception:
pass
parsed_data[block] = content
return parsed_data

View file

@ -63,5 +63,5 @@ class Redis:
self._client = None
@property
def is_valid(self):
return bool(self._client)
def is_valid(self) -> bool:
return self._client is not None

View file

@ -1,15 +0,0 @@
# For unit test
-r requirements.txt
connexion[uvicorn]~=3.0.5
azure-cognitiveservices-speech~=1.31.0
duckduckgo_search
serpapi
google
httplib2
google_api_python_client
selenium
webdriver_manager
pyppeteer
#aioboto3~=11.3.0 # Used by metagpt/utils/s3.py
aioredis~=2.0.1 # Used by metagpt/utils/redis.py

View file

@ -1,13 +1,13 @@
aiohttp==3.8.4
#azure_storage==0.37.0
channels==4.0.0
# chromadb==0.3.22
# chromadb
# Django==4.1.5
# docx==0.2.4
#faiss==1.5.3
faiss_cpu==1.7.4
fire==0.4.0
typer
typer==0.9.0
# godot==0.1.1
# google_api_python_client==2.93.0 # Used by search_engine.py
lancedb==0.4.0

View file

@ -22,6 +22,31 @@ here = Path(__file__).resolve().parent
long_description = (here / "README.md").read_text(encoding="utf-8")
requirements = (here / "requirements.txt").read_text(encoding="utf-8").splitlines()
extras_require = {
"playwright": ["playwright>=1.26", "beautifulsoup4"],
"selenium": ["selenium>4", "webdriver_manager", "beautifulsoup4"],
"search-google": ["google-api-python-client==2.94.0"],
"search-ddg": ["duckduckgo-search==3.8.5"],
"ocr": ["paddlepaddle==2.4.2", "paddleocr>=2.0.1", "tabulate==0.9.0"],
"test": ["pytest", "pytest-cov", "pytest-asyncio", "pytest-mock"],
}
extras_require["test"] = [
*set(i for j in extras_require.values() for i in j),
"pytest",
"pytest-asyncio",
"pytest-cov",
"pytest-mock",
"pytest-html",
]
extras_require["pyppeteer"] = [
"pyppeteer>=1.0.2"
] # pyppeteer is unmaintained and there are conflicts with dependencies
extras_require["dev"] = (["pylint~=3.0.3", "black~=23.3.0", "isort~=5.12.0", "pre-commit~=3.6.0"],)
setup(
name="metagpt",
version="0.5.2",
@ -36,16 +61,7 @@ setup(
packages=find_packages(exclude=["contrib", "docs", "examples", "tests*"]),
python_requires=">=3.9",
install_requires=requirements,
extras_require={
"playwright": ["playwright>=1.26", "beautifulsoup4"],
"selenium": ["selenium>4", "webdriver_manager", "beautifulsoup4"],
"search-google": ["google-api-python-client==2.94.0"],
"search-ddg": ["duckduckgo-search==3.8.5"],
"pyppeteer": ["pyppeteer>=1.0.2"],
"ocr": ["paddlepaddle==2.4.2", "paddleocr>=2.0.1", "tabulate==0.9.0"],
"dev": ["pylint~=3.0.3", "black~=23.3.0", "isort~=5.12.0", "pre-commit~=3.6.0"],
"test": ["pytest", "pytest-cov", "pytest-asyncio", "pytest-mock"],
},
extras_require=extras_require,
cmdclass={
"install_mermaid": InstallMermaidCLI,
},

View file

@ -89,6 +89,7 @@ def loguru_caplog(caplog):
@pytest.fixture(scope="session", autouse=True)
def setup_and_teardown_git_repo(request):
CONFIG.git_repo = GitRepository(local_path=DEFAULT_WORKSPACE_ROOT / "unittest")
CONFIG.git_reinit = True
# Destroy git repo at the end of the test session.
def fin():

View file

@ -0,0 +1 @@
{"design_filename": "docs/system_design/20231221155954.json", "task_filename": "docs/tasks/20231221155954.json", "codes_filenames": ["game.py", "main.py"], "reason": "```json\n{\n \"game.py\": \"Add handling for no empty cells in add_new_tile function, Update score in move function\",\n \"main.py\": \"Handle game over condition in the game loop\"\n}\n```"}

View file

@ -0,0 +1 @@
{"Language": "en_us", "Programming Language": "Python", "Original Requirements": "write a 2048 game", "Project Name": "game_2048", "Product Goals": ["Create an addictive and engaging gaming experience", "Ensure smooth performance and responsiveness", "Offer customizable game settings and features"], "User Stories": ["As a player, I want to be able to play the game on different devices and screen sizes", "As a gamer, I want to be challenged with increasing difficulty levels as I progress", "As a user, I want to be able to undo my last move in the game"], "Competitive Analysis": ["2048 Game by Gabriele Cirulli: Popular and addictive, lacks advanced customization options"], "Competitive Quadrant Chart": "quadrantChart\n title \"Engagement and Customization of 2048 Games\"\n x-axis \"Low Customization\" --> \"High Customization\"\n y-axis \"Low Engagement\" --> \"High Engagement\"\n quadrant-1 \"Enhance Customization\"\n quadrant-2 \"Improve Engagement\"\n quadrant-3 \"Maintain Customization, Enhance Engagement\"\n quadrant-4 \"Highly Engaging and Customizable\"\n \"2048 Game by Gabriele Cirulli\": [0.4, 0.7]\n \"Our Target Product\": [0.6, 0.8]", "Requirement Analysis": "The product should provide an intuitive and seamless gaming experience with customizable features to enhance user engagement.", "Requirement Pool": [["P0", "Implement game logic and user interface"], ["P1", "Incorporate multiple difficulty levels and scoring system"], ["P2", "Integrate customizable game settings and undo feature"]], "UI Design draft": "The UI should have a clean and modern design with intuitive game controls and customizable settings for difficulty levels and game themes.", "Anything UNCLEAR": "..."}

View file

@ -0,0 +1 @@
{"Implementation approach": "We will use the Pygame library to create the game interface and handle user input. The game logic will be implemented using Python classes and data structures.", "File list": ["main.py", "game.py"], "Data structures and interfaces": "classDiagram\n class Game {\n -grid: List[List[int]]\n -score: int\n -game_over: bool\n +__init__()\n +reset_game()\n +move(direction: str)\n +is_game_over() bool\n +get_empty_cells() List[Tuple[int, int]]\n +add_new_tile()\n +get_score() int\n }\n class UI {\n -game: Game\n +__init__(game: Game)\n +draw_grid()\n +draw_score()\n +draw_game_over()\n +handle_input()\n }\n Game --> UI", "Program call flow": "sequenceDiagram\n participant M as Main\n participant G as Game\n participant U as UI\n M->>G: reset_game()\n M->>U: draw_grid()\n M->>U: draw_score()\n M->>U: handle_input()\n U->>G: move(direction)\n G->>G: add_new_tile()\n G->>U: draw_grid()\n G->>U: draw_score()\n G->>U: draw_game_over()\n G->>G: is_game_over()\n G->>G: get_empty_cells()\n G->>G: get_score()", "Anything UNCLEAR": "..."}

View file

@ -0,0 +1 @@
{"Required Python packages": ["pygame==2.0.1"], "Required Other language third-party packages": ["No third-party dependencies required"], "Logic Analysis": [["game.py", "Contains Game class and related functions for game logic"], ["main.py", "Contains main function, initializes the game and UI"]], "Task list": ["game.py", "main.py"], "Full API spec": "", "Shared Knowledge": "The game logic will be implemented using Python classes and data structures. The Pygame library will be used to create the game interface and handle user input.", "Anything UNCLEAR": "..."}

View file

@ -0,0 +1 @@
{"summary": "---\n## instruction:\nThe errors are caused by both the development code and the test code. The development code needs to be fixed to ensure that the `reset_game` method resets the grid properly. The test code also needs to be fixed to ensure that the `add_new_tile` test does not raise an index out of range error.\n\n## File To Rewrite:\ngame.py\n\n## Status:\nFAIL\n\n## Send To:\nEngineer\n---", "stdout": "", "stderr": "E.......F\n======================================================================\nERROR: test_add_new_tile (__main__.TestGame)\n----------------------------------------------------------------------\nTraceback (most recent call last):\n File \"/Users/xx/tests/test_game.py\", line 104, in test_add_new_tile\n self.assertIn(self.game.grid[empty_cells[0][0]][empty_cells[0][1]], [2, 4])\nIndexError: list index out of range\n\n======================================================================\nFAIL: test_reset_game (__main__.TestGame)\n----------------------------------------------------------------------\nTraceback (most recent call last):\n File \"/Users/xx/tests/test_game.py\", line 13, in test_reset_game\n self.assertEqual(self.game.grid, [[0 for _ in range(4)] for _ in range(4)])\nAssertionError: Lists differ: [[0, 0, 0, 0], [0, 2, 0, 0], [0, 0, 0, 2], [0, 0, 0, 0]] != [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]\n\nFirst differing element 1:\n[0, 2, 0, 0]\n[0, 0, 0, 0]\n\n- [[0, 0, 0, 0], [0, 2, 0, 0], [0, 0, 0, 2], [0, 0, 0, 0]]\n? --- ^\n\n+ [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]\n? +++ ^\n\n\n----------------------------------------------------------------------\nRan 9 tests in 0.002s\n\nFAILED (failures=1, errors=1)\n"}

View file

@ -12,6 +12,7 @@ import pytest
from metagpt.actions import Action
from metagpt.actions.action_node import ActionNode
from metagpt.environment import Environment
from metagpt.llm import LLM
from metagpt.roles import Role
from metagpt.schema import Message
from metagpt.team import Team
@ -76,18 +77,24 @@ async def test_action_node_one_layer():
assert "key-a" in markdown_template
assert node_dict["key-a"] == "instruction-b"
assert "key-a" in repr(node)
@pytest.mark.asyncio
async def test_action_node_two_layer():
node_a = ActionNode(key="key-a", expected_type=str, instruction="i-a", example="e-a")
node_b = ActionNode(key="key-b", expected_type=str, instruction="i-b", example="e-b")
node_a = ActionNode(key="reasoning", expected_type=str, instruction="reasoning step by step", example="")
node_b = ActionNode(key="answer", expected_type=str, instruction="the final answer", example="")
root = ActionNode.from_children(key="", nodes=[node_a, node_b])
assert "key-a" in root.children
root = ActionNode.from_children(key="detail answer", nodes=[node_a, node_b])
assert "reasoning" in root.children
assert node_b in root.children.values()
json_template = root.compile(context="123", schema="json", mode="auto")
assert "i-a" in json_template
# FIXME: ADD MARKDOWN SUPPORT. NEED TO TUNE MARKDOWN SYMBOL FIRST.
answer1 = await root.fill(context="what's the answer to 123+456?", schema="json", strgy="simple", llm=LLM())
assert "579" in answer1.content
answer2 = await root.fill(context="what's the answer to 123+456?", schema="json", strgy="complex", llm=LLM())
assert "579" in answer2.content
t_dict = {
@ -116,11 +123,28 @@ WRITE_TASKS_OUTPUT_MAPPING = {
"Anything UNCLEAR": (str, ...),
}
WRITE_TASKS_OUTPUT_MAPPING_MISSING = {
"Required Python third-party packages": (str, ...),
}
def test_create_model_class():
test_class = ActionNode.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING)
assert test_class.__name__ == "test_class"
output = test_class(**t_dict)
print(output.schema())
assert output.schema()["title"] == "test_class"
assert output.schema()["type"] == "object"
assert output.schema()["properties"]["Full API spec"]
def test_create_model_class_missing():
test_class = ActionNode.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING_MISSING)
assert test_class.__name__ == "test_class"
_ = test_class(**t_dict) # 这里应该要挂掉
def test_create_model_class_with_mapping():
t = ActionNode.create_model_class("test_class_1", WRITE_TASKS_OUTPUT_MAPPING)

View file

@ -1,16 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/7/1 22:50
@Author : alexanderwu
@File : test_azure_tts.py
"""
from metagpt.tools.azure_tts import AzureTTS
def test_azure_tts():
azure_tts = AzureTTS()
azure_tts.synthesize_speech("zh-CN", "zh-CN-YunxiNeural", "Boy", "你好,我是卡卡", "output.wav")
# 运行需要先配置 SUBSCRIPTION_KEY
# TODO: 这里如果要检验还要额外加上对应的asr才能确保前后生成是接近一致的但现在还没有

View file

@ -1,101 +0,0 @@
import os
import tempfile
import pytest
from metagpt.actions.clone_function import (
CloneFunction,
run_function_code,
run_function_script,
)
source_code = """
import pandas as pd
import ta
def user_indicator():
# 读取股票数据
stock_data = pd.read_csv('./tests/data/baba_stock.csv')
stock_data.head()
# 计算简单移动平均线
stock_data['SMA'] = ta.trend.sma_indicator(stock_data['Close'], window=6)
stock_data[['Date', 'Close', 'SMA']].head()
# 计算布林带
stock_data['bb_upper'], stock_data['bb_middle'], stock_data['bb_lower'] = ta.volatility.bollinger_hband_indicator(stock_data['Close'], window=20), ta.volatility.bollinger_mavg(stock_data['Close'], window=20), ta.volatility.bollinger_lband_indicator(stock_data['Close'], window=20)
stock_data[['Date', 'Close', 'bb_upper', 'bb_middle', 'bb_lower']].head()
"""
template_code = """
def stock_indicator(stock_path: str, indicators=['Simple Moving Average', 'BollingerBands', 'MACD]) -> pd.DataFrame:
import pandas as pd
# here is your code.
"""
def get_expected_res():
import pandas as pd
import ta
# 读取股票数据
stock_data = pd.read_csv("./tests/data/baba_stock.csv")
stock_data.head()
# 计算简单移动平均线
stock_data["SMA"] = ta.trend.sma_indicator(stock_data["Close"], window=6)
stock_data[["Date", "Close", "SMA"]].head()
# 计算布林带
stock_data["bb_upper"], stock_data["bb_middle"], stock_data["bb_lower"] = (
ta.volatility.bollinger_hband_indicator(stock_data["Close"], window=20),
ta.volatility.bollinger_mavg(stock_data["Close"], window=20),
ta.volatility.bollinger_lband_indicator(stock_data["Close"], window=20),
)
stock_data[["Date", "Close", "bb_upper", "bb_middle", "bb_lower"]].head()
return stock_data
@pytest.mark.asyncio
async def test_clone_function():
clone = CloneFunction()
code = await clone.run(template_code, source_code)
assert "def " in code
stock_path = "./tests/data/baba_stock.csv"
df, msg = run_function_code(code, "stock_indicator", stock_path)
assert not msg
expected_df = get_expected_res()
assert df.equals(expected_df)
def test_run_function_script():
# 创建一个临时文件并写入脚本内容
script_content = """def valid_function(arg1, arg2):\n return arg1 + arg2\n"""
with tempfile.NamedTemporaryFile(mode="w+", suffix=".py", delete=False) as temp_file:
temp_file.write(script_content)
temp_file_path = temp_file.name
invalid_script_content = """def valid_function(arg1, arg2)\n return arg1 + arg2\n"""
with tempfile.NamedTemporaryFile(mode="w+", suffix=".py", delete=False) as error_temp_file:
error_temp_file.write(invalid_script_content)
error_temp_file_path = error_temp_file.name
try:
# 正常情况下运行脚本
result, _ = run_function_script(temp_file_path, "valid_function", 1, arg2=2)
assert result == 3
# 不存在的脚本路径
with pytest.raises(FileNotFoundError):
run_function_script("nonexistent/path/script.py", "valid_function", 1, arg2=2)
# 无效的脚本内容
result, traceback = run_function_script(error_temp_file_path, "invalid_function", 1, arg2=2)
assert not result
assert "SyntaxError" in traceback
# 函数调用失败的情况
result, traceback = run_function_script(temp_file_path, "function_that_raises_exception", 1, arg2=2)
assert not result
assert "KeyError" in traceback
finally:
# 删除临时文件
if os.path.exists(temp_file_path):
os.remove(temp_file_path)

View file

@ -6,27 +6,26 @@
@Author : Stitch-z
@File : test_invoice_ocr.py
"""
import json
import os
from pathlib import Path
import pytest
from metagpt.actions.invoice_ocr import GenerateTable, InvoiceOCR, ReplyQuestion
from metagpt.const import TEST_DATA_PATH
@pytest.mark.asyncio
@pytest.mark.parametrize(
"invoice_path",
[
"../../data/invoices/invoice-3.jpg",
"../../data/invoices/invoice-4.zip",
Path("invoices/invoice-3.jpg"),
Path("invoices/invoice-4.zip"),
],
)
async def test_invoice_ocr(invoice_path: str):
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
filename = os.path.basename(invoice_path)
resp = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)
async def test_invoice_ocr(invoice_path: Path):
invoice_path = TEST_DATA_PATH / invoice_path
resp = await InvoiceOCR().run(file_path=Path(invoice_path))
assert isinstance(resp, list)
@ -34,25 +33,29 @@ async def test_invoice_ocr(invoice_path: str):
@pytest.mark.parametrize(
("invoice_path", "expected_result"),
[
("../../data/invoices/invoice-1.pdf", [{"收款人": "小明", "城市": "深圳", "总费用/元": "412.00", "开票日期": "2023年02月03日"}]),
(Path("invoices/invoice-1.pdf"), {"收款人": "小明", "城市": "深圳", "总费用/元": 412.00, "开票日期": "2023年02月03日"}),
],
)
async def test_generate_table(invoice_path: str, expected_result: list[dict]):
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
filename = os.path.basename(invoice_path)
ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)
async def test_generate_table(invoice_path: Path, expected_result: dict):
invoice_path = TEST_DATA_PATH / invoice_path
filename = invoice_path.name
ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path))
table_data = await GenerateTable().run(ocr_results=ocr_result, filename=filename)
assert json.dumps(table_data) == json.dumps(expected_result)
assert isinstance(table_data, list)
table_data = table_data[0]
assert expected_result["收款人"] == table_data["收款人"]
assert expected_result["城市"] in table_data["城市"]
assert float(expected_result["总费用/元"]) == float(table_data["总费用/元"])
assert expected_result["开票日期"] == table_data["开票日期"]
@pytest.mark.asyncio
@pytest.mark.parametrize(
("invoice_path", "query", "expected_result"),
[("../../data/invoices/invoice-1.pdf", "Invoicing date", "2023年02月03日")],
[(Path("invoices/invoice-1.pdf"), "Invoicing date", "2023年02月03日")],
)
async def test_reply_question(invoice_path: str, query: dict, expected_result: str):
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
filename = os.path.basename(invoice_path)
ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)
async def test_reply_question(invoice_path: Path, query: dict, expected_result: str):
invoice_path = TEST_DATA_PATH / invoice_path
ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path))
result = await ReplyQuestion().run(query=query, ocr_result=ocr_result)
assert expected_result in result

View file

@ -0,0 +1,124 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/12/28
@Author : mashenquan
@File : test_research.py
"""
import pytest
from metagpt.actions import CollectLinks, research
@pytest.mark.asyncio
async def test_action():
action = CollectLinks()
result = await action.run(topic="baidu")
assert result
@pytest.mark.asyncio
async def test_collect_links(mocker):
async def mock_llm_ask(self, prompt: str, system_msgs):
if "Please provide up to 2 necessary keywords" in prompt:
return '["metagpt", "llm"]'
elif "Provide up to 4 queries related to your research topic" in prompt:
return (
'["MetaGPT use cases", "The roadmap of MetaGPT", '
'"The function of MetaGPT", "What llm MetaGPT support"]'
)
elif "sort the remaining search results" in prompt:
return "[1,2]"
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_llm_ask)
resp = await research.CollectLinks().run("The application of MetaGPT")
for i in ["MetaGPT use cases", "The roadmap of MetaGPT", "The function of MetaGPT", "What llm MetaGPT support"]:
assert i in resp
@pytest.mark.asyncio
async def test_collect_links_with_rank_func(mocker):
rank_before = []
rank_after = []
url_per_query = 4
def rank_func(results):
results = results[:url_per_query]
rank_before.append(results)
results = results[::-1]
rank_after.append(results)
return results
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_collect_links_llm_ask)
resp = await research.CollectLinks(rank_func=rank_func).run("The application of MetaGPT")
for x, y, z in zip(rank_before, rank_after, resp.values()):
assert x[::-1] == y
assert [i["link"] for i in y] == z
@pytest.mark.asyncio
async def test_web_browse_and_summarize(mocker):
async def mock_llm_ask(*args, **kwargs):
return "metagpt"
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_llm_ask)
url = "https://github.com/geekan/MetaGPT"
url2 = "https://github.com/trending"
query = "What's new in metagpt"
resp = await research.WebBrowseAndSummarize().run(url, query=query)
assert len(resp) == 1
assert url in resp
assert resp[url] == "metagpt"
resp = await research.WebBrowseAndSummarize().run(url, url2, query=query)
assert len(resp) == 2
async def mock_llm_ask(*args, **kwargs):
return "Not relevant."
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_llm_ask)
resp = await research.WebBrowseAndSummarize().run(url, query=query)
assert len(resp) == 1
assert url in resp
assert resp[url] is None
@pytest.mark.asyncio
async def test_conduct_research(mocker):
data = None
async def mock_llm_ask(*args, **kwargs):
nonlocal data
data = f"# Research Report\n## Introduction\n{args} {kwargs}"
return data
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_llm_ask)
content = (
"MetaGPT takes a one line requirement as input and "
"outputs user stories / competitive analysis / requirements / data structures / APIs / documents, etc."
)
resp = await research.ConductResearch().run("The application of MetaGPT", content)
assert resp == data
async def mock_collect_links_llm_ask(self, prompt: str, system_msgs):
if "Please provide up to 2 necessary keywords" in prompt:
return '["metagpt", "llm"]'
elif "Provide up to 4 queries related to your research topic" in prompt:
return (
'["MetaGPT use cases", "The roadmap of MetaGPT", ' '"The function of MetaGPT", "What llm MetaGPT support"]'
)
elif "sort the remaining search results" in prompt:
return "[1,2]"
return ""
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -14,13 +14,13 @@ from metagpt.schema import RunCodeContext
@pytest.mark.asyncio
async def test_run_text():
result, errs = await RunCode.run_text("result = 1 + 1")
assert result == 2
assert errs == ""
out, err = await RunCode.run_text("result = 1 + 1")
assert out == 2
assert err == ""
result, errs = await RunCode.run_text("result = 1 / 0")
assert result == ""
assert "ZeroDivisionError" in errs
out, err = await RunCode.run_text("result = 1 / 0")
assert out == ""
assert "division by zero" in err
@pytest.mark.asyncio

View file

@ -58,7 +58,29 @@ class TestSkillAction:
action = SkillAction(skill=self.skill, args=parser_action.args)
rsp = await action.run()
assert rsp
assert "image/png;base64," in rsp.content
assert "image/png;base64," in rsp.content or "http" in rsp.content
@pytest.mark.parametrize(
("skill_name", "txt", "want"),
[
("skill1", 'skill1(a="1", b="2")', {"a": "1", "b": "2"}),
("skill1", '(a="1", b="2")', None),
("skill1", 'skill1(a="1", b="2"', None),
],
)
def test_parse_arguments(self, skill_name, txt, want):
args = ArgumentsParingAction.parse_arguments(skill_name, txt)
assert args == want
@pytest.mark.asyncio
async def test_find_and_call_function_error(self):
with pytest.raises(ValueError):
await SkillAction.find_and_call_function("dummy_call", {"a": 1})
@pytest.mark.asyncio
async def test_skill_action_error(self):
action = SkillAction(skill=self.skill, args={})
await action.run()
if __name__ == "__main__":

View file

@ -0,0 +1,51 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/12/28
@Author : mashenquan
@File : test_talk_action.py
"""
import pytest
from metagpt.actions.talk_action import TalkAction
from metagpt.config import CONFIG
from metagpt.schema import Message
@pytest.mark.asyncio
@pytest.mark.parametrize(
("agent_description", "language", "context", "knowledge", "history_summary"),
[
(
"mathematician",
"English",
"How old is Susie?",
"Susie is a girl born in 2011/11/14. Today is 2023/12/3",
"balabala... (useless words)",
),
(
"mathematician",
"Chinese",
"Does Susie have an apple?",
"Susie is a girl born in 2011/11/14. Today is 2023/12/3",
"Susie had an apple, and she ate it right now",
),
],
)
async def test_prompt(agent_description, language, context, knowledge, history_summary):
# Prerequisites
CONFIG.agent_description = agent_description
CONFIG.language = language
action = TalkAction(context=context, knowledge=knowledge, history_summary=history_summary)
assert "{" not in action.prompt
assert "{" not in action.prompt_gpt4
rsp = await action.run()
assert rsp
assert isinstance(rsp, Message)
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -6,12 +6,24 @@
@File : test_write_code.py
@Modifiled By: mashenquan, 2023-12-6. According to RFC 135
"""
from pathlib import Path
import pytest
from metagpt.actions.write_code import WriteCode
from metagpt.config import CONFIG
from metagpt.const import (
CODE_SUMMARIES_FILE_REPO,
SYSTEM_DESIGN_FILE_REPO,
TASK_FILE_REPO,
TEST_OUTPUTS_FILE_REPO,
)
from metagpt.logs import logger
from metagpt.provider.openai_api import OpenAILLM as LLM
from metagpt.schema import CodingContext, Document
from metagpt.utils.common import aread
from metagpt.utils.file_repository import FileRepository
from tests.metagpt.actions.mock_markdown import TASKS_2, WRITE_CODE_PROMPT_SAMPLE
@ -37,3 +49,47 @@ async def test_write_code_directly():
llm = LLM()
rsp = await llm.aask(prompt)
logger.info(rsp)
@pytest.mark.asyncio
async def test_write_code_deps():
# Prerequisites
CONFIG.src_workspace = CONFIG.git_repo.workdir / "snake1/snake1"
demo_path = Path(__file__).parent / "../../data/demo_project"
await FileRepository.save_file(
filename="test_game.py.json",
content=await aread(str(demo_path / "test_game.py.json")),
relative_path=TEST_OUTPUTS_FILE_REPO,
)
await FileRepository.save_file(
filename="20231221155954.json",
content=await aread(str(demo_path / "code_summaries.json")),
relative_path=CODE_SUMMARIES_FILE_REPO,
)
await FileRepository.save_file(
filename="20231221155954.json",
content=await aread(str(demo_path / "system_design.json")),
relative_path=SYSTEM_DESIGN_FILE_REPO,
)
await FileRepository.save_file(
filename="20231221155954.json", content=await aread(str(demo_path / "tasks.json")), relative_path=TASK_FILE_REPO
)
await FileRepository.save_file(
filename="main.py", content='if __name__ == "__main__":\nmain()', relative_path=CONFIG.src_workspace
)
context = CodingContext(
filename="game.py",
design_doc=await FileRepository.get_file(filename="20231221155954.json", relative_path=SYSTEM_DESIGN_FILE_REPO),
task_doc=await FileRepository.get_file(filename="20231221155954.json", relative_path=TASK_FILE_REPO),
code_doc=Document(filename="game.py", content="", root_path="snake1"),
)
coding_doc = Document(root_path="snake1", filename="game.py", content=context.json())
action = WriteCode(context=coding_doc)
rsp = await action.run()
assert rsp
assert rsp.code_doc.content
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -30,3 +30,13 @@ class Person:
async def test_write_docstring(style: str, part: str):
ret = await WriteDocstring().run(code, style=style)
assert part in ret
@pytest.mark.asyncio
async def test_write():
code = await WriteDocstring.write_docstring(__file__)
assert code
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -23,10 +23,14 @@ async def test_write_prd_review():
Timeline: The feature should be ready for testing in 1.5 months.
"""
write_prd_review = WritePRDReview("write_prd_review")
write_prd_review = WritePRDReview(name="write_prd_review")
prd_review = await write_prd_review.run(prd)
# We cannot exactly predict the generated PRD review, but we can check if it is a string and if it is not empty
assert isinstance(prd_review, str)
assert len(prd_review) > 0
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -6,53 +6,21 @@
@File : test_write_teaching_plan.py
"""
import asyncio
from typing import Optional
from langchain.llms.base import LLM
from pydantic import BaseModel
import pytest
from metagpt.actions.write_teaching_plan import WriteTeachingPlanPart
from metagpt.config import Config
from metagpt.schema import Message
class MockWriteTeachingPlanPart(WriteTeachingPlanPart):
def __init__(self, options, name: str = "", context=None, llm: LLM = None, topic="", language="Chinese"):
super().__init__(options, name, context, llm, topic, language)
async def _aask(self, prompt: str, system_msgs: Optional[list[str]] = None) -> str:
return f"{WriteTeachingPlanPart.DATA_BEGIN_TAG}\nprompt\n{WriteTeachingPlanPart.DATA_END_TAG}"
async def mock_write_teaching_plan_part():
class Inputs(BaseModel):
input: str
name: str
topic: str
language: str
inputs = [
{"input": "AABBCC", "name": "A", "topic": WriteTeachingPlanPart.COURSE_TITLE, "language": "C"},
{"input": "DDEEFFF", "name": "A1", "topic": "B1", "language": "C1"},
]
for i in inputs:
seed = Inputs(**i)
options = Config().runtime_options
act = MockWriteTeachingPlanPart(options=options, name=seed.name, topic=seed.topic, language=seed.language)
await act.run([Message(content="")])
assert act.topic == seed.topic
assert str(act) == seed.topic
assert act.name == seed.name
assert act.rsp == "# prompt" if seed.topic == WriteTeachingPlanPart.COURSE_TITLE else "prompt"
def test_suite():
loop = asyncio.get_event_loop()
task = loop.create_task(mock_write_teaching_plan_part())
loop.run_until_complete(task)
@pytest.mark.asyncio
@pytest.mark.parametrize(
("topic", "context"),
[("Title", "Lesson 1: Learn to draw an apple."), ("Teaching Content", "Lesson 1: Learn to draw an apple.")],
)
async def test_write_teaching_plan_part(topic, context):
action = WriteTeachingPlanPart(topic=topic, context=context)
rsp = await action.run()
assert rsp
if __name__ == "__main__":
test_suite()
pytest.main([__file__, "-s"])

View file

@ -1,36 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/6/11 21:08
@Author : alexanderwu
@File : test_milvus_store.py
"""
import random
import numpy as np
from metagpt.document_store.milvus_store import MilvusConnection, MilvusStore
from metagpt.logs import logger
book_columns = {"idx": int, "name": str, "desc": str, "emb": np.ndarray, "price": float}
book_data = [
[i for i in range(10)],
[f"book-{i}" for i in range(10)],
[f"book-desc-{i}" for i in range(10000, 10010)],
[[random.random() for _ in range(2)] for _ in range(10)],
[random.random() for _ in range(10)],
]
def test_milvus_store():
milvus_connection = MilvusConnection(alias="default", host="192.168.50.161", port="30530")
milvus_store = MilvusStore(milvus_connection)
milvus_store.drop("Book")
milvus_store.create_collection("Book", book_columns)
milvus_store.add(book_data)
milvus_store.build_index("emb")
milvus_store.load_collection()
results = milvus_store.search([[1.0, 1.0]], field="emb")
logger.info(results)
assert results

View file

@ -29,7 +29,7 @@ points = [
]
def test_milvus_store():
def test_qdrant_store():
qdrant_connection = QdrantConnection(memory=True)
vectors_config = VectorParams(size=2, distance=Distance.COSINE)
qdrant_store = QdrantStore(qdrant_connection)
@ -43,13 +43,13 @@ def test_milvus_store():
results = qdrant_store.search("Book", query=[1.0, 1.0])
assert results[0]["id"] == 2
assert results[0]["score"] == 0.999106722578389
assert results[1]["score"] == 7
assert results[1]["id"] == 7
assert results[1]["score"] == 0.9961650411397226
results = qdrant_store.search("Book", query=[1.0, 1.0], return_vector=True)
assert results[0]["id"] == 2
assert results[0]["score"] == 0.999106722578389
assert results[0]["vector"] == [0.7363563179969788, 0.6765939593315125]
assert results[1]["score"] == 7
assert results[1]["id"] == 7
assert results[1]["score"] == 0.9961650411397226
assert results[1]["vector"] == [0.7662628889083862, 0.6425272226333618]
results = qdrant_store.search(

View file

@ -7,35 +7,26 @@
@Desc : Unit tests.
"""
import base64
import pytest
from pydantic import BaseModel
from metagpt.config import CONFIG
from metagpt.learn.text_to_image import text_to_image
@pytest.mark.asyncio
async def test():
class Input(BaseModel):
input: str
size_type: str
# Prerequisites
assert CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL
assert CONFIG.OPENAI_API_KEY
inputs = [{"input": "Panda emoji", "size_type": "512x512"}]
for i in inputs:
seed = Input(**i)
base64_data = await text_to_image(seed.input)
assert base64_data != ""
print(f"{seed.input} -> {base64_data}")
flags = ";base64,"
assert flags in base64_data
ix = base64_data.find(flags) + len(flags)
declaration = base64_data[0:ix]
assert declaration
data = base64_data[ix:]
assert data
assert base64.b64decode(data, validate=True)
data = await text_to_image("Panda emoji", size_type="512x512")
assert "base64" in data or "http" in data
key = CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL
CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL = None
data = await text_to_image("Panda emoji", size_type="512x512")
assert "base64" in data or "http" in data
CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL = key
if __name__ == "__main__":

View file

@ -6,40 +6,33 @@
@File : test_text_to_speech.py
@Desc : Unit tests.
"""
import asyncio
import base64
from pydantic import BaseModel
import pytest
from metagpt.config import CONFIG
from metagpt.learn.text_to_speech import text_to_speech
async def mock_text_to_speech():
class Input(BaseModel):
input: str
@pytest.mark.asyncio
async def test_text_to_speech():
# Prerequisites
assert CONFIG.IFLYTEK_APP_ID
assert CONFIG.IFLYTEK_API_KEY
assert CONFIG.IFLYTEK_API_SECRET
assert CONFIG.AZURE_TTS_SUBSCRIPTION_KEY and CONFIG.AZURE_TTS_SUBSCRIPTION_KEY != "YOUR_API_KEY"
assert CONFIG.AZURE_TTS_REGION
inputs = [{"input": "Panda emoji"}]
# test azure
data = await text_to_speech("panda emoji")
assert "base64" in data or "http" in data
for i in inputs:
seed = Input(**i)
base64_data = await text_to_speech(seed.input)
assert base64_data != ""
print(f"{seed.input} -> {base64_data}")
flags = ";base64,"
assert flags in base64_data
ix = base64_data.find(flags) + len(flags)
declaration = base64_data[0:ix]
assert declaration
data = base64_data[ix:]
assert data
assert base64.b64decode(data, validate=True)
def test_suite():
loop = asyncio.get_event_loop()
task = loop.create_task(mock_text_to_speech())
loop.run_until_complete(task)
# test iflytek
key = CONFIG.AZURE_TTS_SUBSCRIPTION_KEY
CONFIG.AZURE_TTS_SUBSCRIPTION_KEY = ""
data = await text_to_speech("panda emoji")
assert "base64" in data or "http" in data
CONFIG.AZURE_TTS_SUBSCRIPTION_KEY = key
if __name__ == "__main__":
test_suite()
pytest.main([__file__, "-s"])

View file

@ -14,9 +14,9 @@ def test_skill_manager():
manager = SkillManager()
logger.info(manager._store)
write_prd = WritePRD("WritePRD")
write_prd = WritePRD(name="WritePRD")
write_prd.desc = "基于老板或其他人的需求进行PRD的撰写包括用户故事、需求分解等"
write_test = WriteTest("WriteTest")
write_test = WriteTest(name="WriteTest")
write_test.desc = "进行测试用例的撰写"
manager.add_skill(write_prd)
manager.add_skill(write_test)
@ -24,7 +24,7 @@ def test_skill_manager():
skill = manager.get_skill("WriteTest")
logger.info(skill)
rsp = manager.retrieve_skill("PRD")
rsp = manager.retrieve_skill("WritePRD")
logger.info(rsp)
assert rsp[0] == "WritePRD"

View file

@ -5,47 +5,64 @@
@Author : mashenquan
@File : test_brain_memory.py
"""
# import json
# from typing import List
#
# import pydantic
#
# from metagpt.memory.brain_memory import BrainMemory
# from metagpt.schema import Message
#
#
# def test_json():
# class Input(pydantic.BaseModel):
# history: List[str]
# solution: List[str]
# knowledge: List[str]
# stack: List[str]
#
# inputs = [{"history": ["a", "b"], "solution": ["c"], "knowledge": ["d", "e"], "stack": ["f"]}]
#
# for i in inputs:
# v = Input(**i)
# bm = BrainMemory()
# for h in v.history:
# msg = Message(content=h)
# bm.history.append(msg.model_dump())
# for h in v.solution:
# msg = Message(content=h)
# bm.solution.append(msg.model_dump())
# for h in v.knowledge:
# msg = Message(content=h)
# bm.knowledge.append(msg.model_dump())
# for h in v.stack:
# msg = Message(content=h)
# bm.stack.append(msg.model_dump())
# s = bm.json()
# m = json.loads(s)
# bm = BrainMemory(**m)
# assert bm
# for v in bm.history:
# msg = Message(**v)
# assert msg
#
#
# if __name__ == "__main__":
# test_json()
import pytest
from metagpt.config import LLMProviderEnum
from metagpt.llm import LLM
from metagpt.memory.brain_memory import BrainMemory
from metagpt.schema import Message
@pytest.mark.asyncio
async def test_memory():
memory = BrainMemory()
memory.add_talk(Message(content="talk"))
assert memory.history[0].role == "user"
memory.add_answer(Message(content="answer"))
assert memory.history[1].role == "assistant"
redis_key = BrainMemory.to_redis_key("none", "user_id", "chat_id")
await memory.dumps(redis_key=redis_key)
assert memory.exists("talk")
assert 1 == memory.to_int("1", 0)
memory.last_talk = "AAA"
assert memory.pop_last_talk() == "AAA"
assert memory.last_talk is None
assert memory.is_history_available
assert memory.history_text
memory = await BrainMemory.loads(redis_key=redis_key)
assert memory
@pytest.mark.parametrize(
("input", "tag", "val"),
[("[TALK]:Hello", "TALK", "Hello"), ("Hello", None, "Hello"), ("[TALK]Hello", None, "[TALK]Hello")],
)
def test_extract_info(input, tag, val):
t, v = BrainMemory.extract_info(input)
assert tag == t
assert val == v
@pytest.mark.asyncio
@pytest.mark.parametrize("llm", [LLM(provider=LLMProviderEnum.OPENAI), LLM(provider=LLMProviderEnum.METAGPT)])
async def test_memory_llm(llm):
memory = BrainMemory()
for i in range(500):
memory.add_talk(Message(content="Lily is a girl.\n"))
res = await memory.is_related("apple", "moon", llm)
assert not res
res = await memory.rewrite(sentence="apple Lily eating", context="", llm=llm)
assert "Lily" in res
res = await memory.get_title(llm=llm)
assert res
assert "Lily" in res
assert memory.history or memory.historical_summary
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -3,7 +3,7 @@
"""
@Time : 2023/5/7 17:40
@Author : alexanderwu
@File : test_base_gpt_api.py
@File : test_base_llm.py
"""
import pytest
@ -27,7 +27,7 @@ prompt_msg = "who are you"
resp_content = default_chat_resp["choices"][0]["message"]["content"]
class MockBaseGPTAPI(BaseLLM):
class MockBaseLLM(BaseLLM):
def completion(self, messages: list[dict], timeout=3):
return default_chat_resp
@ -41,12 +41,12 @@ class MockBaseGPTAPI(BaseLLM):
return default_chat_resp
def test_base_gpt_api():
def test_base_llm():
message = Message(role="user", content="hello")
assert "role" in message.to_dict()
assert "user" in str(message)
base_gpt_api = MockBaseGPTAPI()
base_llm = MockBaseLLM()
openai_funccall_resp = {
"choices": [
@ -70,37 +70,37 @@ def test_base_gpt_api():
}
]
}
func: dict = base_gpt_api.get_choice_function(openai_funccall_resp)
func: dict = base_llm.get_choice_function(openai_funccall_resp)
assert func == {
"name": "execute",
"arguments": '{\n "language": "python",\n "code": "print(\'Hello, World!\')"\n}',
}
func_args: dict = base_gpt_api.get_choice_function_arguments(openai_funccall_resp)
func_args: dict = base_llm.get_choice_function_arguments(openai_funccall_resp)
assert func_args == {"language": "python", "code": "print('Hello, World!')"}
choice_text = base_gpt_api.get_choice_text(openai_funccall_resp)
choice_text = base_llm.get_choice_text(openai_funccall_resp)
assert choice_text == openai_funccall_resp["choices"][0]["message"]["content"]
# resp = base_gpt_api.ask(prompt_msg)
# resp = base_llm.ask(prompt_msg)
# assert resp == resp_content
# resp = base_gpt_api.ask_batch([prompt_msg])
# resp = base_llm.ask_batch([prompt_msg])
# assert resp == resp_content
# resp = base_gpt_api.ask_code([prompt_msg])
# resp = base_llm.ask_code([prompt_msg])
# assert resp == resp_content
@pytest.mark.asyncio
async def test_async_base_gpt_api():
base_gpt_api = MockBaseGPTAPI()
async def test_async_base_llm():
base_llm = MockBaseLLM()
resp = await base_gpt_api.aask(prompt_msg)
resp = await base_llm.aask(prompt_msg)
assert resp == resp_content
resp = await base_gpt_api.aask_batch([prompt_msg])
resp = await base_llm.aask_batch([prompt_msg])
assert resp == resp_content
resp = await base_gpt_api.aask_code([prompt_msg])
resp = await base_llm.aask_code([prompt_msg])
assert resp == resp_content

View file

@ -0,0 +1,14 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/12/28
@Author : mashenquan
@File : test_metagpt_api.py
"""
from metagpt.config import LLMProviderEnum
from metagpt.llm import LLM
def test_llm():
llm = LLM(provider=LLMProviderEnum.METAGPT)
assert llm

View file

@ -16,6 +16,7 @@ from tests.metagpt.roles.mock import MockMessages
@pytest.mark.asyncio
async def test_architect():
# FIXME: make git as env? Or should we support
role = Architect()
role.put_message(MockMessages.req)
rsp = await role.run(MockMessages.prd)

View file

@ -36,7 +36,7 @@ async def test_run():
{
"content": "who is tulin",
"role": "user",
"id": 1,
"id": "1",
},
{"content": "The one who eaten a poison apple.", "role": "assistant"},
],
@ -53,7 +53,7 @@ async def test_run():
{
"content": "can you draw me an picture?",
"role": "user",
"id": 1,
"id": "1",
},
{"content": "Yes, of course. What do you want me to draw", "role": "assistant"},
],

View file

@ -12,6 +12,7 @@ from pathlib import Path
import pandas as pd
import pytest
from metagpt.const import DATA_PATH, TEST_DATA_PATH
from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant, InvoicePath
from metagpt.schema import Message
@ -22,29 +23,29 @@ from metagpt.schema import Message
[
(
"Invoicing date",
Path("../../data/invoices/invoice-1.pdf"),
Path("../../../data/invoice_table/invoice-1.xlsx"),
Path("invoices/invoice-1.pdf"),
Path("invoice_table/invoice-1.xlsx"),
{"收款人": "小明", "城市": "深圳", "总费用/元": 412.00, "开票日期": "2023年02月03日"},
),
(
"Invoicing date",
Path("../../data/invoices/invoice-2.png"),
Path("../../../data/invoice_table/invoice-2.xlsx"),
Path("invoices/invoice-2.png"),
Path("invoice_table/invoice-2.xlsx"),
{"收款人": "铁头", "城市": "广州", "总费用/元": 898.00, "开票日期": "2023年03月17日"},
),
(
"Invoicing date",
Path("../../data/invoices/invoice-3.jpg"),
Path("../../../data/invoice_table/invoice-3.xlsx"),
Path("invoices/invoice-3.jpg"),
Path("invoice_table/invoice-3.xlsx"),
{"收款人": "夏天", "城市": "福州", "总费用/元": 2462.00, "开票日期": "2023年08月26日"},
),
],
)
async def test_invoice_ocr_assistant(query: str, invoice_path: Path, invoice_table_path: Path, expected_result: dict):
invoice_path = Path.cwd() / invoice_path
invoice_path = TEST_DATA_PATH / invoice_path
role = InvoiceOCRAssistant()
await role.run(Message(content=query, instruct_content=InvoicePath(file_path=invoice_path)))
invoice_table_path = Path.cwd() / invoice_table_path
invoice_table_path = DATA_PATH / invoice_table_path
df = pd.read_excel(invoice_table_path)
resp = df.to_dict(orient="records")
assert isinstance(resp, list)
@ -52,5 +53,5 @@ async def test_invoice_ocr_assistant(query: str, invoice_path: Path, invoice_tab
resp = resp[0]
assert expected_result["收款人"] == resp["收款人"]
assert expected_result["城市"] in resp["城市"]
assert int(expected_result["总费用/元"]) == int(resp["总费用/元"])
assert float(expected_result["总费用/元"]) == float(resp["总费用/元"])
assert expected_result["开票日期"] == resp["开票日期"]

View file

@ -15,7 +15,7 @@ from tests.metagpt.roles.mock import MockMessages
@pytest.mark.asyncio
async def test_product_manager():
product_manager = ProductManager()
rsp = await product_manager.handle(MockMessages.req)
rsp = await product_manager.run(MockMessages.req)
logger.info(rsp)
assert len(rsp.content) > 0
assert "Product Goals" in rsp.content

View file

@ -15,5 +15,5 @@ from tests.metagpt.roles.mock import MockMessages
@pytest.mark.asyncio
async def test_project_manager():
project_manager = ProjectManager()
rsp = await project_manager.handle(MockMessages.system_design)
rsp = await project_manager.run(MockMessages.system_design)
logger.info(rsp)

View file

@ -28,7 +28,23 @@ async def mock_llm_ask(self, prompt: str, system_msgs):
async def test_researcher(mocker):
with TemporaryDirectory() as dirname:
topic = "dataiku vs. datarobot"
mocker.patch("metagpt.provider.base_gpt_api.BaseGPTAPI.aask", mock_llm_ask)
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_llm_ask)
researcher.RESEARCH_PATH = Path(dirname)
await researcher.Researcher().run(topic)
assert (researcher.RESEARCH_PATH / f"{topic}.md").read_text().startswith("# Research Report")
def test_write_report(mocker):
with TemporaryDirectory() as dirname:
for i, topic in enumerate(
[
("1./metagpt"),
('2.:"metagpt'),
("3.*?<>|metagpt"),
("4. metagpt\n"),
]
):
researcher.RESEARCH_PATH = Path(dirname)
content = "# Research Report"
researcher.Researcher().write_report(topic, content)
assert (researcher.RESEARCH_PATH / f"{i+1}. metagpt.md").read_text().startswith("# Research Report")

View file

@ -5,7 +5,6 @@
@Author : Stitch-z
@File : test_tutorial_assistant.py
"""
import shutil
import aiofiles
import pytest
@ -17,8 +16,6 @@ from metagpt.roles.tutorial_assistant import TutorialAssistant
@pytest.mark.asyncio
@pytest.mark.parametrize(("language", "topic"), [("Chinese", "Write a tutorial about pip")])
async def test_tutorial_assistant(language: str, topic: str):
shutil.rmtree(path=TUTORIAL_PATH, ignore_errors=True)
role = TutorialAssistant(language=language)
msg = await role.run(topic)
assert TUTORIAL_PATH.exists()

View file

@ -9,23 +9,25 @@ import pytest
from typer.testing import CliRunner
from metagpt.logs import logger
from metagpt.startup import app
from metagpt.team import Team
runner = CliRunner()
@pytest.mark.asyncio
async def test_team():
async def test_empty_team():
# FIXME: we're now using "metagpt" cli, so the entrance should be replaced instead.
company = Team()
company.run_project("做一个基础搜索引擎,可以支持知识库")
history = await company.run(n_round=5)
history = await company.run(idea="Build a simple search system. I will upload my files later.")
logger.info(history)
# def test_startup():
# args = ["Make a 2048 game"]
# result = runner.invoke(app, args)
def test_startup():
args = ["Make a cli snake game"]
result = runner.invoke(app, args)
logger.info(result)
logger.info(result.output)
if __name__ == "__main__":

View file

@ -10,4 +10,4 @@ def test_team():
company = Team()
company.hire([ProjectManager()])
assert len(company.environment.roles) == 1
assert len(company.env.roles) == 1

View file

@ -7,6 +7,8 @@
"""
from __future__ import annotations
from typing import Callable
import pytest
from metagpt.config import CONFIG
@ -25,7 +27,7 @@ class MockSearchEnine:
@pytest.mark.asyncio
@pytest.mark.parametrize(
("search_engine_typpe", "run_func", "max_results", "as_string"),
("search_engine_type", "run_func", "max_results", "as_string"),
[
(SearchEngineType.SERPAPI_GOOGLE, None, 8, True),
(SearchEngineType.SERPAPI_GOOGLE, None, 4, False),
@ -39,23 +41,18 @@ class MockSearchEnine:
(SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 6, False),
],
)
async def test_search_engine(
search_engine_typpe,
run_func,
max_results,
as_string,
):
async def test_search_engine(search_engine_type, run_func: Callable, max_results: int, as_string: bool):
# Prerequisites
if search_engine_typpe is SearchEngineType.SERPAPI_GOOGLE:
if search_engine_type is SearchEngineType.SERPAPI_GOOGLE:
assert CONFIG.SERPAPI_API_KEY and CONFIG.SERPAPI_API_KEY != "YOUR_API_KEY"
elif search_engine_typpe is SearchEngineType.DIRECT_GOOGLE:
elif search_engine_type is SearchEngineType.DIRECT_GOOGLE:
assert CONFIG.GOOGLE_API_KEY and CONFIG.GOOGLE_API_KEY != "YOUR_API_KEY"
assert CONFIG.GOOGLE_CSE_ID and CONFIG.GOOGLE_CSE_ID != "YOUR_CSE_ID"
elif search_engine_typpe is SearchEngineType.SERPER_GOOGLE:
elif search_engine_type is SearchEngineType.SERPER_GOOGLE:
assert CONFIG.SERPER_API_KEY and CONFIG.SERPER_API_KEY != "YOUR_API_KEY"
search_engine = SearchEngine(search_engine_typpe, run_func)
rsp = await search_engine.run("metagpt", max_results=max_results, as_string=as_string)
search_engine = SearchEngine(search_engine_type, run_func)
rsp = await search_engine.run("metagpt", max_results, as_string)
logger.info(rsp)
if as_string:
assert isinstance(rsp, str)

View file

@ -111,27 +111,27 @@ class TestCodeParser:
def test_parse_blocks(self, parser, text):
result = parser.parse_blocks(text)
print(result)
assert result == {"title": "content", "title2": "content2"}
assert "game.py" in result["Task list"]
def test_parse_block(self, parser, text):
result = parser.parse_block("title", text)
result = parser.parse_block("Task list", text)
print(result)
assert result == "content"
assert "game.py" in result
def test_parse_code(self, parser, text):
result = parser.parse_code("title", text, "python")
result = parser.parse_code("Task list", text, "python")
print(result)
assert result == "print('hello world')"
assert "game.py" in result
def test_parse_str(self, parser, text):
result = parser.parse_str("title", text, "python")
result = parser.parse_str("Anything UNCLEAR", text, "python")
print(result)
assert result == "hello world"
assert "We need clarification on how the high score " in result
def test_parse_file_list(self, parser, text):
result = parser.parse_file_list("Task list", text)
print(result)
assert result == ["task1", "task2"]
assert "game.py" in result
if __name__ == "__main__":

View file

@ -47,7 +47,8 @@ class TestGetProjectRoot:
def test_get_project_root(self):
project_root = get_metagpt_root()
assert project_root.name == "MetaGPT"
src_path = project_root / "metagpt"
assert src_path.exists()
def test_get_root_exception(self):
self.change_etc_dir()

View file

@ -21,10 +21,11 @@ def test_config_class_get_key_exception():
def test_config_yaml_file_not_exists():
config = Config("wtf.yaml")
with pytest.raises(Exception) as exc_info:
config.get("OPENAI_BASE_URL")
assert str(exc_info.value) == "Set OPENAI_API_KEY or Anthropic_API_KEY first"
# FIXME: 由于这里是单例所以会导致Config重新创建失效。后续要将Config改为非单例模式。
_ = Config("wtf.yaml")
# with pytest.raises(Exception) as exc_info:
# config.get("OPENAI_BASE_URL")
# assert str(exc_info.value) == "Set OPENAI_API_KEY or Anthropic_API_KEY first"
def test_options():

View file

@ -10,29 +10,31 @@ import pytest
from metagpt.config import CONFIG
from metagpt.utils.common import check_cmd_exists
from metagpt.utils.mermaid import MMC1, MMC2, mermaid_to_file
from metagpt.utils.mermaid import MMC1, mermaid_to_file
@pytest.mark.asyncio
@pytest.mark.parametrize("engine", ["nodejs", "playwright", "pyppeteer", "ink"])
@pytest.mark.parametrize("engine", ["nodejs", "ink"]) # TODO: playwright and pyppeteer
async def test_mermaid(engine):
# Prerequisites
# npm install -g @mermaid-js/mermaid-cli
# nodejs prerequisites: npm install -g @mermaid-js/mermaid-cli
# ink prerequisites: connected to internet
# playwright prerequisites: playwright install --with-deps chromium
assert check_cmd_exists("npm") == 0
assert CONFIG.PYPPETEER_EXECUTABLE_PATH
CONFIG.mermaid_engine = engine
save_to = CONFIG.git_repo.workdir / f"{CONFIG.mermaid_engine}/1"
await mermaid_to_file(MMC1, save_to)
for ext in [".pdf", ".svg", ".png"]:
assert save_to.with_suffix(ext).exists()
save_to.with_suffix(ext).unlink(missing_ok=True)
save_to = CONFIG.git_repo.workdir / f"{CONFIG.mermaid_engine}/2"
await mermaid_to_file(MMC2, save_to)
for ext in [".pdf", ".svg", ".png"]:
assert save_to.with_suffix(ext).exists()
save_to.with_suffix(ext).unlink(missing_ok=True)
# ink does not support pdf
if engine == "ink":
for ext in [".svg", ".png"]:
assert save_to.with_suffix(ext).exists()
save_to.with_suffix(ext).unlink(missing_ok=True)
else:
for ext in [".pdf", ".svg", ".png"]:
assert save_to.with_suffix(ext).exists()
save_to.with_suffix(ext).unlink(missing_ok=True)
if __name__ == "__main__":

View file

@ -54,13 +54,13 @@ def test_parse_file_list():
expected_result = ["file1", "file2", "file3"]
assert OutputParser.parse_file_list(test_text) == expected_result
with pytest.raises(Exception):
OutputParser.parse_file_list("wrong_input")
# with pytest.raises(Exception):
# OutputParser.parse_file_list("wrong_input")
def test_parse_data():
test_data = "##block1\n```python\nprint('Hello, world!')\n```\n##block2\nfiles=['file1', 'file2', 'file3']"
expected_result = {"block1": "print('Hello, world!')", "block2": ["file1", "file2", "file3"]}
expected_result = {"block1": "print('Hello, world!')\n", "block2": ["file1", "file2", "file3"]}
assert OutputParser.parse_data(test_data) == expected_result
@ -94,7 +94,7 @@ def test_parse_data():
(
"""xxx xx""",
list,
None,
[],
[],
),
(

View file

@ -45,9 +45,11 @@ async def test_s3():
@pytest.mark.asyncio
async def test_s3_no_error():
conn = S3()
key = conn.auth_config["aws_secret_access_key"]
conn.auth_config["aws_secret_access_key"] = ""
res = await conn.cache("ABC", ".bak", "script")
assert not res
conn.auth_config["aws_secret_access_key"] = key
if __name__ == "__main__":