From c9e05a2186cd8d95e73f0ed41937b38e4f7721d5 Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 9 Jan 2024 21:31:38 +0800 Subject: [PATCH] refine code --- metagpt/actions/action.py | 5 +++-- metagpt/roles/role.py | 43 ++++++++++++++++++++++----------------- 2 files changed, 27 insertions(+), 21 deletions(-) diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index 9f045bbaa..cdedfcd64 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -10,10 +10,11 @@ from __future__ import annotations from typing import Optional, Union -from pydantic import ConfigDict, Field, model_validator +from pydantic import BaseModel, ConfigDict, Field, model_validator import metagpt from metagpt.actions.action_node import ActionNode +from metagpt.config2 import ConfigMixin from metagpt.context import Context from metagpt.llm import LLM from metagpt.provider.base_llm import BaseLLM @@ -27,7 +28,7 @@ from metagpt.schema import ( from metagpt.utils.file_repository import FileRepository -class Action(SerializationMixin): +class Action(SerializationMixin, ConfigMixin, BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True, exclude=["llm"]) name: str = "" diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 88bab72cb..75dff94f2 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -146,6 +146,23 @@ class Role(SerializationMixin, ConfigMixin, BaseModel): __hash__ = object.__hash__ # support Role as hashable type in `Environment.members` + def __init__(self, **data: Any): + self.pydantic_rebuild_model() + super().__init__(**data) + + self.llm.system_prompt = self._get_prefix() + self._watch(data.get("watch") or [UserRequirement]) + + if self.latest_observed_msg: + self.recovered = True + + @staticmethod + def pydantic_rebuild_model(): + from metagpt.environment import Environment + + Environment + Role.model_rebuild() + @property def todo(self) -> Action: return self.rc.todo @@ -157,6 +174,9 @@ class Role(SerializationMixin, ConfigMixin, BaseModel): @property def config(self): + """Role config: role config > context config""" + if self._config: + return self._config return self.context.config @property @@ -177,19 +197,19 @@ class Role(SerializationMixin, ConfigMixin, BaseModel): @property def prompt_schema(self): - return self.context.config.prompt_schema + return self.config.prompt_schema @property def project_name(self): - return self.context.config.project_name + return self.config.project_name @project_name.setter def project_name(self, value): - self.context.config.project_name = value + self.config.project_name = value @property def project_path(self): - return self.context.config.project_path + return self.config.project_path @model_validator(mode="after") def check_addresses(self): @@ -197,21 +217,6 @@ class Role(SerializationMixin, ConfigMixin, BaseModel): self.addresses = {any_to_str(self), self.name} if self.name else {any_to_str(self)} return self - def __init__(self, **data: Any): - # --- avoid PydanticUndefinedAnnotation name 'Environment' is not defined # - from metagpt.environment import Environment - - Environment - # ------ - Role.model_rebuild() - super().__init__(**data) - - self.llm.system_prompt = self._get_prefix() - self._watch(data.get("watch") or [UserRequirement]) - - if self.latest_observed_msg: - self.recovered = True - def _reset(self): self.states = [] self.actions = []