From 613515836d45c53e44efe46f0b945f95c7bcb67d Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 9 Jan 2024 22:04:49 +0800 Subject: [PATCH] refine code --- metagpt/actions/action.py | 23 +++++------ metagpt/actions/debug_error.py | 4 +- metagpt/actions/design_api.py | 2 +- metagpt/actions/design_api_review.py | 2 +- metagpt/actions/execute_task.py | 2 +- metagpt/actions/invoice_ocr.py | 6 +-- metagpt/actions/prepare_documents.py | 6 +-- metagpt/actions/project_management.py | 2 +- metagpt/actions/rebuild_class_view.py | 6 +-- metagpt/actions/rebuild_sequence_view.py | 2 +- metagpt/actions/research.py | 6 +-- metagpt/actions/run_code.py | 4 +- metagpt/actions/search_and_summarize.py | 23 ++++------- metagpt/actions/summarize_code.py | 4 +- metagpt/actions/talk_action.py | 6 +-- metagpt/actions/write_code.py | 6 +-- metagpt/actions/write_code_review.py | 6 +-- metagpt/actions/write_docstring.py | 2 +- metagpt/actions/write_prd_review.py | 2 +- metagpt/actions/write_teaching_plan.py | 2 +- metagpt/actions/write_test.py | 2 +- metagpt/config.py | 4 +- metagpt/config2.py | 21 ---------- metagpt/context.py | 52 ++++++++++++++++++++++++ metagpt/roles/engineer.py | 16 ++++---- metagpt/roles/role.py | 16 ++------ tests/metagpt/test_config.py | 5 ++- 27 files changed, 123 insertions(+), 109 deletions(-) diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index cdedfcd64..cabab784f 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -12,10 +12,8 @@ from typing import Optional, Union from pydantic import BaseModel, ConfigDict, Field, model_validator -import metagpt from metagpt.actions.action_node import ActionNode -from metagpt.config2 import ConfigMixin -from metagpt.context import Context +from metagpt.context import ContextMixin from metagpt.llm import LLM from metagpt.provider.base_llm import BaseLLM from metagpt.schema import ( @@ -28,44 +26,43 @@ from metagpt.schema import ( from metagpt.utils.file_repository import FileRepository -class Action(SerializationMixin, ConfigMixin, BaseModel): +class Action(SerializationMixin, ContextMixin, BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True, exclude=["llm"]) name: str = "" llm: BaseLLM = Field(default_factory=LLM, exclude=True) - context: Union[dict, CodingContext, CodeSummarizeContext, TestingContext, RunCodeContext, str, None] = "" + i_context: Union[dict, CodingContext, CodeSummarizeContext, TestingContext, RunCodeContext, str, None] = "" prefix: str = "" # aask*时会加上prefix,作为system_message desc: str = "" # for skill manager node: ActionNode = Field(default=None, exclude=True) - g_context: Optional[Context] = Field(default=metagpt.context.context, exclude=True) @property def git_repo(self): - return self.g_context.git_repo + return self.context.git_repo @property def file_repo(self): - return FileRepository(self.g_context.git_repo) + return FileRepository(self.context.git_repo) @property def src_workspace(self): - return self.g_context.src_workspace + return self.context.src_workspace @property def prompt_schema(self): - return self.g_context.config.prompt_schema + return self.config.prompt_schema @property def project_name(self): - return self.g_context.config.project_name + return self.config.project_name @project_name.setter def project_name(self, value): - self.g_context.config.project_name = value + self.config.project_name = value @property def project_path(self): - return self.g_context.config.project_path + return self.config.project_path @model_validator(mode="before") @classmethod diff --git a/metagpt/actions/debug_error.py b/metagpt/actions/debug_error.py index aa84d1f11..3647640c0 100644 --- a/metagpt/actions/debug_error.py +++ b/metagpt/actions/debug_error.py @@ -47,7 +47,7 @@ Now you should start rewriting the code: class DebugError(Action): - context: RunCodeContext = Field(default_factory=RunCodeContext) + i_context: RunCodeContext = Field(default_factory=RunCodeContext) async def run(self, *args, **kwargs) -> str: output_doc = await self.file_repo.get_file( @@ -63,7 +63,7 @@ class DebugError(Action): logger.info(f"Debug and rewrite {self.context.test_filename}") code_doc = await self.file_repo.get_file( - filename=self.context.code_filename, relative_path=self.g_context.src_workspace + filename=self.context.code_filename, relative_path=self.context.src_workspace ) if not code_doc: return "" diff --git a/metagpt/actions/design_api.py b/metagpt/actions/design_api.py index b89ec7877..3e978f823 100644 --- a/metagpt/actions/design_api.py +++ b/metagpt/actions/design_api.py @@ -37,7 +37,7 @@ NEW_REQ_TEMPLATE = """ class WriteDesign(Action): name: str = "" - context: Optional[str] = None + i_context: Optional[str] = None desc: str = ( "Based on the PRD, think about the system design, and design the corresponding APIs, " "data structures, library tables, processes, and paths. Please provide your design, feedback " diff --git a/metagpt/actions/design_api_review.py b/metagpt/actions/design_api_review.py index fb1b92d85..ccd01a4c3 100644 --- a/metagpt/actions/design_api_review.py +++ b/metagpt/actions/design_api_review.py @@ -13,7 +13,7 @@ from metagpt.actions.action import Action class DesignReview(Action): name: str = "DesignReview" - context: Optional[str] = None + i_context: Optional[str] = None async def run(self, prd, api_design): prompt = ( diff --git a/metagpt/actions/execute_task.py b/metagpt/actions/execute_task.py index 4ae4ee17b..1cc3bd699 100644 --- a/metagpt/actions/execute_task.py +++ b/metagpt/actions/execute_task.py @@ -13,7 +13,7 @@ from metagpt.schema import Message class ExecuteTask(Action): name: str = "ExecuteTask" - context: list[Message] = [] + i_context: list[Message] = [] async def run(self, *args, **kwargs): pass diff --git a/metagpt/actions/invoice_ocr.py b/metagpt/actions/invoice_ocr.py index 36570097a..a3406ff65 100644 --- a/metagpt/actions/invoice_ocr.py +++ b/metagpt/actions/invoice_ocr.py @@ -41,7 +41,7 @@ class InvoiceOCR(Action): """ name: str = "InvoiceOCR" - context: Optional[str] = None + i_context: Optional[str] = None @staticmethod async def _check_file_type(file_path: Path) -> str: @@ -132,7 +132,7 @@ class GenerateTable(Action): """ name: str = "GenerateTable" - context: Optional[str] = None + i_context: Optional[str] = None llm: BaseLLM = Field(default_factory=LLM) language: str = "ch" @@ -177,7 +177,7 @@ class ReplyQuestion(Action): """ name: str = "ReplyQuestion" - context: Optional[str] = None + i_context: Optional[str] = None llm: BaseLLM = Field(default_factory=LLM) language: str = "ch" diff --git a/metagpt/actions/prepare_documents.py b/metagpt/actions/prepare_documents.py index ae5aaf2b5..8a9e78b2a 100644 --- a/metagpt/actions/prepare_documents.py +++ b/metagpt/actions/prepare_documents.py @@ -22,11 +22,11 @@ class PrepareDocuments(Action): """PrepareDocuments Action: initialize project folder and add new requirements to docs/requirements.txt.""" name: str = "PrepareDocuments" - context: Optional[str] = None + i_context: Optional[str] = None @property def config(self): - return self.g_context.config + return self.context.config def _init_repo(self): """Initialize the Git environment.""" @@ -39,7 +39,7 @@ class PrepareDocuments(Action): shutil.rmtree(path) self.config.project_path = path self.config.project_name = path.name - self.g_context.git_repo = GitRepository(local_path=path, auto_init=True) + self.context.git_repo = GitRepository(local_path=path, auto_init=True) async def run(self, with_messages, **kwargs): """Create and initialize the workspace folder, initialize the Git environment.""" diff --git a/metagpt/actions/project_management.py b/metagpt/actions/project_management.py index b40da824f..bb8141a74 100644 --- a/metagpt/actions/project_management.py +++ b/metagpt/actions/project_management.py @@ -36,7 +36,7 @@ NEW_REQ_TEMPLATE = """ class WriteTasks(Action): name: str = "CreateTasks" - context: Optional[str] = None + i_context: Optional[str] = None async def run(self, with_messages): system_design_file_repo = self.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO) diff --git a/metagpt/actions/rebuild_class_view.py b/metagpt/actions/rebuild_class_view.py index 5128b9fee..876beccec 100644 --- a/metagpt/actions/rebuild_class_view.py +++ b/metagpt/actions/rebuild_class_view.py @@ -32,13 +32,13 @@ class RebuildClassView(Action): async def run(self, with_messages=None, format=CONFIG.prompt_schema): graph_repo_pathname = CONFIG.git_repo.workdir / GRAPH_REPO_FILE_REPO / CONFIG.git_repo.workdir.name graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json"))) - repo_parser = RepoParser(base_directory=Path(self.context)) + repo_parser = RepoParser(base_directory=Path(self.i_context)) # use pylint - class_views, relationship_views, package_root = await repo_parser.rebuild_class_views(path=Path(self.context)) + class_views, relationship_views, package_root = await repo_parser.rebuild_class_views(path=Path(self.i_context)) await GraphRepository.update_graph_db_with_class_views(graph_db, class_views) await GraphRepository.update_graph_db_with_class_relationship_views(graph_db, relationship_views) # use ast - direction, diff_path = self._diff_path(path_root=Path(self.context).resolve(), package_root=package_root) + direction, diff_path = self._diff_path(path_root=Path(self.i_context).resolve(), package_root=package_root) symbols = repo_parser.generate_symbols() for file_info in symbols: # Align to the same root directory in accordance with `class_views`. diff --git a/metagpt/actions/rebuild_sequence_view.py b/metagpt/actions/rebuild_sequence_view.py index 865050c93..bc128d8b0 100644 --- a/metagpt/actions/rebuild_sequence_view.py +++ b/metagpt/actions/rebuild_sequence_view.py @@ -41,7 +41,7 @@ class RebuildSequenceView(Action): async def _rebuild_sequence_view(self, entry, graph_db): filename = entry.subject.split(":", 1)[0] - src_filename = RebuildSequenceView._get_full_filename(root=self.context, pathname=filename) + src_filename = RebuildSequenceView._get_full_filename(root=self.i_context, pathname=filename) content = await aread(filename=src_filename, encoding="utf-8") content = f"```python\n{content}\n```\n\n---\nTranslate the code above into Mermaid Sequence Diagram." data = await self.llm.aask( diff --git a/metagpt/actions/research.py b/metagpt/actions/research.py index 90b08cb6a..84067ad92 100644 --- a/metagpt/actions/research.py +++ b/metagpt/actions/research.py @@ -81,7 +81,7 @@ class CollectLinks(Action): """Action class to collect links from a search engine.""" name: str = "CollectLinks" - context: Optional[str] = None + i_context: Optional[str] = None desc: str = "Collect links from a search engine." search_engine: SearchEngine = Field(default_factory=SearchEngine) @@ -177,7 +177,7 @@ class WebBrowseAndSummarize(Action): """Action class to explore the web and provide summaries of articles and webpages.""" name: str = "WebBrowseAndSummarize" - context: Optional[str] = None + i_context: Optional[str] = None llm: BaseLLM = Field(default_factory=LLM) desc: str = "Explore the web and provide summaries of articles and webpages." browse_func: Union[Callable[[list[str]], None], None] = None @@ -248,7 +248,7 @@ class ConductResearch(Action): """Action class to conduct research and generate a research report.""" name: str = "ConductResearch" - context: Optional[str] = None + i_context: Optional[str] = None llm: BaseLLM = Field(default_factory=LLM) def __init__(self, **kwargs): diff --git a/metagpt/actions/run_code.py b/metagpt/actions/run_code.py index 0d42308c1..8fdda0a0d 100644 --- a/metagpt/actions/run_code.py +++ b/metagpt/actions/run_code.py @@ -76,7 +76,7 @@ standard errors: class RunCode(Action): name: str = "RunCode" - context: RunCodeContext = Field(default_factory=RunCodeContext) + i_context: RunCodeContext = Field(default_factory=RunCodeContext) @classmethod async def run_text(cls, code) -> Tuple[str, str]: @@ -93,7 +93,7 @@ class RunCode(Action): additional_python_paths = [str(path) for path in additional_python_paths] # Copy the current environment variables - env = self.g_context.new_environ() + env = self.context.new_environ() # Modify the PYTHONPATH environment variable additional_python_paths = [working_directory] + additional_python_paths diff --git a/metagpt/actions/search_and_summarize.py b/metagpt/actions/search_and_summarize.py index 39ca23df5..59b35cd58 100644 --- a/metagpt/actions/search_and_summarize.py +++ b/metagpt/actions/search_and_summarize.py @@ -8,10 +8,9 @@ from typing import Any, Optional import pydantic -from pydantic import Field, model_validator +from pydantic import model_validator from metagpt.actions import Action -from metagpt.config import Config from metagpt.logs import logger from metagpt.schema import Message from metagpt.tools import SearchEngineType @@ -106,28 +105,22 @@ You are a member of a professional butler team and will provide helpful suggesti class SearchAndSummarize(Action): name: str = "" content: Optional[str] = None - config: None = Field(default_factory=Config) engine: Optional[SearchEngineType] = None search_func: Optional[Any] = None search_engine: SearchEngine = None result: str = "" - @model_validator(mode="before") - @classmethod - def validate_engine_and_run_func(cls, values): - engine = values.get("engine") - search_func = values.get("search_func") - config = Config() - - if engine is None: - engine = config.search_engine + @model_validator(mode="after") + def validate_engine_and_run_func(self): + if self.engine is None: + self.engine = self.config.search_engine try: - search_engine = SearchEngine(engine=engine, run_func=search_func) + search_engine = SearchEngine(engine=self.engine, run_func=self.search_func) except pydantic.ValidationError: search_engine = None - values["search_engine"] = search_engine - return values + self.search_engine = search_engine + return self async def run(self, context: list[Message], system_text=SEARCH_AND_SUMMARIZE_SYSTEM) -> str: if self.search_engine is None: diff --git a/metagpt/actions/summarize_code.py b/metagpt/actions/summarize_code.py index 948eceab2..690d5c77b 100644 --- a/metagpt/actions/summarize_code.py +++ b/metagpt/actions/summarize_code.py @@ -90,7 +90,7 @@ flowchart TB class SummarizeCode(Action): name: str = "SummarizeCode" - context: CodeSummarizeContext = Field(default_factory=CodeSummarizeContext) + i_context: CodeSummarizeContext = Field(default_factory=CodeSummarizeContext) @retry(stop=stop_after_attempt(2), wait=wait_random_exponential(min=1, max=60)) async def summarize_code(self, prompt): @@ -103,7 +103,7 @@ class SummarizeCode(Action): design_doc = await repo.get_file(filename=design_pathname.name, relative_path=SYSTEM_DESIGN_FILE_REPO) task_pathname = Path(self.context.task_filename) task_doc = await repo.get_file(filename=task_pathname.name, relative_path=TASK_FILE_REPO) - src_file_repo = self.git_repo.new_file_repository(relative_path=self.g_context.src_workspace) + src_file_repo = self.git_repo.new_file_repository(relative_path=self.context.src_workspace) code_blocks = [] for filename in self.context.codes_filenames: code_doc = await src_file_repo.get(filename) diff --git a/metagpt/actions/talk_action.py b/metagpt/actions/talk_action.py index eab1740fc..253b829ed 100644 --- a/metagpt/actions/talk_action.py +++ b/metagpt/actions/talk_action.py @@ -15,18 +15,18 @@ from metagpt.schema import Message class TalkAction(Action): - context: str + i_context: str history_summary: str = "" knowledge: str = "" rsp: Optional[Message] = None @property def agent_description(self): - return self.g_context.kwargs.agent_description + return self.context.kwargs.agent_description @property def language(self): - return self.g_context.kwargs.language or config.language + return self.context.kwargs.language or config.language @property def prompt(self): diff --git a/metagpt/actions/write_code.py b/metagpt/actions/write_code.py index 2b8f91a1d..779fe52a6 100644 --- a/metagpt/actions/write_code.py +++ b/metagpt/actions/write_code.py @@ -85,7 +85,7 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc class WriteCode(Action): name: str = "WriteCode" - context: Document = Field(default_factory=Document) + i_context: Document = Field(default_factory=Document) @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) async def write_code(self, prompt) -> str: @@ -116,7 +116,7 @@ class WriteCode(Action): coding_context.task_doc, exclude=self.context.filename, git_repo=self.git_repo, - src_workspace=self.g_context.src_workspace, + src_workspace=self.context.src_workspace, ) prompt = PROMPT_TEMPLATE.format( @@ -132,7 +132,7 @@ class WriteCode(Action): code = await self.write_code(prompt) if not coding_context.code_doc: # avoid root_path pydantic ValidationError if use WriteCode alone - root_path = self.g_context.src_workspace if self.g_context.src_workspace else "" + root_path = self.context.src_workspace if self.context.src_workspace else "" coding_context.code_doc = Document(filename=coding_context.filename, root_path=str(root_path)) coding_context.code_doc.content = code return coding_context diff --git a/metagpt/actions/write_code_review.py b/metagpt/actions/write_code_review.py index 4433a7ab9..6ff9d5aa4 100644 --- a/metagpt/actions/write_code_review.py +++ b/metagpt/actions/write_code_review.py @@ -119,7 +119,7 @@ REWRITE_CODE_TEMPLATE = """ class WriteCodeReview(Action): name: str = "WriteCodeReview" - context: CodingContext = Field(default_factory=CodingContext) + i_context: CodingContext = Field(default_factory=CodingContext) @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) async def write_code_review_and_rewrite(self, context_prompt, cr_prompt, filename): @@ -136,14 +136,14 @@ class WriteCodeReview(Action): async def run(self, *args, **kwargs) -> CodingContext: iterative_code = self.context.code_doc.content - k = self.g_context.config.code_review_k_times or 1 + k = self.context.config.code_review_k_times or 1 for i in range(k): format_example = FORMAT_EXAMPLE.format(filename=self.context.code_doc.filename) task_content = self.context.task_doc.content if self.context.task_doc else "" code_context = await WriteCode.get_codes( self.context.task_doc, exclude=self.context.filename, - git_repo=self.g_context.git_repo, + git_repo=self.context.git_repo, src_workspace=self.src_workspace, ) context = "\n".join( diff --git a/metagpt/actions/write_docstring.py b/metagpt/actions/write_docstring.py index 8b8335517..79204e6a4 100644 --- a/metagpt/actions/write_docstring.py +++ b/metagpt/actions/write_docstring.py @@ -161,7 +161,7 @@ class WriteDocstring(Action): """ desc: str = "Write docstring for code." - context: Optional[str] = None + i_context: Optional[str] = None async def run( self, diff --git a/metagpt/actions/write_prd_review.py b/metagpt/actions/write_prd_review.py index 2babe38db..68fb5d9e8 100644 --- a/metagpt/actions/write_prd_review.py +++ b/metagpt/actions/write_prd_review.py @@ -13,7 +13,7 @@ from metagpt.actions.action import Action class WritePRDReview(Action): name: str = "" - context: Optional[str] = None + i_context: Optional[str] = None prd: Optional[str] = None desc: str = "Based on the PRD, conduct a PRD Review, providing clear and detailed feedback" diff --git a/metagpt/actions/write_teaching_plan.py b/metagpt/actions/write_teaching_plan.py index 76923a663..04507fda3 100644 --- a/metagpt/actions/write_teaching_plan.py +++ b/metagpt/actions/write_teaching_plan.py @@ -15,7 +15,7 @@ from metagpt.logs import logger class WriteTeachingPlanPart(Action): """Write Teaching Plan Part""" - context: Optional[str] = None + i_context: Optional[str] = None topic: str = "" language: str = "Chinese" rsp: Optional[str] = None diff --git a/metagpt/actions/write_test.py b/metagpt/actions/write_test.py index 96486311f..38b1cf03c 100644 --- a/metagpt/actions/write_test.py +++ b/metagpt/actions/write_test.py @@ -39,7 +39,7 @@ you should correctly import the necessary classes based on these file locations! class WriteTest(Action): name: str = "WriteTest" - context: Optional[TestingContext] = None + i_context: Optional[TestingContext] = None async def write_code(self, prompt): code_rsp = await self._aask(prompt) diff --git a/metagpt/config.py b/metagpt/config.py index 0c7b54f83..952ccc962 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -133,8 +133,8 @@ class Config(metaclass=Singleton): self.ollama_api_base = self._get("OLLAMA_API_BASE") self.ollama_api_model = self._get("OLLAMA_API_MODEL") - if not self._get("DISABLE_LLM_PROVIDER_CHECK"): - _ = self.get_default_llm_provider_enum() + # if not self._get("DISABLE_LLM_PROVIDER_CHECK"): + # _ = self.get_default_llm_provider_enum() self.openai_base_url = self._get("OPENAI_BASE_URL") self.openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy diff --git a/metagpt/config2.py b/metagpt/config2.py index 393c46200..cb5c22ac2 100644 --- a/metagpt/config2.py +++ b/metagpt/config2.py @@ -153,25 +153,4 @@ def merge_dict(dicts: Iterable[Dict]) -> Dict: return result -class ConfigMixin(BaseModel): - """Mixin class for configurable objects""" - - # Env/Role/Action will use this config as private config, or use self.context.config as public config - _config: Optional[Config] = None - - def __init__(self, config: Optional[Config] = None, **kwargs): - """Initialize with config""" - super().__init__(**kwargs) - self.set_config(config) - - def set(self, k, v, override=False): - """Set attribute""" - if override or not self.__dict__.get(k): - self.__dict__[k] = v - - def set_config(self, config: Config, override=False): - """Set config""" - self.set("_config", config, override) - - config = Config.default() diff --git a/metagpt/context.py b/metagpt/context.py index 4016e8d7c..74f7b133d 100644 --- a/metagpt/context.py +++ b/metagpt/context.py @@ -100,5 +100,57 @@ class Context(BaseModel): return llm +class ContextMixin(BaseModel): + """Mixin class for context and config""" + + # Env/Role/Action will use this context as private context, or use self.context as public context + _context: Optional[Context] = None + # Env/Role/Action will use this config as private config, or use self.context.config as public config + _config: Optional[Config] = None + + def __init__(self, context: Optional[Context] = None, config: Optional[Config] = None, **kwargs): + """Initialize with config""" + super().__init__(**kwargs) + self.set_context(context) + self.set_config(config) + + def set(self, k, v, override=False): + """Set attribute""" + if override or not self.__dict__.get(k): + self.__dict__[k] = v + + def set_context(self, context: Context, override=True): + """Set context""" + self.set("_context", context, override) + + def set_config(self, config: Config, override=False): + """Set config""" + self.set("_config", config, override) + + @property + def config(self): + """Role config: role config > context config""" + if self._config: + return self._config + return self.context.config + + @config.setter + def config(self, config: Config): + """Set config""" + self.set_config(config) + + @property + def context(self): + """Role context: role context > context""" + if self._context: + return self._context + return context + + @context.setter + def context(self, context: Context): + """Set context""" + self.set_context(context) + + # Global context, not in Env context = Context() diff --git a/metagpt/roles/engineer.py b/metagpt/roles/engineer.py index ad0c1ac92..dc9f31686 100644 --- a/metagpt/roles/engineer.py +++ b/metagpt/roles/engineer.py @@ -159,9 +159,9 @@ class Engineer(Role): src_relative_path = self.src_workspace.relative_to(self.git_repo.workdir) for todo in self.summarize_todos: summary = await todo.run() - summary_filename = Path(todo.context.design_filename).with_suffix(".md").name - dependencies = {todo.context.design_filename, todo.context.task_filename} - for filename in todo.context.codes_filenames: + summary_filename = Path(todo.i_context.design_filename).with_suffix(".md").name + dependencies = {todo.i_context.design_filename, todo.i_context.task_filename} + for filename in todo.i_context.codes_filenames: rpath = src_relative_path / filename dependencies.add(str(rpath)) await code_summaries_pdf_file_repo.save( @@ -169,15 +169,15 @@ class Engineer(Role): ) is_pass, reason = await self._is_pass(summary) if not is_pass: - todo.context.reason = reason - tasks.append(todo.context.dict()) + todo.i_context.reason = reason + tasks.append(todo.i_context.dict()) await code_summaries_file_repo.save( - filename=Path(todo.context.design_filename).name, - content=todo.context.model_dump_json(), + filename=Path(todo.i_context.design_filename).name, + content=todo.i_context.model_dump_json(), dependencies=dependencies, ) else: - await code_summaries_file_repo.delete(filename=Path(todo.context.design_filename).name) + await code_summaries_file_repo.delete(filename=Path(todo.i_context.design_filename).name) logger.info(f"--max-auto-summarize-code={self.config.max_auto_summarize_code}") if not tasks or self.config.max_auto_summarize_code == 0: diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 959b5d00d..e31eabd23 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -30,8 +30,7 @@ from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validat from metagpt.actions import Action, ActionOutput from metagpt.actions.action_node import ActionNode from metagpt.actions.add_requirement import UserRequirement -from metagpt.config2 import ConfigMixin -from metagpt.context import Context, context +from metagpt.context import ContextMixin from metagpt.llm import LLM from metagpt.logs import logger from metagpt.memory import Memory @@ -120,7 +119,7 @@ class RoleContext(BaseModel): return self.memory.get() -class Role(SerializationMixin, ConfigMixin, BaseModel): +class Role(SerializationMixin, ContextMixin, BaseModel): """Role/Agent""" model_config = ConfigDict(arbitrary_types_allowed=True, exclude=["llm"]) @@ -142,7 +141,7 @@ class Role(SerializationMixin, ConfigMixin, BaseModel): # builtin variables recovered: bool = False # to tag if a recovered role latest_observed_msg: Optional[Message] = None # record the latest observed message when interrupted - context: Optional[Context] = Field(default=context, exclude=True) + # context: Optional[Context] = Field(default=context, exclude=True) __hash__ = object.__hash__ # support Role as hashable type in `Environment.members` @@ -172,16 +171,9 @@ class Role(SerializationMixin, ConfigMixin, BaseModel): def set_todo(self, value: Optional[Action]): """Set action to do and update context""" if value: - value.g_context = self.context + value.context = self.context self.rc.todo = value - @property - def config(self): - """Role config: role config > context config""" - if self._config: - return self._config - return self.context.config - @property def git_repo(self): """Git repo""" diff --git a/tests/metagpt/test_config.py b/tests/metagpt/test_config.py index 0a2c0d462..c74b16930 100644 --- a/tests/metagpt/test_config.py +++ b/tests/metagpt/test_config.py @@ -7,8 +7,9 @@ """ from pydantic import BaseModel -from metagpt.config2 import Config, ConfigMixin, config +from metagpt.config2 import Config, config from metagpt.configs.llm_config import LLMType +from metagpt.context import ContextMixin from tests.metagpt.provider.mock_llm_config import mock_llm_config @@ -29,7 +30,7 @@ def test_config_from_dict(): assert cfg.llm["default"].api_key == "mock_api_key" -class ModelX(ConfigMixin, BaseModel): +class ModelX(ContextMixin, BaseModel): a: str = "a" b: str = "b"