refine code

This commit is contained in:
geekan 2024-01-09 22:04:49 +08:00 committed by 莘权 马
parent 4bb4dce4b9
commit 613515836d
27 changed files with 123 additions and 109 deletions

View file

@ -12,10 +12,8 @@ from typing import Optional, Union
from pydantic import BaseModel, ConfigDict, Field, model_validator
import metagpt
from metagpt.actions.action_node import ActionNode
from metagpt.config2 import ConfigMixin
from metagpt.context import Context
from metagpt.context import ContextMixin
from metagpt.llm import LLM
from metagpt.provider.base_llm import BaseLLM
from metagpt.schema import (
@ -28,44 +26,43 @@ from metagpt.schema import (
from metagpt.utils.file_repository import FileRepository
class Action(SerializationMixin, ConfigMixin, BaseModel):
class Action(SerializationMixin, ContextMixin, BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, exclude=["llm"])
name: str = ""
llm: BaseLLM = Field(default_factory=LLM, exclude=True)
context: Union[dict, CodingContext, CodeSummarizeContext, TestingContext, RunCodeContext, str, None] = ""
i_context: Union[dict, CodingContext, CodeSummarizeContext, TestingContext, RunCodeContext, str, None] = ""
prefix: str = "" # aask*时会加上prefix作为system_message
desc: str = "" # for skill manager
node: ActionNode = Field(default=None, exclude=True)
g_context: Optional[Context] = Field(default=metagpt.context.context, exclude=True)
@property
def git_repo(self):
return self.g_context.git_repo
return self.context.git_repo
@property
def file_repo(self):
return FileRepository(self.g_context.git_repo)
return FileRepository(self.context.git_repo)
@property
def src_workspace(self):
return self.g_context.src_workspace
return self.context.src_workspace
@property
def prompt_schema(self):
return self.g_context.config.prompt_schema
return self.config.prompt_schema
@property
def project_name(self):
return self.g_context.config.project_name
return self.config.project_name
@project_name.setter
def project_name(self, value):
self.g_context.config.project_name = value
self.config.project_name = value
@property
def project_path(self):
return self.g_context.config.project_path
return self.config.project_path
@model_validator(mode="before")
@classmethod

View file

@ -47,7 +47,7 @@ Now you should start rewriting the code:
class DebugError(Action):
context: RunCodeContext = Field(default_factory=RunCodeContext)
i_context: RunCodeContext = Field(default_factory=RunCodeContext)
async def run(self, *args, **kwargs) -> str:
output_doc = await self.file_repo.get_file(
@ -63,7 +63,7 @@ class DebugError(Action):
logger.info(f"Debug and rewrite {self.context.test_filename}")
code_doc = await self.file_repo.get_file(
filename=self.context.code_filename, relative_path=self.g_context.src_workspace
filename=self.context.code_filename, relative_path=self.context.src_workspace
)
if not code_doc:
return ""

View file

@ -37,7 +37,7 @@ NEW_REQ_TEMPLATE = """
class WriteDesign(Action):
name: str = ""
context: Optional[str] = None
i_context: Optional[str] = None
desc: str = (
"Based on the PRD, think about the system design, and design the corresponding APIs, "
"data structures, library tables, processes, and paths. Please provide your design, feedback "

View file

@ -13,7 +13,7 @@ from metagpt.actions.action import Action
class DesignReview(Action):
name: str = "DesignReview"
context: Optional[str] = None
i_context: Optional[str] = None
async def run(self, prd, api_design):
prompt = (

View file

@ -13,7 +13,7 @@ from metagpt.schema import Message
class ExecuteTask(Action):
name: str = "ExecuteTask"
context: list[Message] = []
i_context: list[Message] = []
async def run(self, *args, **kwargs):
pass

View file

@ -41,7 +41,7 @@ class InvoiceOCR(Action):
"""
name: str = "InvoiceOCR"
context: Optional[str] = None
i_context: Optional[str] = None
@staticmethod
async def _check_file_type(file_path: Path) -> str:
@ -132,7 +132,7 @@ class GenerateTable(Action):
"""
name: str = "GenerateTable"
context: Optional[str] = None
i_context: Optional[str] = None
llm: BaseLLM = Field(default_factory=LLM)
language: str = "ch"
@ -177,7 +177,7 @@ class ReplyQuestion(Action):
"""
name: str = "ReplyQuestion"
context: Optional[str] = None
i_context: Optional[str] = None
llm: BaseLLM = Field(default_factory=LLM)
language: str = "ch"

View file

@ -22,11 +22,11 @@ class PrepareDocuments(Action):
"""PrepareDocuments Action: initialize project folder and add new requirements to docs/requirements.txt."""
name: str = "PrepareDocuments"
context: Optional[str] = None
i_context: Optional[str] = None
@property
def config(self):
return self.g_context.config
return self.context.config
def _init_repo(self):
"""Initialize the Git environment."""
@ -39,7 +39,7 @@ class PrepareDocuments(Action):
shutil.rmtree(path)
self.config.project_path = path
self.config.project_name = path.name
self.g_context.git_repo = GitRepository(local_path=path, auto_init=True)
self.context.git_repo = GitRepository(local_path=path, auto_init=True)
async def run(self, with_messages, **kwargs):
"""Create and initialize the workspace folder, initialize the Git environment."""

View file

@ -36,7 +36,7 @@ NEW_REQ_TEMPLATE = """
class WriteTasks(Action):
name: str = "CreateTasks"
context: Optional[str] = None
i_context: Optional[str] = None
async def run(self, with_messages):
system_design_file_repo = self.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO)

View file

@ -32,13 +32,13 @@ class RebuildClassView(Action):
async def run(self, with_messages=None, format=CONFIG.prompt_schema):
graph_repo_pathname = CONFIG.git_repo.workdir / GRAPH_REPO_FILE_REPO / CONFIG.git_repo.workdir.name
graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
repo_parser = RepoParser(base_directory=Path(self.context))
repo_parser = RepoParser(base_directory=Path(self.i_context))
# use pylint
class_views, relationship_views, package_root = await repo_parser.rebuild_class_views(path=Path(self.context))
class_views, relationship_views, package_root = await repo_parser.rebuild_class_views(path=Path(self.i_context))
await GraphRepository.update_graph_db_with_class_views(graph_db, class_views)
await GraphRepository.update_graph_db_with_class_relationship_views(graph_db, relationship_views)
# use ast
direction, diff_path = self._diff_path(path_root=Path(self.context).resolve(), package_root=package_root)
direction, diff_path = self._diff_path(path_root=Path(self.i_context).resolve(), package_root=package_root)
symbols = repo_parser.generate_symbols()
for file_info in symbols:
# Align to the same root directory in accordance with `class_views`.

View file

@ -41,7 +41,7 @@ class RebuildSequenceView(Action):
async def _rebuild_sequence_view(self, entry, graph_db):
filename = entry.subject.split(":", 1)[0]
src_filename = RebuildSequenceView._get_full_filename(root=self.context, pathname=filename)
src_filename = RebuildSequenceView._get_full_filename(root=self.i_context, pathname=filename)
content = await aread(filename=src_filename, encoding="utf-8")
content = f"```python\n{content}\n```\n\n---\nTranslate the code above into Mermaid Sequence Diagram."
data = await self.llm.aask(

View file

@ -81,7 +81,7 @@ class CollectLinks(Action):
"""Action class to collect links from a search engine."""
name: str = "CollectLinks"
context: Optional[str] = None
i_context: Optional[str] = None
desc: str = "Collect links from a search engine."
search_engine: SearchEngine = Field(default_factory=SearchEngine)
@ -177,7 +177,7 @@ class WebBrowseAndSummarize(Action):
"""Action class to explore the web and provide summaries of articles and webpages."""
name: str = "WebBrowseAndSummarize"
context: Optional[str] = None
i_context: Optional[str] = None
llm: BaseLLM = Field(default_factory=LLM)
desc: str = "Explore the web and provide summaries of articles and webpages."
browse_func: Union[Callable[[list[str]], None], None] = None
@ -248,7 +248,7 @@ class ConductResearch(Action):
"""Action class to conduct research and generate a research report."""
name: str = "ConductResearch"
context: Optional[str] = None
i_context: Optional[str] = None
llm: BaseLLM = Field(default_factory=LLM)
def __init__(self, **kwargs):

View file

@ -76,7 +76,7 @@ standard errors:
class RunCode(Action):
name: str = "RunCode"
context: RunCodeContext = Field(default_factory=RunCodeContext)
i_context: RunCodeContext = Field(default_factory=RunCodeContext)
@classmethod
async def run_text(cls, code) -> Tuple[str, str]:
@ -93,7 +93,7 @@ class RunCode(Action):
additional_python_paths = [str(path) for path in additional_python_paths]
# Copy the current environment variables
env = self.g_context.new_environ()
env = self.context.new_environ()
# Modify the PYTHONPATH environment variable
additional_python_paths = [working_directory] + additional_python_paths

View file

@ -8,10 +8,9 @@
from typing import Any, Optional
import pydantic
from pydantic import Field, model_validator
from pydantic import model_validator
from metagpt.actions import Action
from metagpt.config import Config
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.tools import SearchEngineType
@ -106,28 +105,22 @@ You are a member of a professional butler team and will provide helpful suggesti
class SearchAndSummarize(Action):
name: str = ""
content: Optional[str] = None
config: None = Field(default_factory=Config)
engine: Optional[SearchEngineType] = None
search_func: Optional[Any] = None
search_engine: SearchEngine = None
result: str = ""
@model_validator(mode="before")
@classmethod
def validate_engine_and_run_func(cls, values):
engine = values.get("engine")
search_func = values.get("search_func")
config = Config()
if engine is None:
engine = config.search_engine
@model_validator(mode="after")
def validate_engine_and_run_func(self):
if self.engine is None:
self.engine = self.config.search_engine
try:
search_engine = SearchEngine(engine=engine, run_func=search_func)
search_engine = SearchEngine(engine=self.engine, run_func=self.search_func)
except pydantic.ValidationError:
search_engine = None
values["search_engine"] = search_engine
return values
self.search_engine = search_engine
return self
async def run(self, context: list[Message], system_text=SEARCH_AND_SUMMARIZE_SYSTEM) -> str:
if self.search_engine is None:

View file

@ -90,7 +90,7 @@ flowchart TB
class SummarizeCode(Action):
name: str = "SummarizeCode"
context: CodeSummarizeContext = Field(default_factory=CodeSummarizeContext)
i_context: CodeSummarizeContext = Field(default_factory=CodeSummarizeContext)
@retry(stop=stop_after_attempt(2), wait=wait_random_exponential(min=1, max=60))
async def summarize_code(self, prompt):
@ -103,7 +103,7 @@ class SummarizeCode(Action):
design_doc = await repo.get_file(filename=design_pathname.name, relative_path=SYSTEM_DESIGN_FILE_REPO)
task_pathname = Path(self.context.task_filename)
task_doc = await repo.get_file(filename=task_pathname.name, relative_path=TASK_FILE_REPO)
src_file_repo = self.git_repo.new_file_repository(relative_path=self.g_context.src_workspace)
src_file_repo = self.git_repo.new_file_repository(relative_path=self.context.src_workspace)
code_blocks = []
for filename in self.context.codes_filenames:
code_doc = await src_file_repo.get(filename)

View file

@ -15,18 +15,18 @@ from metagpt.schema import Message
class TalkAction(Action):
context: str
i_context: str
history_summary: str = ""
knowledge: str = ""
rsp: Optional[Message] = None
@property
def agent_description(self):
return self.g_context.kwargs.agent_description
return self.context.kwargs.agent_description
@property
def language(self):
return self.g_context.kwargs.language or config.language
return self.context.kwargs.language or config.language
@property
def prompt(self):

View file

@ -85,7 +85,7 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc
class WriteCode(Action):
name: str = "WriteCode"
context: Document = Field(default_factory=Document)
i_context: Document = Field(default_factory=Document)
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
async def write_code(self, prompt) -> str:
@ -116,7 +116,7 @@ class WriteCode(Action):
coding_context.task_doc,
exclude=self.context.filename,
git_repo=self.git_repo,
src_workspace=self.g_context.src_workspace,
src_workspace=self.context.src_workspace,
)
prompt = PROMPT_TEMPLATE.format(
@ -132,7 +132,7 @@ class WriteCode(Action):
code = await self.write_code(prompt)
if not coding_context.code_doc:
# avoid root_path pydantic ValidationError if use WriteCode alone
root_path = self.g_context.src_workspace if self.g_context.src_workspace else ""
root_path = self.context.src_workspace if self.context.src_workspace else ""
coding_context.code_doc = Document(filename=coding_context.filename, root_path=str(root_path))
coding_context.code_doc.content = code
return coding_context

View file

@ -119,7 +119,7 @@ REWRITE_CODE_TEMPLATE = """
class WriteCodeReview(Action):
name: str = "WriteCodeReview"
context: CodingContext = Field(default_factory=CodingContext)
i_context: CodingContext = Field(default_factory=CodingContext)
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
async def write_code_review_and_rewrite(self, context_prompt, cr_prompt, filename):
@ -136,14 +136,14 @@ class WriteCodeReview(Action):
async def run(self, *args, **kwargs) -> CodingContext:
iterative_code = self.context.code_doc.content
k = self.g_context.config.code_review_k_times or 1
k = self.context.config.code_review_k_times or 1
for i in range(k):
format_example = FORMAT_EXAMPLE.format(filename=self.context.code_doc.filename)
task_content = self.context.task_doc.content if self.context.task_doc else ""
code_context = await WriteCode.get_codes(
self.context.task_doc,
exclude=self.context.filename,
git_repo=self.g_context.git_repo,
git_repo=self.context.git_repo,
src_workspace=self.src_workspace,
)
context = "\n".join(

View file

@ -161,7 +161,7 @@ class WriteDocstring(Action):
"""
desc: str = "Write docstring for code."
context: Optional[str] = None
i_context: Optional[str] = None
async def run(
self,

View file

@ -13,7 +13,7 @@ from metagpt.actions.action import Action
class WritePRDReview(Action):
name: str = ""
context: Optional[str] = None
i_context: Optional[str] = None
prd: Optional[str] = None
desc: str = "Based on the PRD, conduct a PRD Review, providing clear and detailed feedback"

View file

@ -15,7 +15,7 @@ from metagpt.logs import logger
class WriteTeachingPlanPart(Action):
"""Write Teaching Plan Part"""
context: Optional[str] = None
i_context: Optional[str] = None
topic: str = ""
language: str = "Chinese"
rsp: Optional[str] = None

View file

@ -39,7 +39,7 @@ you should correctly import the necessary classes based on these file locations!
class WriteTest(Action):
name: str = "WriteTest"
context: Optional[TestingContext] = None
i_context: Optional[TestingContext] = None
async def write_code(self, prompt):
code_rsp = await self._aask(prompt)

View file

@ -133,8 +133,8 @@ class Config(metaclass=Singleton):
self.ollama_api_base = self._get("OLLAMA_API_BASE")
self.ollama_api_model = self._get("OLLAMA_API_MODEL")
if not self._get("DISABLE_LLM_PROVIDER_CHECK"):
_ = self.get_default_llm_provider_enum()
# if not self._get("DISABLE_LLM_PROVIDER_CHECK"):
# _ = self.get_default_llm_provider_enum()
self.openai_base_url = self._get("OPENAI_BASE_URL")
self.openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy

View file

@ -153,25 +153,4 @@ def merge_dict(dicts: Iterable[Dict]) -> Dict:
return result
class ConfigMixin(BaseModel):
"""Mixin class for configurable objects"""
# Env/Role/Action will use this config as private config, or use self.context.config as public config
_config: Optional[Config] = None
def __init__(self, config: Optional[Config] = None, **kwargs):
"""Initialize with config"""
super().__init__(**kwargs)
self.set_config(config)
def set(self, k, v, override=False):
"""Set attribute"""
if override or not self.__dict__.get(k):
self.__dict__[k] = v
def set_config(self, config: Config, override=False):
"""Set config"""
self.set("_config", config, override)
config = Config.default()

View file

@ -100,5 +100,57 @@ class Context(BaseModel):
return llm
class ContextMixin(BaseModel):
"""Mixin class for context and config"""
# Env/Role/Action will use this context as private context, or use self.context as public context
_context: Optional[Context] = None
# Env/Role/Action will use this config as private config, or use self.context.config as public config
_config: Optional[Config] = None
def __init__(self, context: Optional[Context] = None, config: Optional[Config] = None, **kwargs):
"""Initialize with config"""
super().__init__(**kwargs)
self.set_context(context)
self.set_config(config)
def set(self, k, v, override=False):
"""Set attribute"""
if override or not self.__dict__.get(k):
self.__dict__[k] = v
def set_context(self, context: Context, override=True):
"""Set context"""
self.set("_context", context, override)
def set_config(self, config: Config, override=False):
"""Set config"""
self.set("_config", config, override)
@property
def config(self):
"""Role config: role config > context config"""
if self._config:
return self._config
return self.context.config
@config.setter
def config(self, config: Config):
"""Set config"""
self.set_config(config)
@property
def context(self):
"""Role context: role context > context"""
if self._context:
return self._context
return context
@context.setter
def context(self, context: Context):
"""Set context"""
self.set_context(context)
# Global context, not in Env
context = Context()

View file

@ -159,9 +159,9 @@ class Engineer(Role):
src_relative_path = self.src_workspace.relative_to(self.git_repo.workdir)
for todo in self.summarize_todos:
summary = await todo.run()
summary_filename = Path(todo.context.design_filename).with_suffix(".md").name
dependencies = {todo.context.design_filename, todo.context.task_filename}
for filename in todo.context.codes_filenames:
summary_filename = Path(todo.i_context.design_filename).with_suffix(".md").name
dependencies = {todo.i_context.design_filename, todo.i_context.task_filename}
for filename in todo.i_context.codes_filenames:
rpath = src_relative_path / filename
dependencies.add(str(rpath))
await code_summaries_pdf_file_repo.save(
@ -169,15 +169,15 @@ class Engineer(Role):
)
is_pass, reason = await self._is_pass(summary)
if not is_pass:
todo.context.reason = reason
tasks.append(todo.context.dict())
todo.i_context.reason = reason
tasks.append(todo.i_context.dict())
await code_summaries_file_repo.save(
filename=Path(todo.context.design_filename).name,
content=todo.context.model_dump_json(),
filename=Path(todo.i_context.design_filename).name,
content=todo.i_context.model_dump_json(),
dependencies=dependencies,
)
else:
await code_summaries_file_repo.delete(filename=Path(todo.context.design_filename).name)
await code_summaries_file_repo.delete(filename=Path(todo.i_context.design_filename).name)
logger.info(f"--max-auto-summarize-code={self.config.max_auto_summarize_code}")
if not tasks or self.config.max_auto_summarize_code == 0:

View file

@ -30,8 +30,7 @@ from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validat
from metagpt.actions import Action, ActionOutput
from metagpt.actions.action_node import ActionNode
from metagpt.actions.add_requirement import UserRequirement
from metagpt.config2 import ConfigMixin
from metagpt.context import Context, context
from metagpt.context import ContextMixin
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.memory import Memory
@ -120,7 +119,7 @@ class RoleContext(BaseModel):
return self.memory.get()
class Role(SerializationMixin, ConfigMixin, BaseModel):
class Role(SerializationMixin, ContextMixin, BaseModel):
"""Role/Agent"""
model_config = ConfigDict(arbitrary_types_allowed=True, exclude=["llm"])
@ -142,7 +141,7 @@ class Role(SerializationMixin, ConfigMixin, BaseModel):
# builtin variables
recovered: bool = False # to tag if a recovered role
latest_observed_msg: Optional[Message] = None # record the latest observed message when interrupted
context: Optional[Context] = Field(default=context, exclude=True)
# context: Optional[Context] = Field(default=context, exclude=True)
__hash__ = object.__hash__ # support Role as hashable type in `Environment.members`
@ -172,16 +171,9 @@ class Role(SerializationMixin, ConfigMixin, BaseModel):
def set_todo(self, value: Optional[Action]):
"""Set action to do and update context"""
if value:
value.g_context = self.context
value.context = self.context
self.rc.todo = value
@property
def config(self):
"""Role config: role config > context config"""
if self._config:
return self._config
return self.context.config
@property
def git_repo(self):
"""Git repo"""

View file

@ -7,8 +7,9 @@
"""
from pydantic import BaseModel
from metagpt.config2 import Config, ConfigMixin, config
from metagpt.config2 import Config, config
from metagpt.configs.llm_config import LLMType
from metagpt.context import ContextMixin
from tests.metagpt.provider.mock_llm_config import mock_llm_config
@ -29,7 +30,7 @@ def test_config_from_dict():
assert cfg.llm["default"].api_key == "mock_api_key"
class ModelX(ConfigMixin, BaseModel):
class ModelX(ContextMixin, BaseModel):
a: str = "a"
b: str = "b"