remove mode of get_codes function

This commit is contained in:
mannaandpoem 2024-01-24 16:44:10 +08:00
parent e42dc522c2
commit 3450c240c9
3 changed files with 18 additions and 28 deletions

View file

@ -16,7 +16,6 @@
"""
import json
from typing import Literal
from pydantic import Field
from tenacity import retry, stop_after_attempt, wait_random_exponential
@ -114,7 +113,7 @@ class WriteCode(Action):
code_context = coding_context.code_doc.content
elif code_plan_and_change:
code_context = await self.get_codes(
coding_context.task_doc, exclude=self.i_context.filename, project_repo=self.repo, mode="incremental"
coding_context.task_doc, exclude=self.i_context.filename, project_repo=self.repo, use_inc=True
)
else:
code_context = await self.get_codes(
@ -155,39 +154,31 @@ class WriteCode(Action):
return coding_context
@staticmethod
async def get_codes(
task_doc: Document, exclude: str, project_repo: ProjectRepo, mode: Literal["normal", "incremental"] = "normal"
) -> str:
async def get_codes(task_doc: Document, exclude: str, project_repo: ProjectRepo, use_inc: bool = False) -> str:
"""
Get code snippets based on different modes.
Get code snippets that meet the requirements in various scenarios.
Attributes:
task_doc (Document): Document object of the task file.
exclude (str): Specifies the filename to be excluded from the code snippets.
project_repo (ProjectRepo): ProjectRepo object of the project.
mode (str): Specifies the mode, either "normal" or "incremental" (default is "normal").
use_inc (bool): Specifies whether is incremental development.
Returns:
str: Code snippets.
Description:
If mode is set to "normal", it returns code snippets for the regular coding phase,
i.e., all the code generated before writing the current file.
If mode is set to "incremental", it returns code snippets for generating the code plan and change,
building upon the existing code in the "normal" mode and adding code for the current file's older versions.
"""
if not task_doc:
return ""
if not task_doc.content:
task_doc = project_repo.docs.task.get(filename=task_doc.filename)
m = json.loads(task_doc.content)
code_filenames = m.get(TASK_LIST.key, []) if mode == "normal" else m.get(REFINED_TASK_LIST.key, [])
code_filenames = m.get(TASK_LIST.key, []) if use_inc else m.get(REFINED_TASK_LIST.key, [])
codes = []
src_file_repo = project_repo.srcs
if mode == "incremental":
if use_inc:
src_files = src_file_repo.all_files
# Get the old workspace that are created by the previous WriteCodePlanAndChange action
old_file_repo = project_repo.git_repo.new_file_repository(relative_path=project_repo.old_workspace)
old_files = old_file_repo.all_files
# Get the union of the files in the src and old workspaces
@ -213,7 +204,7 @@ class WriteCode(Action):
continue
codes.append(f"----- {filename}\n```{doc.content}```")
elif mode == "normal":
else:
for filename in code_filenames:
# Exclude the current file to get the context code snippets for generating the current file
if filename == exclude:
@ -222,4 +213,5 @@ class WriteCode(Action):
if not doc:
continue
codes.append(f"----- {filename}\n```{doc.content}```")
return "\n".join(codes)

View file

@ -137,11 +137,8 @@ class WriteCodeReview(Action):
async def run(self, *args, **kwargs) -> CodingContext:
iterative_code = self.i_context.code_doc.content
# k = self.context.config.code_review_k_times or 1
k = 1
code_plan_and_change_doc = await self.repo.get(filename=CODE_PLAN_AND_CHANGE_FILENAME)
code_plan_and_change = code_plan_and_change_doc.content if code_plan_and_change_doc else ""
mode = "incremental" if code_plan_and_change else "normal"
k = self.context.config.code_review_k_times or 1
for i in range(k):
format_example = FORMAT_EXAMPLE.format(filename=self.i_context.code_doc.filename)
task_content = self.i_context.task_doc.content if self.i_context.task_doc else ""
@ -149,10 +146,10 @@ class WriteCodeReview(Action):
self.i_context.task_doc,
exclude=self.i_context.filename,
project_repo=self.repo.with_src_path(self.context.src_workspace),
mode=mode,
use_inc=self.config.inc,
)
if not code_plan_and_change:
if not self.config.inc:
context = "\n".join(
[
"## System Design\n" + str(self.i_context.design_doc) + "\n",
@ -162,10 +159,11 @@ class WriteCodeReview(Action):
)
else:
requirement_doc = await self.repo.docs.get(filename=REQUIREMENT_FILENAME)
code_plan_and_change_doc = await self.repo.get(filename=CODE_PLAN_AND_CHANGE_FILENAME)
context = "\n".join(
[
"## User New Requirements\n" + str(requirement_doc) + "\n",
"## Code Plan And Change\n" + code_plan_and_change + "\n",
"## Code Plan And Change\n" + str(code_plan_and_change_doc) + "\n",
"## System Design\n" + str(self.i_context.design_doc) + "\n",
"## Tasks\n" + task_content + "\n",
"## Code Files\n" + code_context + "\n",

View file

@ -23,7 +23,7 @@ import json
import os
from collections import defaultdict
from pathlib import Path
from typing import Literal, Set
from typing import Set
from metagpt.actions import Action, WriteCode, WriteCodeReview, WriteTasks
from metagpt.actions.fix_bug import FixBug
@ -100,7 +100,7 @@ class Engineer(Role):
m = json.loads(task_msg.content)
return m.get(TASK_LIST.key) or m.get(REFINED_TASK_LIST.key)
async def _act_sp_with_cr(self, review=False, mode: Literal["normal", "incremental"] = "normal") -> Set[str]:
async def _act_sp_with_cr(self, review=False) -> Set[str]:
changed_files = set()
for todo in self.code_todos:
"""
@ -118,7 +118,7 @@ class Engineer(Role):
coding_context = await action.run()
dependencies = {coding_context.design_doc.root_relative_path, coding_context.task_doc.root_relative_path}
if mode == "incremental":
if self.config.inc:
dependencies.add(os.path.join(CODE_PLAN_AND_CHANGE_FILE_REPO, CODE_PLAN_AND_CHANGE_FILENAME))
await self.project_repo.srcs.save(
filename=coding_context.filename,