From 8bf7d3186a003052fae6c71c84871cb6dccf8e8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Wed, 27 Dec 2023 22:46:39 +0800 Subject: [PATCH 1/3] feat: Action Node + exclude parameter refactor: awrite --- metagpt/actions/action_node.py | 53 +++++++++++-------- metagpt/actions/write_prd.py | 6 +-- metagpt/actions/write_prd_an.py | 4 +- metagpt/tools/ut_writer.py | 25 ++------- metagpt/utils/common.py | 8 +++ tests/metagpt/learn/test_text_to_embedding.py | 4 +- tests/metagpt/utils/test_common.py | 11 ++++ 7 files changed, 58 insertions(+), 53 deletions(-) diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index b554f15dd..9534e91c5 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -117,19 +117,20 @@ class ActionNode: obj.add_children(nodes) return obj - def get_children_mapping(self) -> Dict[str, Tuple[Type, Any]]: + def get_children_mapping(self, exclude=None) -> Dict[str, Tuple[Type, Any]]: """获得子ActionNode的字典,以key索引""" - return {k: (v.expected_type, ...) for k, v in self.children.items()} + exclude = exclude or [] + return {k: (v.expected_type, ...) for k, v in self.children.items() if k not in exclude} def get_self_mapping(self) -> Dict[str, Tuple[Type, Any]]: """get self key: type mapping""" return {self.key: (self.expected_type, ...)} - def get_mapping(self, mode="children") -> Dict[str, Tuple[Type, Any]]: + def get_mapping(self, mode="children", exclude=None) -> Dict[str, Tuple[Type, Any]]: """get key: type mapping under mode""" if mode == "children" or (mode == "auto" and self.children): - return self.get_children_mapping() - return self.get_self_mapping() + return self.get_children_mapping(exclude=exclude) + return {} if exclude and self.key in exclude else self.get_self_mapping() @classmethod def create_model_class(cls, class_name: str, mapping: Dict[str, Tuple[Type, Any]]): @@ -154,13 +155,13 @@ class ActionNode: new_class.__root_validator_check_missing_fields = classmethod(check_missing_fields) return new_class - def create_children_class(self): + def create_children_class(self, exclude=None): """使用object内有的字段直接生成model_class""" class_name = f"{self.key}_AN" - mapping = self.get_children_mapping() + mapping = self.get_children_mapping(exclude=exclude) return self.create_model_class(class_name, mapping) - def to_dict(self, format_func=None, mode="auto") -> Dict: + def to_dict(self, format_func=None, mode="auto", exclude=None) -> Dict: """将当前节点与子节点都按照node: format的格式组织成字典""" # 如果没有提供格式化函数,使用默认的格式化方式 @@ -180,7 +181,10 @@ class ActionNode: return node_dict # 遍历子节点并递归调用 to_dict 方法 + exclude = exclude or [] for _, child_node in self.children.items(): + if child_node.key in exclude: + continue node_dict.update(child_node.to_dict(format_func)) return node_dict @@ -201,25 +205,25 @@ class ActionNode: else: # markdown return f"[{tag}]\n" + text + f"\n[/{tag}]" - def _compile_f(self, schema, mode, tag, format_func, kv_sep) -> str: - nodes = self.to_dict(format_func=format_func, mode=mode) + def _compile_f(self, schema, mode, tag, format_func, kv_sep, exclude=None) -> str: + nodes = self.to_dict(format_func=format_func, mode=mode, exclude=exclude) text = self.compile_to(nodes, schema, kv_sep) return self.tagging(text, schema, tag) - def compile_instruction(self, schema="markdown", mode="children", tag="") -> str: + def compile_instruction(self, schema="markdown", mode="children", tag="", exclude=None) -> str: """compile to raw/json/markdown template with all/root/children nodes""" format_func = lambda i: f"{i.expected_type} # {i.instruction}" - return self._compile_f(schema, mode, tag, format_func, kv_sep=": ") + return self._compile_f(schema, mode, tag, format_func, kv_sep=": ", exclude=exclude) - def compile_example(self, schema="json", mode="children", tag="") -> str: + def compile_example(self, schema="json", mode="children", tag="", exclude=None) -> str: """compile to raw/json/markdown examples with all/root/children nodes""" # 这里不能使用f-string,因为转译为str后再json.dumps会额外加上引号,无法作为有效的example # 错误示例:"File list": "['main.py', 'const.py', 'game.py']", 注意这里值不是list,而是str format_func = lambda i: i.example - return self._compile_f(schema, mode, tag, format_func, kv_sep="\n") + return self._compile_f(schema, mode, tag, format_func, kv_sep="\n", exclude=exclude) - def compile(self, context, schema="json", mode="children", template=SIMPLE_TEMPLATE) -> str: + def compile(self, context, schema="json", mode="children", template=SIMPLE_TEMPLATE, exclude=[]) -> str: """ mode: all/root/children mode="children": 编译所有子节点为一个统一模板,包括instruction与example @@ -235,8 +239,8 @@ class ActionNode: # FIXME: json instruction会带来格式问题,如:"Project name": "web_2048 # 项目名称使用下划线", # compile example暂时不支持markdown - instruction = self.compile_instruction(schema="markdown", mode=mode) - example = self.compile_example(schema=schema, tag=TAG, mode=mode) + instruction = self.compile_instruction(schema="markdown", mode=mode, exclude=exclude) + example = self.compile_example(schema=schema, tag=TAG, mode=mode, exclude=exclude) # nodes = ", ".join(self.to_dict(mode=mode).keys()) constraints = [LANGUAGE_CONSTRAINT, FORMAT_CONSTRAINT] constraint = "\n".join(constraints) @@ -291,11 +295,11 @@ class ActionNode: def set_context(self, context): self.set_recursive("context", context) - async def simple_fill(self, schema, mode, timeout=CONFIG.timeout): - prompt = self.compile(context=self.context, schema=schema, mode=mode) + async def simple_fill(self, schema, mode, timeout=CONFIG.timeout, exclude=None): + prompt = self.compile(context=self.context, schema=schema, mode=mode, exclude=exclude) if schema != "raw": - mapping = self.get_mapping(mode) + mapping = self.get_mapping(mode, exclude=exclude) class_name = f"{self.key}_AN" content, scontent = await self._aask_v1(prompt, class_name, mapping, schema=schema, timeout=timeout) self.content = content @@ -306,7 +310,7 @@ class ActionNode: return self - async def fill(self, context, llm, schema="json", mode="auto", strgy="simple", timeout=CONFIG.timeout): + async def fill(self, context, llm, schema="json", mode="auto", strgy="simple", timeout=CONFIG.timeout, exclude=[]): """Fill the node(s) with mode. :param context: Everything we should know when filling node. @@ -323,6 +327,7 @@ class ActionNode: - simple: run only once - complex: run each node :param timeout: Timeout for llm invocation. + :param exclude: The keys of ActionNode to exclude. :return: self """ self.set_llm(llm) @@ -331,12 +336,14 @@ class ActionNode: schema = self.schema if strgy == "simple": - return await self.simple_fill(schema=schema, mode=mode, timeout=timeout) + return await self.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude) elif strgy == "complex": # 这里隐式假设了拥有children tmp = {} for _, i in self.children.items(): - child = await i.simple_fill(schema=schema, mode=mode, timeout=timeout) + if exclude and i.key in exclude: + continue + child = await i.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude) tmp.update(child.instruct_content.dict()) cls = self.create_children_class() self.instruct_content = cls(**tmp) diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index 289354a11..de647f167 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -23,10 +23,10 @@ from metagpt.actions import Action, ActionOutput from metagpt.actions.action_node import ActionNode from metagpt.actions.fix_bug import FixBug from metagpt.actions.write_prd_an import ( + PROJECT_NAME, WP_IS_RELATIVE_NODE, WP_ISSUE_TYPE_NODE, WRITE_PRD_NODE, - WRITE_PRD_NODE_NO_NAME, ) from metagpt.config import CONFIG from metagpt.const import ( @@ -124,8 +124,8 @@ class WritePRD(Action): # logger.info(rsp) project_name = CONFIG.project_name if CONFIG.project_name else "" context = CONTEXT_TEMPLATE.format(requirements=requirements, project_name=project_name) - write_prd_node = WRITE_PRD_NODE if not project_name else WRITE_PRD_NODE_NO_NAME - node = await write_prd_node.fill(context=context, llm=self.llm) # schema=schema + exclude = [PROJECT_NAME.key] if project_name else [] + node = await WRITE_PRD_NODE.fill(context=context, llm=self.llm, exclude=exclude) # schema=schema await self._rename_workspace(node) return node diff --git a/metagpt/actions/write_prd_an.py b/metagpt/actions/write_prd_an.py index e33da2451..948d7d62f 100644 --- a/metagpt/actions/write_prd_an.py +++ b/metagpt/actions/write_prd_an.py @@ -141,6 +141,7 @@ NODES = [ LANGUAGE, PROGRAMMING_LANGUAGE, ORIGINAL_REQUIREMENTS, + PROJECT_NAME, PRODUCT_GOALS, USER_STORIES, COMPETITIVE_ANALYSIS, @@ -151,8 +152,7 @@ NODES = [ ANYTHING_UNCLEAR, ] -WRITE_PRD_NODE = ActionNode.from_children("WritePRD", NODES + [PROJECT_NAME]) -WRITE_PRD_NODE_NO_NAME = ActionNode.from_children("WritePRD", NODES) +WRITE_PRD_NODE = ActionNode.from_children("WritePRD", NODES) WP_ISSUE_TYPE_NODE = ActionNode.from_children("WP_ISSUE_TYPE", [ISSUE_TYPE, REASON]) WP_IS_RELATIVE_NODE = ActionNode.from_children("WP_IS_RELATIVE", [IS_RELATIVE, REASON]) diff --git a/metagpt/tools/ut_writer.py b/metagpt/tools/ut_writer.py index 41b2acbd5..f2f2bf51c 100644 --- a/metagpt/tools/ut_writer.py +++ b/metagpt/tools/ut_writer.py @@ -4,9 +4,8 @@ import json from pathlib import Path -import aiofiles - from metagpt.provider.openai_api import OpenAILLM as GPTAPI +from metagpt.utils.common import awrite ICL_SAMPLE = """Interface definition: ```text @@ -255,20 +254,14 @@ class UTGenerator: return doc - async def _store(self, data, base, folder, fname): - """Store data in a file.""" - file_path = self.get_file_path(Path(base) / folder, fname) - async with aiofiles.open(file_path, mode="w", encoding="utf-8") as file: - await file.write(data) - async def ask_gpt_and_save(self, question: str, tag: str, fname: str): """Generate questions and store both questions and answers""" messages = [self.icl_sample, question] result = await self.gpt_msgs_to_code(messages=messages) - await self._store(question, self.questions_path, tag, f"{fname}.txt") + await awrite(Path(self.questions_path) / tag / f"{fname}.txt", question) data = result.get("code", "") if result else "" - await self._store(data, self.ut_py_path, tag, f"{fname}.py") + await awrite(Path(self.ut_py_path) / tag / f"{fname}.py", data) async def _generate_ut(self, tag, paths): """Process the structure under a data path @@ -291,15 +284,3 @@ class UTGenerator: result = await GPTAPI().aask_code(messages=messages) return result - - def get_file_path(self, base: Path, fname: str): - """Save different file paths - - Args: - base (str): Path - fname (str): File name - """ - path = Path(base) - path.mkdir(parents=True, exist_ok=True) - file_path = path / fname - return str(file_path) diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index ced17bb7f..f03de1da1 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -537,6 +537,14 @@ async def aread(file_path: str) -> str: return content +async def awrite(filename: str | Path, data: str): + """Write file asynchronously.""" + pathname = Path(filename) + pathname.parent.mkdir(parents=True, exist_ok=True) + async with aiofiles.open(str(pathname), mode="w", encoding="utf-8") as writer: + await writer.write(data) + + async def read_file_block(filename: str | Path, lineno: int, end_lineno: int): if not Path(filename).exists(): return "" diff --git a/tests/metagpt/learn/test_text_to_embedding.py b/tests/metagpt/learn/test_text_to_embedding.py index e3d20a759..f9ad20ee7 100644 --- a/tests/metagpt/learn/test_text_to_embedding.py +++ b/tests/metagpt/learn/test_text_to_embedding.py @@ -12,7 +12,6 @@ import asyncio from pydantic import BaseModel from metagpt.learn.text_to_embedding import text_to_embedding -from metagpt.tools.openai_text_to_embedding import ResultEmbedding async def mock_text_to_embedding(): @@ -23,8 +22,7 @@ async def mock_text_to_embedding(): for i in inputs: seed = Input(**i) - data = await text_to_embedding(seed.input) - v = ResultEmbedding(**data) + v = await text_to_embedding(seed.input) assert len(v.data) > 0 diff --git a/tests/metagpt/utils/test_common.py b/tests/metagpt/utils/test_common.py index 5e49023a0..53708527f 100644 --- a/tests/metagpt/utils/test_common.py +++ b/tests/metagpt/utils/test_common.py @@ -9,6 +9,7 @@ import importlib import os import platform +import uuid from pathlib import Path from typing import Any, Set @@ -25,6 +26,8 @@ from metagpt.utils.common import ( OutputParser, any_to_str, any_to_str_set, + aread, + awrite, check_cmd_exists, concat_namespace, import_class_inst, @@ -170,6 +173,14 @@ class TestGetProjectRoot: async def test_read_file_block(self): assert await read_file_block(filename=__file__, lineno=6, end_lineno=6) == "@File : test_common.py\n" + @pytest.mark.asyncio + async def test_read_write(self): + pathname = Path(__file__).parent / uuid.uuid4().hex / "test.tmp" + await awrite(pathname, "ABC") + data = await aread(pathname) + assert data == "ABC" + pathname.unlink(missing_ok=True) + if __name__ == "__main__": pytest.main([__file__, "-s"]) From 16f0a0fd06a49c5006a718beacc37358c2573a1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Wed, 27 Dec 2023 22:46:39 +0800 Subject: [PATCH 2/3] feat: Action Node + exclude parameter refactor: awrite feat: +unit test --- metagpt/actions/action_node.py | 53 +++++++++++-------- metagpt/actions/prepare_documents.py | 3 +- metagpt/actions/research.py | 3 +- metagpt/actions/write_prd.py | 6 +-- metagpt/actions/write_prd_an.py | 4 +- metagpt/config.py | 14 +++-- metagpt/tools/search_engine_serpapi.py | 3 +- metagpt/tools/ut_writer.py | 25 ++------- metagpt/utils/common.py | 8 +++ tests/metagpt/actions/test_azure_tts.py | 16 ------ tests/metagpt/actions/test_research.py | 22 ++++++++ tests/metagpt/actions/test_talk_action.py | 51 ++++++++++++++++++ tests/metagpt/learn/test_text_to_embedding.py | 4 +- tests/metagpt/utils/test_common.py | 11 ++++ 14 files changed, 145 insertions(+), 78 deletions(-) delete mode 100644 tests/metagpt/actions/test_azure_tts.py create mode 100644 tests/metagpt/actions/test_research.py create mode 100644 tests/metagpt/actions/test_talk_action.py diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index b554f15dd..9534e91c5 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -117,19 +117,20 @@ class ActionNode: obj.add_children(nodes) return obj - def get_children_mapping(self) -> Dict[str, Tuple[Type, Any]]: + def get_children_mapping(self, exclude=None) -> Dict[str, Tuple[Type, Any]]: """获得子ActionNode的字典,以key索引""" - return {k: (v.expected_type, ...) for k, v in self.children.items()} + exclude = exclude or [] + return {k: (v.expected_type, ...) for k, v in self.children.items() if k not in exclude} def get_self_mapping(self) -> Dict[str, Tuple[Type, Any]]: """get self key: type mapping""" return {self.key: (self.expected_type, ...)} - def get_mapping(self, mode="children") -> Dict[str, Tuple[Type, Any]]: + def get_mapping(self, mode="children", exclude=None) -> Dict[str, Tuple[Type, Any]]: """get key: type mapping under mode""" if mode == "children" or (mode == "auto" and self.children): - return self.get_children_mapping() - return self.get_self_mapping() + return self.get_children_mapping(exclude=exclude) + return {} if exclude and self.key in exclude else self.get_self_mapping() @classmethod def create_model_class(cls, class_name: str, mapping: Dict[str, Tuple[Type, Any]]): @@ -154,13 +155,13 @@ class ActionNode: new_class.__root_validator_check_missing_fields = classmethod(check_missing_fields) return new_class - def create_children_class(self): + def create_children_class(self, exclude=None): """使用object内有的字段直接生成model_class""" class_name = f"{self.key}_AN" - mapping = self.get_children_mapping() + mapping = self.get_children_mapping(exclude=exclude) return self.create_model_class(class_name, mapping) - def to_dict(self, format_func=None, mode="auto") -> Dict: + def to_dict(self, format_func=None, mode="auto", exclude=None) -> Dict: """将当前节点与子节点都按照node: format的格式组织成字典""" # 如果没有提供格式化函数,使用默认的格式化方式 @@ -180,7 +181,10 @@ class ActionNode: return node_dict # 遍历子节点并递归调用 to_dict 方法 + exclude = exclude or [] for _, child_node in self.children.items(): + if child_node.key in exclude: + continue node_dict.update(child_node.to_dict(format_func)) return node_dict @@ -201,25 +205,25 @@ class ActionNode: else: # markdown return f"[{tag}]\n" + text + f"\n[/{tag}]" - def _compile_f(self, schema, mode, tag, format_func, kv_sep) -> str: - nodes = self.to_dict(format_func=format_func, mode=mode) + def _compile_f(self, schema, mode, tag, format_func, kv_sep, exclude=None) -> str: + nodes = self.to_dict(format_func=format_func, mode=mode, exclude=exclude) text = self.compile_to(nodes, schema, kv_sep) return self.tagging(text, schema, tag) - def compile_instruction(self, schema="markdown", mode="children", tag="") -> str: + def compile_instruction(self, schema="markdown", mode="children", tag="", exclude=None) -> str: """compile to raw/json/markdown template with all/root/children nodes""" format_func = lambda i: f"{i.expected_type} # {i.instruction}" - return self._compile_f(schema, mode, tag, format_func, kv_sep=": ") + return self._compile_f(schema, mode, tag, format_func, kv_sep=": ", exclude=exclude) - def compile_example(self, schema="json", mode="children", tag="") -> str: + def compile_example(self, schema="json", mode="children", tag="", exclude=None) -> str: """compile to raw/json/markdown examples with all/root/children nodes""" # 这里不能使用f-string,因为转译为str后再json.dumps会额外加上引号,无法作为有效的example # 错误示例:"File list": "['main.py', 'const.py', 'game.py']", 注意这里值不是list,而是str format_func = lambda i: i.example - return self._compile_f(schema, mode, tag, format_func, kv_sep="\n") + return self._compile_f(schema, mode, tag, format_func, kv_sep="\n", exclude=exclude) - def compile(self, context, schema="json", mode="children", template=SIMPLE_TEMPLATE) -> str: + def compile(self, context, schema="json", mode="children", template=SIMPLE_TEMPLATE, exclude=[]) -> str: """ mode: all/root/children mode="children": 编译所有子节点为一个统一模板,包括instruction与example @@ -235,8 +239,8 @@ class ActionNode: # FIXME: json instruction会带来格式问题,如:"Project name": "web_2048 # 项目名称使用下划线", # compile example暂时不支持markdown - instruction = self.compile_instruction(schema="markdown", mode=mode) - example = self.compile_example(schema=schema, tag=TAG, mode=mode) + instruction = self.compile_instruction(schema="markdown", mode=mode, exclude=exclude) + example = self.compile_example(schema=schema, tag=TAG, mode=mode, exclude=exclude) # nodes = ", ".join(self.to_dict(mode=mode).keys()) constraints = [LANGUAGE_CONSTRAINT, FORMAT_CONSTRAINT] constraint = "\n".join(constraints) @@ -291,11 +295,11 @@ class ActionNode: def set_context(self, context): self.set_recursive("context", context) - async def simple_fill(self, schema, mode, timeout=CONFIG.timeout): - prompt = self.compile(context=self.context, schema=schema, mode=mode) + async def simple_fill(self, schema, mode, timeout=CONFIG.timeout, exclude=None): + prompt = self.compile(context=self.context, schema=schema, mode=mode, exclude=exclude) if schema != "raw": - mapping = self.get_mapping(mode) + mapping = self.get_mapping(mode, exclude=exclude) class_name = f"{self.key}_AN" content, scontent = await self._aask_v1(prompt, class_name, mapping, schema=schema, timeout=timeout) self.content = content @@ -306,7 +310,7 @@ class ActionNode: return self - async def fill(self, context, llm, schema="json", mode="auto", strgy="simple", timeout=CONFIG.timeout): + async def fill(self, context, llm, schema="json", mode="auto", strgy="simple", timeout=CONFIG.timeout, exclude=[]): """Fill the node(s) with mode. :param context: Everything we should know when filling node. @@ -323,6 +327,7 @@ class ActionNode: - simple: run only once - complex: run each node :param timeout: Timeout for llm invocation. + :param exclude: The keys of ActionNode to exclude. :return: self """ self.set_llm(llm) @@ -331,12 +336,14 @@ class ActionNode: schema = self.schema if strgy == "simple": - return await self.simple_fill(schema=schema, mode=mode, timeout=timeout) + return await self.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude) elif strgy == "complex": # 这里隐式假设了拥有children tmp = {} for _, i in self.children.items(): - child = await i.simple_fill(schema=schema, mode=mode, timeout=timeout) + if exclude and i.key in exclude: + continue + child = await i.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude) tmp.update(child.instruct_content.dict()) cls = self.create_children_class() self.instruct_content = cls(**tmp) diff --git a/metagpt/actions/prepare_documents.py b/metagpt/actions/prepare_documents.py index 39702d3fd..97d3828bf 100644 --- a/metagpt/actions/prepare_documents.py +++ b/metagpt/actions/prepare_documents.py @@ -32,8 +32,7 @@ class PrepareDocuments(Action): def _init_repo(self): """Initialize the Git environment.""" - path = CONFIG.project_path - if not path: + if not CONFIG.project_path: name = CONFIG.project_name or FileRepository.new_filename() path = Path(CONFIG.workspace_path) / name else: diff --git a/metagpt/actions/research.py b/metagpt/actions/research.py index a6cc7cc22..5ff7af9ae 100644 --- a/metagpt/actions/research.py +++ b/metagpt/actions/research.py @@ -129,7 +129,8 @@ class CollectLinks(Action): if len(remove) == 0: break - prompt = reduce_message_length(gen_msg(), self.llm.model, system_text, CONFIG.max_tokens_rsp) + model_name = CONFIG.get_model_name(CONFIG.get_default_llm_provider_enum()) + prompt = reduce_message_length(gen_msg(), model_name, system_text, CONFIG.max_tokens_rsp) logger.debug(prompt) queries = await self._aask(prompt, [system_text]) try: diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index 289354a11..de647f167 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -23,10 +23,10 @@ from metagpt.actions import Action, ActionOutput from metagpt.actions.action_node import ActionNode from metagpt.actions.fix_bug import FixBug from metagpt.actions.write_prd_an import ( + PROJECT_NAME, WP_IS_RELATIVE_NODE, WP_ISSUE_TYPE_NODE, WRITE_PRD_NODE, - WRITE_PRD_NODE_NO_NAME, ) from metagpt.config import CONFIG from metagpt.const import ( @@ -124,8 +124,8 @@ class WritePRD(Action): # logger.info(rsp) project_name = CONFIG.project_name if CONFIG.project_name else "" context = CONTEXT_TEMPLATE.format(requirements=requirements, project_name=project_name) - write_prd_node = WRITE_PRD_NODE if not project_name else WRITE_PRD_NODE_NO_NAME - node = await write_prd_node.fill(context=context, llm=self.llm) # schema=schema + exclude = [PROJECT_NAME.key] if project_name else [] + node = await WRITE_PRD_NODE.fill(context=context, llm=self.llm, exclude=exclude) # schema=schema await self._rename_workspace(node) return node diff --git a/metagpt/actions/write_prd_an.py b/metagpt/actions/write_prd_an.py index e33da2451..948d7d62f 100644 --- a/metagpt/actions/write_prd_an.py +++ b/metagpt/actions/write_prd_an.py @@ -141,6 +141,7 @@ NODES = [ LANGUAGE, PROGRAMMING_LANGUAGE, ORIGINAL_REQUIREMENTS, + PROJECT_NAME, PRODUCT_GOALS, USER_STORIES, COMPETITIVE_ANALYSIS, @@ -151,8 +152,7 @@ NODES = [ ANYTHING_UNCLEAR, ] -WRITE_PRD_NODE = ActionNode.from_children("WritePRD", NODES + [PROJECT_NAME]) -WRITE_PRD_NODE_NO_NAME = ActionNode.from_children("WritePRD", NODES) +WRITE_PRD_NODE = ActionNode.from_children("WritePRD", NODES) WP_ISSUE_TYPE_NODE = ActionNode.from_children("WP_ISSUE_TYPE", [ISSUE_TYPE, REASON]) WP_IS_RELATIVE_NODE = ActionNode.from_children("WP_IS_RELATIVE", [IS_RELATIVE, REASON]) diff --git a/metagpt/config.py b/metagpt/config.py index 1ce12216d..82f17706f 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -110,11 +110,7 @@ class Config(metaclass=Singleton): if provider is LLMProviderEnum.GEMINI and not require_python_version(req_version=(3, 10)): warnings.warn("Use Gemini requires Python >= 3.10") - model_mappings = { - LLMProviderEnum.OPENAI: self.OPENAI_API_MODEL, - LLMProviderEnum.AZURE_OPENAI: self.DEPLOYMENT_NAME, - } - model_name = model_mappings.get(provider) + model_name = self.get_model_name(provider=provider) if model_name: logger.info(f"{provider} Model: {model_name}") if provider: @@ -122,6 +118,14 @@ class Config(metaclass=Singleton): return provider raise NotConfiguredException("You should config a LLM configuration first") + def get_model_name(self, provider=None) -> str: + provider = provider or self.get_default_llm_provider_enum() + model_mappings = { + LLMProviderEnum.OPENAI: self.OPENAI_API_MODEL, + LLMProviderEnum.AZURE_OPENAI: self.DEPLOYMENT_NAME, + } + return model_mappings.get(provider, "") + @staticmethod def _is_valid_llm_key(k: str) -> bool: return bool(k and k != "YOUR_API_KEY") diff --git a/metagpt/tools/search_engine_serpapi.py b/metagpt/tools/search_engine_serpapi.py index 750184198..b8a436cb8 100644 --- a/metagpt/tools/search_engine_serpapi.py +++ b/metagpt/tools/search_engine_serpapi.py @@ -43,7 +43,8 @@ class SerpAPIWrapper(BaseModel): async def run(self, query, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str: """Run query through SerpAPI and parse result async.""" - return self._process_response(await self.results(query, max_results), as_string=as_string) + result = await self.results(query, max_results) + return self._process_response(result, as_string=as_string) async def results(self, query: str, max_results: int) -> dict: """Use aiohttp to run query through SerpAPI and return the results async.""" diff --git a/metagpt/tools/ut_writer.py b/metagpt/tools/ut_writer.py index 41b2acbd5..f2f2bf51c 100644 --- a/metagpt/tools/ut_writer.py +++ b/metagpt/tools/ut_writer.py @@ -4,9 +4,8 @@ import json from pathlib import Path -import aiofiles - from metagpt.provider.openai_api import OpenAILLM as GPTAPI +from metagpt.utils.common import awrite ICL_SAMPLE = """Interface definition: ```text @@ -255,20 +254,14 @@ class UTGenerator: return doc - async def _store(self, data, base, folder, fname): - """Store data in a file.""" - file_path = self.get_file_path(Path(base) / folder, fname) - async with aiofiles.open(file_path, mode="w", encoding="utf-8") as file: - await file.write(data) - async def ask_gpt_and_save(self, question: str, tag: str, fname: str): """Generate questions and store both questions and answers""" messages = [self.icl_sample, question] result = await self.gpt_msgs_to_code(messages=messages) - await self._store(question, self.questions_path, tag, f"{fname}.txt") + await awrite(Path(self.questions_path) / tag / f"{fname}.txt", question) data = result.get("code", "") if result else "" - await self._store(data, self.ut_py_path, tag, f"{fname}.py") + await awrite(Path(self.ut_py_path) / tag / f"{fname}.py", data) async def _generate_ut(self, tag, paths): """Process the structure under a data path @@ -291,15 +284,3 @@ class UTGenerator: result = await GPTAPI().aask_code(messages=messages) return result - - def get_file_path(self, base: Path, fname: str): - """Save different file paths - - Args: - base (str): Path - fname (str): File name - """ - path = Path(base) - path.mkdir(parents=True, exist_ok=True) - file_path = path / fname - return str(file_path) diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index ced17bb7f..f03de1da1 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -537,6 +537,14 @@ async def aread(file_path: str) -> str: return content +async def awrite(filename: str | Path, data: str): + """Write file asynchronously.""" + pathname = Path(filename) + pathname.parent.mkdir(parents=True, exist_ok=True) + async with aiofiles.open(str(pathname), mode="w", encoding="utf-8") as writer: + await writer.write(data) + + async def read_file_block(filename: str | Path, lineno: int, end_lineno: int): if not Path(filename).exists(): return "" diff --git a/tests/metagpt/actions/test_azure_tts.py b/tests/metagpt/actions/test_azure_tts.py deleted file mode 100644 index 9995e9691..000000000 --- a/tests/metagpt/actions/test_azure_tts.py +++ /dev/null @@ -1,16 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -@Time : 2023/7/1 22:50 -@Author : alexanderwu -@File : test_azure_tts.py -""" -from metagpt.tools.azure_tts import AzureTTS - - -def test_azure_tts(): - azure_tts = AzureTTS() - azure_tts.synthesize_speech("zh-CN", "zh-CN-YunxiNeural", "Boy", "你好,我是卡卡", "output.wav") - - # 运行需要先配置 SUBSCRIPTION_KEY - # TODO: 这里如果要检验,还要额外加上对应的asr,才能确保前后生成是接近一致的,但现在还没有 diff --git a/tests/metagpt/actions/test_research.py b/tests/metagpt/actions/test_research.py new file mode 100644 index 000000000..91f83add9 --- /dev/null +++ b/tests/metagpt/actions/test_research.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/28 +@Author : mashenquan +@File : test_research.py +""" + +import pytest + +from metagpt.actions import CollectLinks + + +@pytest.mark.asyncio +async def test_action(): + action = CollectLinks() + result = await action.run(topic="baidu") + assert result + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/actions/test_talk_action.py b/tests/metagpt/actions/test_talk_action.py new file mode 100644 index 000000000..953fdf44a --- /dev/null +++ b/tests/metagpt/actions/test_talk_action.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/28 +@Author : mashenquan +@File : test_talk_action.py +""" + +import pytest + +from metagpt.actions.talk_action import TalkAction +from metagpt.config import CONFIG +from metagpt.schema import Message + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("agent_description", "language", "context", "knowledge", "history_summary"), + [ + ( + "mathematician", + "English", + "How old is Susie?", + "Susie is a girl born in 2011/11/14. Today is 2023/12/3", + "balabala... (useless words)", + ), + ( + "mathematician", + "Chinese", + "Does Susie have an apple?", + "Susie is a girl born in 2011/11/14. Today is 2023/12/3", + "Susie had an apple, and she ate it right now", + ), + ], +) +async def test_prompt(agent_description, language, context, knowledge, history_summary): + # Prerequisites + CONFIG.agent_description = agent_description + CONFIG.language = language + + action = TalkAction(context=context, knowledge=knowledge, history_summary=history_summary) + assert "{" not in action.prompt + assert "{" not in action.prompt_gpt4 + + rsp = await action.run() + assert rsp + assert isinstance(rsp, Message) + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/learn/test_text_to_embedding.py b/tests/metagpt/learn/test_text_to_embedding.py index e3d20a759..f9ad20ee7 100644 --- a/tests/metagpt/learn/test_text_to_embedding.py +++ b/tests/metagpt/learn/test_text_to_embedding.py @@ -12,7 +12,6 @@ import asyncio from pydantic import BaseModel from metagpt.learn.text_to_embedding import text_to_embedding -from metagpt.tools.openai_text_to_embedding import ResultEmbedding async def mock_text_to_embedding(): @@ -23,8 +22,7 @@ async def mock_text_to_embedding(): for i in inputs: seed = Input(**i) - data = await text_to_embedding(seed.input) - v = ResultEmbedding(**data) + v = await text_to_embedding(seed.input) assert len(v.data) > 0 diff --git a/tests/metagpt/utils/test_common.py b/tests/metagpt/utils/test_common.py index 5e49023a0..53708527f 100644 --- a/tests/metagpt/utils/test_common.py +++ b/tests/metagpt/utils/test_common.py @@ -9,6 +9,7 @@ import importlib import os import platform +import uuid from pathlib import Path from typing import Any, Set @@ -25,6 +26,8 @@ from metagpt.utils.common import ( OutputParser, any_to_str, any_to_str_set, + aread, + awrite, check_cmd_exists, concat_namespace, import_class_inst, @@ -170,6 +173,14 @@ class TestGetProjectRoot: async def test_read_file_block(self): assert await read_file_block(filename=__file__, lineno=6, end_lineno=6) == "@File : test_common.py\n" + @pytest.mark.asyncio + async def test_read_write(self): + pathname = Path(__file__).parent / uuid.uuid4().hex / "test.tmp" + await awrite(pathname, "ABC") + data = await aread(pathname) + assert data == "ABC" + pathname.unlink(missing_ok=True) + if __name__ == "__main__": pytest.main([__file__, "-s"]) From c61a3d2a99769efa74e9d7b94280a406cf44c909 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Thu, 28 Dec 2023 15:42:36 +0800 Subject: [PATCH 3/3] feat: +unit test --- metagpt/memory/brain_memory.py | 24 ++--- metagpt/utils/redis.py | 4 +- tests/data/demo_project/code_summaries.json | 1 + tests/data/demo_project/system_design.json | 1 + tests/data/demo_project/tasks.json | 1 + tests/data/demo_project/test_game.py.json | 1 + tests/metagpt/actions/test_skill_action.py | 24 ++++- tests/metagpt/actions/test_write_code.py | 56 +++++++++++ tests/metagpt/learn/test_text_to_speech.py | 47 ++++----- tests/metagpt/memory/test_brain_memory.py | 104 +++++++++++--------- 10 files changed, 177 insertions(+), 86 deletions(-) create mode 100644 tests/data/demo_project/code_summaries.json create mode 100644 tests/data/demo_project/system_design.json create mode 100644 tests/data/demo_project/tasks.json create mode 100644 tests/data/demo_project/test_game.py.json diff --git a/metagpt/memory/brain_memory.py b/metagpt/memory/brain_memory.py index c882859d8..36d5d5cdc 100644 --- a/metagpt/memory/brain_memory.py +++ b/metagpt/memory/brain_memory.py @@ -55,9 +55,9 @@ class BrainMemory(BaseModel): return "\n".join(texts) @staticmethod - async def loads(redis_key: str, redis_conf: Dict = None) -> "BrainMemory": - redis = Redis(conf=redis_conf) - if not redis.is_valid() or not redis_key: + async def loads(redis_key: str) -> "BrainMemory": + redis = Redis() + if not redis.is_valid or not redis_key: return BrainMemory() v = await redis.get(key=redis_key) logger.debug(f"REDIS GET {redis_key} {v}") @@ -67,11 +67,11 @@ class BrainMemory(BaseModel): return bm return BrainMemory() - async def dumps(self, redis_key: str, timeout_sec: int = 30 * 60, redis_conf: Dict = None): + async def dumps(self, redis_key: str, timeout_sec: int = 30 * 60): if not self.is_dirty: return - redis = Redis(conf=redis_conf) - if not redis.is_valid() or not redis_key: + redis = Redis() + if not redis.is_valid or not redis_key: return False v = self.json(ensure_ascii=False) if self.cacheable: @@ -86,26 +86,26 @@ class BrainMemory(BaseModel): async def set_history_summary(self, history_summary, redis_key, redis_conf): if self.historical_summary == history_summary: if self.is_dirty: - await self.dumps(redis_key=redis_key, redis_conf=redis_conf) + await self.dumps(redis_key=redis_key) self.is_dirty = False return self.historical_summary = history_summary self.history = [] - await self.dumps(redis_key=redis_key, redis_conf=redis_conf) + await self.dumps(redis_key=redis_key) self.is_dirty = False def add_history(self, msg: Message): if msg.id: if self.to_int(msg.id, 0) <= self.to_int(self.last_history_id, -1): return - self.history.append(msg.dict()) + self.history.append(msg) self.last_history_id = str(msg.id) self.is_dirty = True def exists(self, text) -> bool: for m in reversed(self.history): - if m.get("content") == text: + if m.content == text: return True return False @@ -163,7 +163,7 @@ class BrainMemory(BaseModel): msgs.reverse() self.history = msgs self.is_dirty = True - await self.dumps(redis_key=CONFIG.REDIS_KEY, redis_conf=CONFIG.REDIS_CONF) + await self.dumps(redis_key=CONFIG.REDIS_KEY) self.is_dirty = False return BrainMemory.to_metagpt_history_format(self.history) @@ -217,7 +217,7 @@ class BrainMemory(BaseModel): return await self._openai_rewrite(sentence=sentence, context=context, llm=llm) @staticmethod - async def _metagpt_rewrite(sentence: str): + async def _metagpt_rewrite(sentence: str, **kwargs): return sentence @staticmethod diff --git a/metagpt/utils/redis.py b/metagpt/utils/redis.py index 2246e7d11..1ad39be59 100644 --- a/metagpt/utils/redis.py +++ b/metagpt/utils/redis.py @@ -63,5 +63,5 @@ class Redis: self._client = None @property - def is_valid(self): - return bool(self._client) + def is_valid(self) -> bool: + return self._client is not None diff --git a/tests/data/demo_project/code_summaries.json b/tests/data/demo_project/code_summaries.json new file mode 100644 index 000000000..20bba0dbf --- /dev/null +++ b/tests/data/demo_project/code_summaries.json @@ -0,0 +1 @@ +{"design_filename": "docs/system_design/20231221155954.json", "task_filename": "docs/tasks/20231221155954.json", "codes_filenames": ["game.py", "main.py"], "reason": "```json\n{\n \"game.py\": \"Add handling for no empty cells in add_new_tile function, Update score in move function\",\n \"main.py\": \"Handle game over condition in the game loop\"\n}\n```"} \ No newline at end of file diff --git a/tests/data/demo_project/system_design.json b/tests/data/demo_project/system_design.json new file mode 100644 index 000000000..43c1ac764 --- /dev/null +++ b/tests/data/demo_project/system_design.json @@ -0,0 +1 @@ +{"Implementation approach": "We will use the Pygame library to create the game interface and handle user input. The game logic will be implemented using Python classes and data structures.", "File list": ["main.py", "game.py"], "Data structures and interfaces": "classDiagram\n class Game {\n -grid: List[List[int]]\n -score: int\n -game_over: bool\n +__init__()\n +reset_game()\n +move(direction: str)\n +is_game_over() bool\n +get_empty_cells() List[Tuple[int, int]]\n +add_new_tile()\n +get_score() int\n }\n class UI {\n -game: Game\n +__init__(game: Game)\n +draw_grid()\n +draw_score()\n +draw_game_over()\n +handle_input()\n }\n Game --> UI", "Program call flow": "sequenceDiagram\n participant M as Main\n participant G as Game\n participant U as UI\n M->>G: reset_game()\n M->>U: draw_grid()\n M->>U: draw_score()\n M->>U: handle_input()\n U->>G: move(direction)\n G->>G: add_new_tile()\n G->>U: draw_grid()\n G->>U: draw_score()\n G->>U: draw_game_over()\n G->>G: is_game_over()\n G->>G: get_empty_cells()\n G->>G: get_score()", "Anything UNCLEAR": "..."} \ No newline at end of file diff --git a/tests/data/demo_project/tasks.json b/tests/data/demo_project/tasks.json new file mode 100644 index 000000000..9e38f4664 --- /dev/null +++ b/tests/data/demo_project/tasks.json @@ -0,0 +1 @@ +{"Required Python packages": ["pygame==2.0.1"], "Required Other language third-party packages": ["No third-party dependencies required"], "Logic Analysis": [["game.py", "Contains Game class and related functions for game logic"], ["main.py", "Contains main function, initializes the game and UI"]], "Task list": ["game.py", "main.py"], "Full API spec": "", "Shared Knowledge": "The game logic will be implemented using Python classes and data structures. The Pygame library will be used to create the game interface and handle user input.", "Anything UNCLEAR": "..."} \ No newline at end of file diff --git a/tests/data/demo_project/test_game.py.json b/tests/data/demo_project/test_game.py.json new file mode 100644 index 000000000..143ee3c26 --- /dev/null +++ b/tests/data/demo_project/test_game.py.json @@ -0,0 +1 @@ +{"summary": "---\n## instruction:\nThe errors are caused by both the development code and the test code. The development code needs to be fixed to ensure that the `reset_game` method resets the grid properly. The test code also needs to be fixed to ensure that the `add_new_tile` test does not raise an index out of range error.\n\n## File To Rewrite:\ngame.py\n\n## Status:\nFAIL\n\n## Send To:\nEngineer\n---", "stdout": "", "stderr": "E.......F\n======================================================================\nERROR: test_add_new_tile (__main__.TestGame)\n----------------------------------------------------------------------\nTraceback (most recent call last):\n File \"/Users/xx/tests/test_game.py\", line 104, in test_add_new_tile\n self.assertIn(self.game.grid[empty_cells[0][0]][empty_cells[0][1]], [2, 4])\nIndexError: list index out of range\n\n======================================================================\nFAIL: test_reset_game (__main__.TestGame)\n----------------------------------------------------------------------\nTraceback (most recent call last):\n File \"/Users/xx/tests/test_game.py\", line 13, in test_reset_game\n self.assertEqual(self.game.grid, [[0 for _ in range(4)] for _ in range(4)])\nAssertionError: Lists differ: [[0, 0, 0, 0], [0, 2, 0, 0], [0, 0, 0, 2], [0, 0, 0, 0]] != [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]\n\nFirst differing element 1:\n[0, 2, 0, 0]\n[0, 0, 0, 0]\n\n- [[0, 0, 0, 0], [0, 2, 0, 0], [0, 0, 0, 2], [0, 0, 0, 0]]\n? --- ^\n\n+ [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]\n? +++ ^\n\n\n----------------------------------------------------------------------\nRan 9 tests in 0.002s\n\nFAILED (failures=1, errors=1)\n"} \ No newline at end of file diff --git a/tests/metagpt/actions/test_skill_action.py b/tests/metagpt/actions/test_skill_action.py index ab764930c..0e0d5d5aa 100644 --- a/tests/metagpt/actions/test_skill_action.py +++ b/tests/metagpt/actions/test_skill_action.py @@ -58,7 +58,29 @@ class TestSkillAction: action = SkillAction(skill=self.skill, args=parser_action.args) rsp = await action.run() assert rsp - assert "image/png;base64," in rsp.content + assert "image/png;base64," in rsp.content or "http" in rsp.content + + @pytest.mark.parametrize( + ("skill_name", "txt", "want"), + [ + ("skill1", 'skill1(a="1", b="2")', {"a": "1", "b": "2"}), + ("skill1", '(a="1", b="2")', None), + ("skill1", 'skill1(a="1", b="2"', None), + ], + ) + def test_parse_arguments(self, skill_name, txt, want): + args = ArgumentsParingAction.parse_arguments(skill_name, txt) + assert args == want + + @pytest.mark.asyncio + async def test_find_and_call_function_error(self): + with pytest.raises(ValueError): + await SkillAction.find_and_call_function("dummy_call", {"a": 1}) + + @pytest.mark.asyncio + async def test_skill_action_error(self): + action = SkillAction(skill=self.skill, args={}) + await action.run() if __name__ == "__main__": diff --git a/tests/metagpt/actions/test_write_code.py b/tests/metagpt/actions/test_write_code.py index 40a3b44ed..e43158f68 100644 --- a/tests/metagpt/actions/test_write_code.py +++ b/tests/metagpt/actions/test_write_code.py @@ -6,12 +6,24 @@ @File : test_write_code.py @Modifiled By: mashenquan, 2023-12-6. According to RFC 135 """ + +from pathlib import Path + import pytest from metagpt.actions.write_code import WriteCode +from metagpt.config import CONFIG +from metagpt.const import ( + CODE_SUMMARIES_FILE_REPO, + SYSTEM_DESIGN_FILE_REPO, + TASK_FILE_REPO, + TEST_OUTPUTS_FILE_REPO, +) from metagpt.logs import logger from metagpt.provider.openai_api import OpenAILLM as LLM from metagpt.schema import CodingContext, Document +from metagpt.utils.common import aread +from metagpt.utils.file_repository import FileRepository from tests.metagpt.actions.mock_markdown import TASKS_2, WRITE_CODE_PROMPT_SAMPLE @@ -37,3 +49,47 @@ async def test_write_code_directly(): llm = LLM() rsp = await llm.aask(prompt) logger.info(rsp) + + +@pytest.mark.asyncio +async def test_write_code_deps(): + # Prerequisites + CONFIG.src_workspace = CONFIG.git_repo.workdir / "snake1/snake1" + demo_path = Path(__file__).parent / "../../data/demo_project" + await FileRepository.save_file( + filename="test_game.py.json", + content=await aread(str(demo_path / "test_game.py.json")), + relative_path=TEST_OUTPUTS_FILE_REPO, + ) + await FileRepository.save_file( + filename="20231221155954.json", + content=await aread(str(demo_path / "code_summaries.json")), + relative_path=CODE_SUMMARIES_FILE_REPO, + ) + await FileRepository.save_file( + filename="20231221155954.json", + content=await aread(str(demo_path / "system_design.json")), + relative_path=SYSTEM_DESIGN_FILE_REPO, + ) + await FileRepository.save_file( + filename="20231221155954.json", content=await aread(str(demo_path / "tasks.json")), relative_path=TASK_FILE_REPO + ) + await FileRepository.save_file( + filename="main.py", content='if __name__ == "__main__":\nmain()', relative_path=CONFIG.src_workspace + ) + context = CodingContext( + filename="game.py", + design_doc=await FileRepository.get_file(filename="20231221155954.json", relative_path=SYSTEM_DESIGN_FILE_REPO), + task_doc=await FileRepository.get_file(filename="20231221155954.json", relative_path=TASK_FILE_REPO), + code_doc=Document(filename="game.py", content="", root_path="snake1"), + ) + coding_doc = Document(root_path="snake1", filename="game.py", content=context.json()) + + action = WriteCode(context=coding_doc) + rsp = await action.run() + assert rsp + assert rsp.code_doc.content + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/learn/test_text_to_speech.py b/tests/metagpt/learn/test_text_to_speech.py index 42b6839fa..2e2f223dc 100644 --- a/tests/metagpt/learn/test_text_to_speech.py +++ b/tests/metagpt/learn/test_text_to_speech.py @@ -6,40 +6,33 @@ @File : test_text_to_speech.py @Desc : Unit tests. """ -import asyncio -import base64 -from pydantic import BaseModel +import pytest +from metagpt.config import CONFIG from metagpt.learn.text_to_speech import text_to_speech -async def mock_text_to_speech(): - class Input(BaseModel): - input: str +@pytest.mark.asyncio +async def test_text_to_speech(): + # Prerequisites + assert CONFIG.IFLYTEK_APP_ID + assert CONFIG.IFLYTEK_API_KEY + assert CONFIG.IFLYTEK_API_SECRET + assert CONFIG.AZURE_TTS_SUBSCRIPTION_KEY and CONFIG.AZURE_TTS_SUBSCRIPTION_KEY != "YOUR_API_KEY" + assert CONFIG.AZURE_TTS_REGION - inputs = [{"input": "Panda emoji"}] + # test azure + data = await text_to_speech("panda emoji") + assert "base64" in data or "http" in data - for i in inputs: - seed = Input(**i) - base64_data = await text_to_speech(seed.input) - assert base64_data != "" - print(f"{seed.input} -> {base64_data}") - flags = ";base64," - assert flags in base64_data - ix = base64_data.find(flags) + len(flags) - declaration = base64_data[0:ix] - assert declaration - data = base64_data[ix:] - assert data - assert base64.b64decode(data, validate=True) - - -def test_suite(): - loop = asyncio.get_event_loop() - task = loop.create_task(mock_text_to_speech()) - loop.run_until_complete(task) + # test iflytek + key = CONFIG.AZURE_TTS_SUBSCRIPTION_KEY + CONFIG.AZURE_TTS_SUBSCRIPTION_KEY = "" + data = await text_to_speech("panda emoji") + assert "base64" in data or "http" in data + CONFIG.AZURE_TTS_SUBSCRIPTION_KEY = key if __name__ == "__main__": - test_suite() + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/memory/test_brain_memory.py b/tests/metagpt/memory/test_brain_memory.py index 32e58c70e..9244f9571 100644 --- a/tests/metagpt/memory/test_brain_memory.py +++ b/tests/metagpt/memory/test_brain_memory.py @@ -5,47 +5,63 @@ @Author : mashenquan @File : test_brain_memory.py """ -# import json -# from typing import List -# -# import pydantic -# -# from metagpt.memory.brain_memory import BrainMemory -# from metagpt.schema import Message -# -# -# def test_json(): -# class Input(pydantic.BaseModel): -# history: List[str] -# solution: List[str] -# knowledge: List[str] -# stack: List[str] -# -# inputs = [{"history": ["a", "b"], "solution": ["c"], "knowledge": ["d", "e"], "stack": ["f"]}] -# -# for i in inputs: -# v = Input(**i) -# bm = BrainMemory() -# for h in v.history: -# msg = Message(content=h) -# bm.history.append(msg.dict()) -# for h in v.solution: -# msg = Message(content=h) -# bm.solution.append(msg.dict()) -# for h in v.knowledge: -# msg = Message(content=h) -# bm.knowledge.append(msg.dict()) -# for h in v.stack: -# msg = Message(content=h) -# bm.stack.append(msg.dict()) -# s = bm.json() -# m = json.loads(s) -# bm = BrainMemory(**m) -# assert bm -# for v in bm.history: -# msg = Message(**v) -# assert msg -# -# -# if __name__ == "__main__": -# test_json() +import pytest + +from metagpt.config import LLMProviderEnum +from metagpt.llm import LLM +from metagpt.memory.brain_memory import BrainMemory +from metagpt.schema import Message + + +@pytest.mark.asyncio +async def test_memory(): + memory = BrainMemory() + memory.add_talk(Message(content="talk")) + assert memory.history[0].role == "user" + memory.add_answer(Message(content="answer")) + assert memory.history[1].role == "assistant" + redis_key = BrainMemory.to_redis_key("none", "user_id", "chat_id") + await memory.dumps(redis_key=redis_key) + assert memory.exists("talk") + assert 1 == memory.to_int("1", 0) + memory.last_talk = "AAA" + assert memory.pop_last_talk() == "AAA" + assert memory.last_talk is None + assert memory.is_history_available + assert memory.history_text + + memory = await BrainMemory.loads(redis_key=redis_key) + assert memory + + +@pytest.mark.parametrize( + ("input", "tag", "val"), + [("[TALK]:Hello", "TALK", "Hello"), ("Hello", None, "Hello"), ("[TALK]Hello", None, "[TALK]Hello")], +) +def test_extract_info(input, tag, val): + t, v = BrainMemory.extract_info(input) + assert tag == t + assert val == v + + +@pytest.mark.asyncio +@pytest.mark.parametrize("llm", [LLM(provider=LLMProviderEnum.OPENAI), LLM(provider=LLMProviderEnum.METAGPT)]) +async def test_memory_llm(llm): + memory = BrainMemory() + for i in range(500): + memory.add_talk(Message(content="Lily is a girl.\n")) + + res = await memory.is_related("apple", "moon", llm) + assert not res + + res = await memory.rewrite(sentence="apple Lily eating", context="", llm=llm) + assert "Lily" in res + + res = await memory.get_title(llm=llm) + assert res + assert "Lily" in res + assert memory.history or memory.historical_summary + + +if __name__ == "__main__": + pytest.main([__file__, "-s"])