refine code

This commit is contained in:
geekan 2024-01-09 21:31:38 +08:00
parent b0efa4b6a5
commit c9e05a2186
2 changed files with 27 additions and 21 deletions

View file

@ -10,10 +10,11 @@ from __future__ import annotations
from typing import Optional, Union
from pydantic import ConfigDict, Field, model_validator
from pydantic import BaseModel, ConfigDict, Field, model_validator
import metagpt
from metagpt.actions.action_node import ActionNode
from metagpt.config2 import ConfigMixin
from metagpt.context import Context
from metagpt.llm import LLM
from metagpt.provider.base_llm import BaseLLM
@ -27,7 +28,7 @@ from metagpt.schema import (
from metagpt.utils.file_repository import FileRepository
class Action(SerializationMixin):
class Action(SerializationMixin, ConfigMixin, BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, exclude=["llm"])
name: str = ""

View file

@ -146,6 +146,23 @@ class Role(SerializationMixin, ConfigMixin, BaseModel):
__hash__ = object.__hash__ # support Role as hashable type in `Environment.members`
def __init__(self, **data: Any):
self.pydantic_rebuild_model()
super().__init__(**data)
self.llm.system_prompt = self._get_prefix()
self._watch(data.get("watch") or [UserRequirement])
if self.latest_observed_msg:
self.recovered = True
@staticmethod
def pydantic_rebuild_model():
from metagpt.environment import Environment
Environment
Role.model_rebuild()
@property
def todo(self) -> Action:
return self.rc.todo
@ -157,6 +174,9 @@ class Role(SerializationMixin, ConfigMixin, BaseModel):
@property
def config(self):
"""Role config: role config > context config"""
if self._config:
return self._config
return self.context.config
@property
@ -177,19 +197,19 @@ class Role(SerializationMixin, ConfigMixin, BaseModel):
@property
def prompt_schema(self):
return self.context.config.prompt_schema
return self.config.prompt_schema
@property
def project_name(self):
return self.context.config.project_name
return self.config.project_name
@project_name.setter
def project_name(self, value):
self.context.config.project_name = value
self.config.project_name = value
@property
def project_path(self):
return self.context.config.project_path
return self.config.project_path
@model_validator(mode="after")
def check_addresses(self):
@ -197,21 +217,6 @@ class Role(SerializationMixin, ConfigMixin, BaseModel):
self.addresses = {any_to_str(self), self.name} if self.name else {any_to_str(self)}
return self
def __init__(self, **data: Any):
# --- avoid PydanticUndefinedAnnotation name 'Environment' is not defined #
from metagpt.environment import Environment
Environment
# ------
Role.model_rebuild()
super().__init__(**data)
self.llm.system_prompt = self._get_prefix()
self._watch(data.get("watch") or [UserRequirement])
if self.latest_observed_msg:
self.recovered = True
def _reset(self):
self.states = []
self.actions = []