From 10436172ca0f402f927da7c7475c4035578d37be Mon Sep 17 00:00:00 2001 From: geekan Date: Thu, 4 Jan 2024 22:02:47 +0800 Subject: [PATCH] add context and config2 --- metagpt/actions/action.py | 14 +++--- metagpt/actions/action_node.py | 6 +-- metagpt/actions/debug_error.py | 4 +- metagpt/actions/prepare_documents.py | 4 +- metagpt/actions/project_management.py | 2 +- metagpt/actions/run_code.py | 2 +- metagpt/actions/summarize_code.py | 2 +- metagpt/actions/write_code.py | 4 +- metagpt/actions/write_code_review.py | 4 +- metagpt/context.py | 4 +- metagpt/roles/assistant.py | 8 ++-- metagpt/roles/engineer.py | 10 +++-- metagpt/roles/invoice_ocr_assistant.py | 2 +- metagpt/roles/qa_engineer.py | 4 +- metagpt/roles/researcher.py | 2 +- metagpt/roles/role.py | 16 +++++-- metagpt/roles/teacher.py | 2 +- tests/conftest.py | 4 +- .../metagpt/actions/test_prepare_documents.py | 12 ++--- tests/metagpt/test_action.py | 7 --- tests/metagpt/test_document.py | 4 +- tests/metagpt/test_environment.py | 8 ++-- tests/metagpt/test_gpt.py | 45 ------------------- tests/metagpt/test_llm.py | 8 +++- tests/metagpt/test_manager.py | 7 --- 25 files changed, 72 insertions(+), 113 deletions(-) delete mode 100644 tests/metagpt/test_action.py delete mode 100644 tests/metagpt/test_gpt.py delete mode 100644 tests/metagpt/test_manager.py diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index ec80a96dd..fba396896 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -34,31 +34,31 @@ class Action(SerializationMixin, is_polymorphic_base=True): prefix: str = "" # aask*时会加上prefix,作为system_message desc: str = "" # for skill manager node: ActionNode = Field(default=None, exclude=True) - _context: Optional[Context] = Field(default=None, exclude=True) + g_context: Optional[Context] = Field(default=None, exclude=True) @property def git_repo(self): - return self._context.git_repo + return self.g_context.git_repo @property def src_workspace(self): - return self._context.src_workspace + return self.g_context.src_workspace @property def prompt_schema(self): - return self._context.config.prompt_schema + return self.g_context.config.prompt_schema @property def project_name(self): - return self._context.config.project_name + return self.g_context.config.project_name @project_name.setter def project_name(self, value): - self._context.config.project_name = value + self.g_context.config.project_name = value @property def project_path(self): - return self._context.config.project_path + return self.g_context.config.project_path @model_validator(mode="before") @classmethod diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index 16a43ea69..633fc9841 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -261,7 +261,7 @@ class ActionNode: output_data_mapping: dict, system_msgs: Optional[list[str]] = None, schema="markdown", # compatible to original format - timeout=None, + timeout=3, ) -> (str, BaseModel): """Use ActionOutput to wrap the output of aask""" content = await self.llm.aask(prompt, system_msgs, timeout=timeout) @@ -293,7 +293,7 @@ class ActionNode: def set_context(self, context): self.set_recursive("context", context) - async def simple_fill(self, schema, mode, timeout=None, exclude=None): + async def simple_fill(self, schema, mode, timeout=3, exclude=None): prompt = self.compile(context=self.context, schema=schema, mode=mode, exclude=exclude) if schema != "raw": @@ -308,7 +308,7 @@ class ActionNode: return self - async def fill(self, context, llm, schema="json", mode="auto", strgy="simple", timeout=None, exclude=[]): + async def fill(self, context, llm, schema="json", mode="auto", strgy="simple", timeout=3, exclude=[]): """Fill the node(s) with mode. :param context: Everything we should know when filling node. diff --git a/metagpt/actions/debug_error.py b/metagpt/actions/debug_error.py index 2916005c2..09823979e 100644 --- a/metagpt/actions/debug_error.py +++ b/metagpt/actions/debug_error.py @@ -51,7 +51,7 @@ Now you should start rewriting the code: class DebugError(Action): context: RunCodeContext = Field(default_factory=RunCodeContext) - _context: Optional[Context] = None + g_context: Optional[Context] = None async def run(self, *args, **kwargs) -> str: output_doc = await FileRepository.get_file( @@ -67,7 +67,7 @@ class DebugError(Action): logger.info(f"Debug and rewrite {self.context.test_filename}") code_doc = await FileRepository.get_file( - filename=self.context.code_filename, relative_path=self._context.src_workspace + filename=self.context.code_filename, relative_path=self.g_context.src_workspace ) if not code_doc: return "" diff --git a/metagpt/actions/prepare_documents.py b/metagpt/actions/prepare_documents.py index 3bd362207..afae03cb5 100644 --- a/metagpt/actions/prepare_documents.py +++ b/metagpt/actions/prepare_documents.py @@ -26,7 +26,7 @@ class PrepareDocuments(Action): @property def config(self): - return self._context.config + return self.g_context.config def _init_repo(self): """Initialize the Git environment.""" @@ -39,7 +39,7 @@ class PrepareDocuments(Action): shutil.rmtree(path) self.config.project_path = path self.config.project_name = path.name - self._context.git_repo = GitRepository(local_path=path, auto_init=True) + self.g_context.git_repo = GitRepository(local_path=path, auto_init=True) async def run(self, with_messages, **kwargs): """Create and initialize the workspace folder, initialize the Git environment.""" diff --git a/metagpt/actions/project_management.py b/metagpt/actions/project_management.py index f8ccd922a..cc35e72e2 100644 --- a/metagpt/actions/project_management.py +++ b/metagpt/actions/project_management.py @@ -41,7 +41,7 @@ class WriteTasks(Action): @property def prompt_schema(self): - return self._context.config.prompt_schema + return self.g_context.config.prompt_schema async def run(self, with_messages, schema=None): system_design_file_repo = self.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO) diff --git a/metagpt/actions/run_code.py b/metagpt/actions/run_code.py index 74ad36dae..0d42308c1 100644 --- a/metagpt/actions/run_code.py +++ b/metagpt/actions/run_code.py @@ -93,7 +93,7 @@ class RunCode(Action): additional_python_paths = [str(path) for path in additional_python_paths] # Copy the current environment variables - env = self._context.new_environ() + env = self.g_context.new_environ() # Modify the PYTHONPATH environment variable additional_python_paths = [working_directory] + additional_python_paths diff --git a/metagpt/actions/summarize_code.py b/metagpt/actions/summarize_code.py index 94f3c6541..21c0113fd 100644 --- a/metagpt/actions/summarize_code.py +++ b/metagpt/actions/summarize_code.py @@ -104,7 +104,7 @@ class SummarizeCode(Action): design_doc = await FileRepository.get_file(filename=design_pathname.name, relative_path=SYSTEM_DESIGN_FILE_REPO) task_pathname = Path(self.context.task_filename) task_doc = await FileRepository.get_file(filename=task_pathname.name, relative_path=TASK_FILE_REPO) - src_file_repo = self.git_repo.new_file_repository(relative_path=self._context.src_workspace) + src_file_repo = self.git_repo.new_file_repository(relative_path=self.g_context.src_workspace) code_blocks = [] for filename in self.context.codes_filenames: code_doc = await src_file_repo.get(filename) diff --git a/metagpt/actions/write_code.py b/metagpt/actions/write_code.py index 5b09aa2b0..0ba5477c6 100644 --- a/metagpt/actions/write_code.py +++ b/metagpt/actions/write_code.py @@ -117,7 +117,7 @@ class WriteCode(Action): coding_context.task_doc, exclude=self.context.filename, git_repo=self.git_repo, - src_workspace=self._context.src_workspace, + src_workspace=self.g_context.src_workspace, ) prompt = PROMPT_TEMPLATE.format( @@ -133,7 +133,7 @@ class WriteCode(Action): code = await self.write_code(prompt) if not coding_context.code_doc: # avoid root_path pydantic ValidationError if use WriteCode alone - root_path = self._context.src_workspace if self._context.src_workspace else "" + root_path = self.g_context.src_workspace if self.g_context.src_workspace else "" coding_context.code_doc = Document(filename=coding_context.filename, root_path=root_path) coding_context.code_doc.content = code return coding_context diff --git a/metagpt/actions/write_code_review.py b/metagpt/actions/write_code_review.py index e261f0623..4433a7ab9 100644 --- a/metagpt/actions/write_code_review.py +++ b/metagpt/actions/write_code_review.py @@ -136,14 +136,14 @@ class WriteCodeReview(Action): async def run(self, *args, **kwargs) -> CodingContext: iterative_code = self.context.code_doc.content - k = self._context.config.code_review_k_times or 1 + k = self.g_context.config.code_review_k_times or 1 for i in range(k): format_example = FORMAT_EXAMPLE.format(filename=self.context.code_doc.filename) task_content = self.context.task_doc.content if self.context.task_doc else "" code_context = await WriteCode.get_codes( self.context.task_doc, exclude=self.context.filename, - git_repo=self._context.git_repo, + git_repo=self.g_context.git_repo, src_workspace=self.src_workspace, ) context = "\n".join( diff --git a/metagpt/context.py b/metagpt/context.py index 53b673b3e..c212f6735 100644 --- a/metagpt/context.py +++ b/metagpt/context.py @@ -9,8 +9,6 @@ import os from pathlib import Path from typing import Dict, Optional -from pydantic import BaseModel - from metagpt.config2 import Config from metagpt.const import OPTIONS from metagpt.provider.base_llm import BaseLLM @@ -19,7 +17,7 @@ from metagpt.utils.cost_manager import CostManager from metagpt.utils.git_repository import GitRepository -class Context(BaseModel): +class Context: kwargs: Dict = {} config: Config = Config.default() git_repo: Optional[GitRepository] = None diff --git a/metagpt/roles/assistant.py b/metagpt/roles/assistant.py index 227578a63..d96d8a895 100644 --- a/metagpt/roles/assistant.py +++ b/metagpt/roles/assistant.py @@ -97,8 +97,10 @@ class Assistant(Role): async def talk_handler(self, text, **kwargs) -> bool: history = self.memory.history_text text = kwargs.get("last_talk") or text - self.rc.todo = TalkAction( - context=text, knowledge=self.memory.get_knowledge(), history_summary=history, llm=self.llm, **kwargs + self.set_todo( + TalkAction( + context=text, knowledge=self.memory.get_knowledge(), history_summary=history, llm=self.llm, **kwargs + ) ) return True @@ -112,7 +114,7 @@ class Assistant(Role): await action.run(**kwargs) if action.args is None: return await self.talk_handler(text=last_talk, **kwargs) - self.rc.todo = SkillAction(skill=skill, args=action.args, llm=self.llm, name=skill.name, desc=skill.description) + self.set_todo(SkillAction(skill=skill, args=action.args, llm=self.llm, name=skill.name, desc=skill.description)) return True async def refine_memory(self) -> str: diff --git a/metagpt/roles/engineer.py b/metagpt/roles/engineer.py index e20ea42a7..51c831b91 100644 --- a/metagpt/roles/engineer.py +++ b/metagpt/roles/engineer.py @@ -281,7 +281,9 @@ class Engineer(Role): f"{changed_files.docs[task_filename].model_dump_json()}" ) changed_files.docs[task_filename] = coding_doc - self.code_todos = [WriteCode(context=i, llm=self.llm) for i in changed_files.docs.values()] + self.code_todos = [ + WriteCode(context=i, g_context=self.context, llm=self.llm) for i in changed_files.docs.values() + ] # Code directly modified by the user. dependency = await self.git_repo.get_dependency() for filename in changed_src_files: @@ -295,10 +297,10 @@ class Engineer(Role): dependency=dependency, ) changed_files.docs[filename] = coding_doc - self.code_todos.append(WriteCode(context=coding_doc, llm=self.llm)) + self.code_todos.append(WriteCode(context=coding_doc, g_context=self.context, llm=self.llm)) if self.code_todos: - self.rc.todo = self.code_todos[0] + self.set_todo(self.code_todos[0]) async def _new_summarize_actions(self): src_file_repo = self.git_repo.new_file_repository(self.src_workspace) @@ -313,7 +315,7 @@ class Engineer(Role): ctx.codes_filenames = filenames self.summarize_todos.append(SummarizeCode(context=ctx, llm=self.llm)) if self.summarize_todos: - self.rc.todo = self.summarize_todos[0] + self.set_todo(self.summarize_todos[0]) @property def todo(self) -> str: diff --git a/metagpt/roles/invoice_ocr_assistant.py b/metagpt/roles/invoice_ocr_assistant.py index f5588974b..8635f4307 100644 --- a/metagpt/roles/invoice_ocr_assistant.py +++ b/metagpt/roles/invoice_ocr_assistant.py @@ -87,7 +87,7 @@ class InvoiceOCRAssistant(Role): else: self._init_actions([GenerateTable]) - self.rc.todo = None + self.set_todo(None) content = INVOICE_OCR_SUCCESS resp = OCRResults(ocr_result=json.dumps(resp)) msg = Message(content=content, instruct_content=resp) diff --git a/metagpt/roles/qa_engineer.py b/metagpt/roles/qa_engineer.py index 1a6ca2d9c..9104e3e1d 100644 --- a/metagpt/roles/qa_engineer.py +++ b/metagpt/roles/qa_engineer.py @@ -72,7 +72,7 @@ class QaEngineer(Role): ) logger.info(f"Writing {test_doc.filename}..") context = TestingContext(filename=test_doc.filename, test_doc=test_doc, code_doc=code_doc) - context = await WriteTest(context=context, _context=self.context, llm=self.llm).run() + context = await WriteTest(context=context, g_context=self.context, llm=self.llm).run() await tests_file_repo.save( filename=context.test_doc.filename, content=context.test_doc.content, @@ -137,7 +137,7 @@ class QaEngineer(Role): async def _debug_error(self, msg): run_code_context = RunCodeContext.loads(msg.content) - code = await DebugError(context=run_code_context, llm=self.llm).run() + code = await DebugError(context=run_code_context, g_context=self.context, llm=self.llm).run() await FileRepository.save_file( filename=run_code_context.test_filename, content=code, relative_path=TEST_CODES_FILE_REPO ) diff --git a/metagpt/roles/researcher.py b/metagpt/roles/researcher.py index 15f6c9a22..5110c6485 100644 --- a/metagpt/roles/researcher.py +++ b/metagpt/roles/researcher.py @@ -49,7 +49,7 @@ class Researcher(Role): if self.rc.state + 1 < len(self.states): self._set_state(self.rc.state + 1) else: - self.rc.todo = None + self.set_todo(None) return False async def _act(self) -> Message: diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 63316b5de..d17331b56 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -154,6 +154,15 @@ class Role(SerializationMixin, is_polymorphic_base=True): __hash__ = object.__hash__ # support Role as hashable type in `Environment.members` + @property + def todo(self) -> Action: + return self.rc.todo + + def set_todo(self, value: Optional[Action]): + if value: + value.g_context = self.context + self.rc.todo = value + @property def config(self): return self.context.config @@ -326,7 +335,7 @@ class Role(SerializationMixin, is_polymorphic_base=True): """Update the current state.""" self.rc.state = state logger.debug(f"actions={self.actions}, state={state}") - self.rc.todo = self.actions[self.rc.state] if state >= 0 else None + self.set_todo(self.actions[self.rc.state] if state >= 0 else None) def set_env(self, env: "Environment"): """Set the environment in which the role works. The role can talk to the environment and can also receive @@ -521,7 +530,7 @@ class Role(SerializationMixin, is_polymorphic_base=True): rsp = await self.react() # Reset the next action to be taken. - self.rc.todo = None + self.set_todo(None) # Send the response message to the Environment object to have it relay the message to the subscribers. self.publish_message(rsp) return rsp @@ -542,8 +551,9 @@ class Role(SerializationMixin, is_polymorphic_base=True): return ActionOutput(content=msg.content, instruct_content=msg.instruct_content) @property - def todo(self) -> str: + def first_action(self) -> str: """AgentStore uses this attribute to display to the user what actions the current role should take.""" + # FIXME: this is a hack, we should not use the first action to represent the todo if self.actions: return any_to_name(self.actions[0]) return "" diff --git a/metagpt/roles/teacher.py b/metagpt/roles/teacher.py index 637fd242a..f9583d49b 100644 --- a/metagpt/roles/teacher.py +++ b/metagpt/roles/teacher.py @@ -59,7 +59,7 @@ class Teacher(Role): self._set_state(self.rc.state + 1) return True - self.rc.todo = None + self.set_todo(None) return False async def _react(self) -> Message: diff --git a/tests/conftest.py b/tests/conftest.py index 1f4a73030..7ed66a61d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -104,9 +104,9 @@ class Context: @pytest.fixture(scope="package") def llm_api(): logger.info("Setting up the test") - _context = Context() + g_context = Context() - yield _context.llm_api + yield g_context.llm_api logger.info("Tearing down the test") diff --git a/tests/metagpt/actions/test_prepare_documents.py b/tests/metagpt/actions/test_prepare_documents.py index 31c8bcb80..c7fb6af20 100644 --- a/tests/metagpt/actions/test_prepare_documents.py +++ b/tests/metagpt/actions/test_prepare_documents.py @@ -9,8 +9,8 @@ import pytest from metagpt.actions.prepare_documents import PrepareDocuments -from metagpt.config import CONFIG from metagpt.const import DOCS_FILE_REPO, REQUIREMENT_FILENAME +from metagpt.context import context from metagpt.schema import Message from metagpt.utils.file_repository import FileRepository @@ -19,12 +19,12 @@ from metagpt.utils.file_repository import FileRepository async def test_prepare_documents(): msg = Message(content="New user requirements balabala...") - if CONFIG.git_repo: - CONFIG.git_repo.delete_repository() - CONFIG.git_repo = None + if context.git_repo: + context.git_repo.delete_repository() + context.git_repo = None - await PrepareDocuments().run(with_messages=[msg]) - assert CONFIG.git_repo + await PrepareDocuments(g_context=context).run(with_messages=[msg]) + assert context.git_repo doc = await FileRepository.get_file(filename=REQUIREMENT_FILENAME, relative_path=DOCS_FILE_REPO) assert doc assert doc.content == msg.content diff --git a/tests/metagpt/test_action.py b/tests/metagpt/test_action.py deleted file mode 100644 index af5106ab4..000000000 --- a/tests/metagpt/test_action.py +++ /dev/null @@ -1,7 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -@Time : 2023/5/11 14:44 -@Author : alexanderwu -@File : test_action.py -""" diff --git a/tests/metagpt/test_document.py b/tests/metagpt/test_document.py index e7b08544b..9c076f4e6 100644 --- a/tests/metagpt/test_document.py +++ b/tests/metagpt/test_document.py @@ -5,7 +5,7 @@ @Author : alexanderwu @File : test_document.py """ -from metagpt.config import CONFIG +from metagpt.config2 import config from metagpt.document import Repo from metagpt.logs import logger @@ -28,6 +28,6 @@ def load_existing_repo(path): def test_repo_set_load(): - repo_path = CONFIG.path / "test_repo" + repo_path = config.workspace.path / "test_repo" set_existing_repo(repo_path) load_existing_repo(repo_path) diff --git a/tests/metagpt/test_environment.py b/tests/metagpt/test_environment.py index 3a899d6ff..d7d8d990a 100644 --- a/tests/metagpt/test_environment.py +++ b/tests/metagpt/test_environment.py @@ -13,7 +13,7 @@ from pathlib import Path import pytest from metagpt.actions import UserRequirement -from metagpt.config import CONFIG +from metagpt.context import context from metagpt.environment import Environment from metagpt.logs import logger from metagpt.roles import Architect, ProductManager, Role @@ -46,9 +46,9 @@ def test_get_roles(env: Environment): @pytest.mark.asyncio async def test_publish_and_process_message(env: Environment): - if CONFIG.git_repo: - CONFIG.git_repo.delete_repository() - CONFIG.git_repo = None + if context.git_repo: + context.git_repo.delete_repository() + context.git_repo = None product_manager = ProductManager(name="Alice", profile="Product Manager", goal="做AI Native产品", constraints="资源有限") architect = Architect( diff --git a/tests/metagpt/test_gpt.py b/tests/metagpt/test_gpt.py deleted file mode 100644 index 2b19f173d..000000000 --- a/tests/metagpt/test_gpt.py +++ /dev/null @@ -1,45 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -@Time : 2023/4/29 19:47 -@Author : alexanderwu -@File : test_gpt.py -""" -import openai -import pytest - -from metagpt.config import CONFIG -from metagpt.logs import logger - - -@pytest.mark.usefixtures("llm_api") -class TestGPT: - @pytest.mark.asyncio - async def test_llm_api_aask(self, llm_api): - answer = await llm_api.aask("hello chatgpt", stream=False) - logger.info(answer) - assert len(answer) > 0 - - answer = await llm_api.aask("hello chatgpt", stream=True) - logger.info(answer) - assert len(answer) > 0 - - @pytest.mark.asyncio - async def test_llm_api_aask_code(self, llm_api): - try: - answer = await llm_api.aask_code(["请扮演一个Google Python专家工程师,如果理解,回复明白", "写一个hello world"], timeout=60) - logger.info(answer) - assert len(answer) > 0 - except openai.BadRequestError: - assert CONFIG.OPENAI_API_TYPE == "azure" - - @pytest.mark.asyncio - async def test_llm_api_costs(self, llm_api): - await llm_api.aask("hello chatgpt", stream=False) - costs = llm_api.get_costs() - logger.info(costs) - assert costs.total_cost > 0 - - -if __name__ == "__main__": - pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/test_llm.py b/tests/metagpt/test_llm.py index 247f043e2..dc18114b1 100644 --- a/tests/metagpt/test_llm.py +++ b/tests/metagpt/test_llm.py @@ -9,7 +9,7 @@ import pytest -from metagpt.provider.openai_api import OpenAILLM as LLM +from metagpt.llm import LLM @pytest.fixture() @@ -23,6 +23,12 @@ async def test_llm_aask(llm): assert len(rsp) > 0 +@pytest.mark.asyncio +async def test_llm_aask_stream(llm): + rsp = await llm.aask("hello world", stream=True) + assert len(rsp) > 0 + + @pytest.mark.asyncio async def test_llm_acompletion(llm): hello_msg = [{"role": "user", "content": "hello"}] diff --git a/tests/metagpt/test_manager.py b/tests/metagpt/test_manager.py deleted file mode 100644 index 5c2a2c795..000000000 --- a/tests/metagpt/test_manager.py +++ /dev/null @@ -1,7 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -@Time : 2023/5/11 14:45 -@Author : alexanderwu -@File : test_manager.py -"""