feat: rewrite Engineer & WriteCode & WriteCodeReview

This commit is contained in:
莘权 马 2023-11-23 17:49:38 +08:00
parent 438fbe28c0
commit 2032a38542
7 changed files with 152 additions and 180 deletions

View file

@ -7,13 +7,15 @@
@Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.1.3 of RFC 116, modify the data type of the `cause_by`
value of the `Message` object.
"""
import json
from tenacity import retry, stop_after_attempt, wait_fixed
from metagpt.actions import WriteDesign
from metagpt.actions.action import Action
from metagpt.const import WORKSPACE_ROOT
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.schema import CodingContext
from metagpt.utils.common import CodeParser, any_to_str
PROMPT_TEMPLATE = """
@ -46,7 +48,7 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc
class WriteCode(Action):
def __init__(self, name="WriteCode", context: list[Message] = None, llm=None):
def __init__(self, name="WriteCode", context=None, llm=None):
super().__init__(name, context, llm)
def _is_invalid(self, filename):
@ -70,15 +72,19 @@ class WriteCode(Action):
logger.info(f"Saving Code to {code_path}")
@retry(stop=stop_after_attempt(2), wait=wait_fixed(1))
async def write_code(self, prompt):
async def write_code(self, prompt) -> str:
code_rsp = await self._aask(prompt)
code = CodeParser.parse_code(block="", text=code_rsp)
return code
async def run(self, context, filename):
prompt = PROMPT_TEMPLATE.format(context=context, filename=filename)
logger.info(f"Writing {filename}..")
async def run(self, *args, **kwargs) -> CodingContext:
m = json.loads(self.context.content)
coding_context = CodingContext(**m)
context = "\n".join(
[coding_context.design_doc.content, coding_context.task_doc.content, coding_context.code_doc.content]
)
prompt = PROMPT_TEMPLATE.format(context=context, filename=self.context.filename)
logger.info(f"Writing {coding_context.filename}..")
code = await self.write_code(prompt)
# code_rsp = await self._aask_v1(prompt, "code_rsp", OUTPUT_MAPPING)
# self._save(context, filename, code)
return code
coding_context.code_doc.content = code
return coding_context

View file

@ -10,7 +10,7 @@ from tenacity import retry, stop_after_attempt, wait_fixed
from metagpt.actions.action import Action
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.schema import CodingContext
from metagpt.utils.common import CodeParser
PROMPT_TEMPLATE = """
@ -63,7 +63,7 @@ FORMAT_EXAMPLE = """
class WriteCodeReview(Action):
def __init__(self, name="WriteCodeReview", context: list[Message] = None, llm=None):
def __init__(self, name="WriteCodeReview", context=None, llm=None):
super().__init__(name, context, llm)
@retry(stop=stop_after_attempt(2), wait=wait_fixed(1))
@ -72,11 +72,18 @@ class WriteCodeReview(Action):
code = CodeParser.parse_code(block="", text=code_rsp)
return code
async def run(self, context, code, filename):
format_example = FORMAT_EXAMPLE.format(filename=filename)
prompt = PROMPT_TEMPLATE.format(context=context, code=code, filename=filename, format_example=format_example)
logger.info(f"Code review {filename}..")
async def run(self, *args, **kwargs) -> CodingContext:
format_example = FORMAT_EXAMPLE.format(filename=self.context.code_doc.filename)
context = "\n".join(
[self.context.design_doc.content, self.context.task_doc.content, self.context.code_doc.content]
)
prompt = PROMPT_TEMPLATE.format(
context=context,
code=self.context.code_doc.content,
filename=self.context.code_doc.filename,
format_example=format_example,
)
logger.info(f"Code review {self.context.code_doc.filename}..")
code = await self.write_code(prompt)
# code_rsp = await self._aask_v1(prompt, "code_rsp", OUTPUT_MAPPING)
# self._save(context, filename, code)
return code
self.context.code_doc.content = code
return self.context

View file

@ -94,6 +94,7 @@ class Config(metaclass=Singleton):
self.prompt_format = self._get("PROMPT_FORMAT", "markdown")
self.git_repo = None
self.src_workspace = None
def _init_with_config_files_and_env(self, configs: dict, yaml_file):
"""Load from config/key.yaml, config/config.yaml, and env in decreasing order of priority"""

View file

@ -11,47 +11,20 @@
3. Fix bug: Add logic for handling asynchronous message processing when messages are not ready.
4. Supplemented the external transmission of internal messages.
"""
import asyncio
import shutil
from collections import OrderedDict
from __future__ import annotations
import json
from pathlib import Path
from metagpt.actions import WriteCode, WriteCodeReview, WriteDesign, WriteTasks
from metagpt.const import WORKSPACE_ROOT
from metagpt.actions import Action, WriteCode, WriteCodeReview, WriteTasks
from metagpt.config import CONFIG
from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
from metagpt.logs import logger
from metagpt.roles import Role
from metagpt.schema import Message
from metagpt.utils.common import CodeParser, any_to_str
from metagpt.schema import CodingContext, Document, Documents, Message
from metagpt.utils.special_tokens import FILENAME_CODE_SEP, MSG_SEP
async def gather_ordered_k(coros, k) -> list:
tasks = OrderedDict()
results = [None] * len(coros)
done_queue = asyncio.Queue()
for i, coro in enumerate(coros):
if len(tasks) >= k:
done, _ = await asyncio.wait(tasks.keys(), return_when=asyncio.FIRST_COMPLETED)
for task in done:
index = tasks.pop(task)
await done_queue.put((index, task.result()))
task = asyncio.create_task(coro)
tasks[task] = i
if tasks:
done, _ = await asyncio.wait(tasks.keys())
for task in done:
index = tasks[task]
await done_queue.put((index, task.result()))
while not done_queue.empty():
index, result = await done_queue.get()
results[index] = result
return results
class Engineer(Role):
"""
Represents an Engineer role responsible for writing and possibly reviewing code.
@ -77,105 +50,19 @@ class Engineer(Role):
) -> None:
"""Initializes the Engineer role with given attributes."""
super().__init__(name, profile, goal, constraints)
self._init_actions([WriteCode])
self.use_code_review = use_code_review
if self.use_code_review:
self._init_actions([WriteCode, WriteCodeReview])
self._watch([WriteTasks])
self.todos = []
self.n_borg = n_borg
@classmethod
def parse_tasks(self, task_msg: Message) -> list[str]:
if task_msg.instruct_content:
return task_msg.instruct_content.dict().get("Task list")
return CodeParser.parse_file_list(block="Task list", text=task_msg.content)
@staticmethod
def _parse_tasks(task_msg: Document) -> list[str]:
m = json.loads(task_msg.content)
return m.get("Task list")
@classmethod
def parse_code(self, code_text: str) -> str:
return CodeParser.parse_code(block="", text=code_text)
@classmethod
def parse_workspace(cls, system_design_msg: Message) -> str:
if system_design_msg.instruct_content:
return system_design_msg.instruct_content.dict().get("Python package name").strip().strip("'").strip('"')
return CodeParser.parse_str(block="Python package name", text=system_design_msg.content)
def get_workspace(self) -> Path:
msg = self._rc.memory.get_by_action(WriteDesign)[-1]
if not msg:
return WORKSPACE_ROOT / "src"
workspace = self.parse_workspace(msg)
# Codes are written in workspace/{package_name}/{package_name}
return WORKSPACE_ROOT / workspace / workspace
def recreate_workspace(self):
workspace = self.get_workspace()
try:
shutil.rmtree(workspace)
except FileNotFoundError:
pass # The folder does not exist, but we don't care
workspace.mkdir(parents=True, exist_ok=True)
def write_file(self, filename: str, code: str):
workspace = self.get_workspace()
filename = filename.replace('"', "").replace("\n", "")
file = workspace / filename
file.parent.mkdir(parents=True, exist_ok=True)
file.write_text(code)
return file
async def _act_mp(self) -> Message:
# self.recreate_workspace()
todo_coros = []
for todo in self.todos:
todo_coro = WriteCode().run(
context=self._rc.memory.get_by_actions([WriteTasks, WriteDesign]),
filename=todo,
)
todo_coros.append(todo_coro)
rsps = await gather_ordered_k(todo_coros, self.n_borg)
for todo, code_rsp in zip(self.todos, rsps):
_ = self.parse_code(code_rsp)
logger.info(todo)
logger.info(code_rsp)
# self.write_file(todo, code)
msg = Message(content=code_rsp, role=self.profile, cause_by=self._rc.todo)
self._rc.memory.add(msg)
self.publish_message(msg)
del self.todos[0]
logger.info(f"Done {self.get_workspace()} generating.")
msg = Message(content="all done.", role=self.profile, cause_by=self._rc.todo)
return msg
async def _act_sp(self) -> Message:
code_msg_all = [] # gather all code info, will pass to qa_engineer for tests later
for todo in self.todos:
code = await WriteCode().run(context=self._rc.history, filename=todo)
# logger.info(todo)
# logger.info(code_rsp)
# code = self.parse_code(code_rsp)
file_path = self.write_file(todo, code)
msg = Message(content=code, role=self.profile, cause_by=self._rc.todo)
self._rc.memory.add(msg)
self.publish_message(msg)
code_msg = todo + FILENAME_CODE_SEP + str(file_path)
code_msg_all.append(code_msg)
logger.info(f"Done {self.get_workspace()} generating.")
msg = Message(
content=MSG_SEP.join(code_msg_all),
role=self.profile,
cause_by=self._rc.todo,
send_to="Edward",
)
return msg
async def _act_sp_precision(self) -> Message:
async def _act_sp_precision(self, review=False) -> Message:
code_msg_all = [] # gather all code info, will pass to qa_engineer for tests later
src_file_repo = CONFIG.git_repo.new_file_repository(CONFIG.src_workspace)
for todo in self.todos:
"""
# Select essential information from the historical data to reduce the length of the prompt (summarized from human experience):
@ -184,30 +71,29 @@ class Engineer(Role):
3. Do we need other codes (currently needed)?
TODO: The goal is not to need it. After clear task decomposition, based on the design idea, you should be able to write a single file without needing other codes. If you can't, it means you need a clearer definition. This is the key to writing longer code.
"""
context = []
msg = self._rc.memory.get_by_actions([WriteDesign, WriteTasks, WriteCode])
for m in msg:
context.append(m.content)
context_str = "\n".join(context)
# Write code
code = await WriteCode().run(context=context_str, filename=todo)
coding_context = await todo.run()
# Code review
if self.use_code_review:
if review:
try:
rewrite_code = await WriteCodeReview().run(context=context_str, code=code, filename=todo)
code = rewrite_code
coding_context = await WriteCodeReview(context=coding_context, llm=self._llm).run()
except Exception as e:
logger.error("code review failed!", e)
pass
file_path = self.write_file(todo, code)
msg = Message(content=code, role=self.profile, cause_by=WriteCode)
await src_file_repo.save(
coding_context.filename,
dependencies={coding_context.design_doc.root_relative_path, coding_context.task_doc.root_relative_path},
content=coding_context.code_doc.content,
)
msg = Message(
content=coding_context.json(), instruct_content=coding_context, role=self.profile, cause_by=WriteCode
)
self._rc.memory.add(msg)
self.publish_message(msg)
code_msg = todo + FILENAME_CODE_SEP + str(file_path)
code_msg = coding_context.filename + FILENAME_CODE_SEP + str(coding_context.code_doc.root_relative_path)
code_msg_all.append(code_msg)
logger.info(f"Done {self.get_workspace()} generating.")
logger.info(f"Done {CONFIG.src_workspace} generating.")
msg = Message(
content=MSG_SEP.join(code_msg_all),
role=self.profile,
@ -218,22 +104,92 @@ class Engineer(Role):
async def _act(self) -> Message:
"""Determines the mode of action based on whether code review is used."""
if not self._rc.todo:
return None
if self.use_code_review:
return await self._act_sp_precision()
return await self._act_sp()
return await self._act_sp_precision(review=self.use_code_review)
async def _observe(self) -> int:
ret = await super(Engineer, self)._observe()
if ret == 0:
return ret
async def _think(self) -> Action | None:
if not CONFIG.src_workspace:
CONFIG.src_workspace = CONFIG.git_repo.workdir / CONFIG.git_repo.workdir.name
# Prepare file repos
src_file_repo = CONFIG.git_repo.new_file_repository(CONFIG.src_workspace)
changed_src_files = src_file_repo.changed_files
task_file_repo = CONFIG.git_repo.new_file_repository(TASK_FILE_REPO)
changed_task_files = task_file_repo.changed_files
design_file_repo = CONFIG.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO)
# Parse task lists
for message in self._rc.news:
if not message.cause_by == any_to_str(WriteTasks):
changed_files = Documents()
# 由上游变化导致的recode
for filename in changed_task_files:
design_doc = await design_file_repo.get(filename)
task_doc = await task_file_repo.get(filename)
task_list = self._parse_tasks(task_doc)
for task_filename in task_list:
old_code_doc = await src_file_repo.get(task_filename)
if not old_code_doc:
old_code_doc = Document(root_path=str(src_file_repo.root_path), filename=task_filename, content="")
context = CodingContext(
filename=task_filename, design_doc=design_doc, task_doc=task_doc, code_doc=old_code_doc
)
coding_doc = Document(
root_path=str(src_file_repo.root_path), filename=task_filename, content=context.json()
)
if task_filename in changed_files.docs:
logger.error(
f"Log to expose potential file name conflicts: {coding_doc.json()} & "
f"{changed_files.docs[task_filename].json()}"
)
changed_files.docs[task_filename] = coding_doc
self.todos = [WriteCode(context=i, llm=self._llm) for i in changed_files.docs.values()]
# 用户直接修改的code
dependency = await CONFIG.git_repo.get_dependency()
for filename in changed_src_files:
if filename in changed_files.docs:
continue
self.todos = self.parse_tasks(message)
return 1
coding_doc = await self._new_coding_doc(
filename=filename,
src_file_repo=src_file_repo,
task_file_repo=task_file_repo,
design_file_repo=design_file_repo,
dependency=dependency,
)
changed_files.docs[filename] = coding_doc
self.todos.append(WriteCode(context=coding_doc, llm=self._llm))
# 仅单测
if CONFIG.REQA_FILENAME and CONFIG.REQA_FILENAME not in changed_files.docs:
context = await self._new_coding_context(
filename=CONFIG.REQA_FILENAME,
src_file_repo=src_file_repo,
task_file_repo=task_file_repo,
design_file_repo=design_file_repo,
dependency=dependency,
)
self.publish_message(Message(content=context.json(), instruct_content=context, cause_by=WriteCode))
return 0
if self.todos:
self._rc.todo = self.todos[0]
return self._rc.todo # For agent store
@staticmethod
async def _new_coding_context(
filename, src_file_repo, task_file_repo, design_file_repo, dependency
) -> CodingContext:
old_code_doc = await src_file_repo.get(filename)
if not old_code_doc:
old_code_doc = Document(root_path=str(src_file_repo.root_path), filename=filename, content="")
dependencies = {Path(i) for i in dependency.get(old_code_doc.root_relative_path)}
task_doc = None
design_doc = None
for i in dependencies:
if str(i.parent) == TASK_FILE_REPO:
task_doc = task_file_repo.get(i.filename)
elif str(i.parent) == SYSTEM_DESIGN_FILE_REPO:
design_doc = design_file_repo.get(i.filename)
context = CodingContext(filename=filename, design_doc=design_doc, task_doc=task_doc, code_doc=old_code_doc)
return context
@staticmethod
async def _new_coding_doc(filename, src_file_repo, task_file_repo, design_file_repo, dependency):
context = await Engineer._new_coding_context(
filename, src_file_repo, task_file_repo, design_file_repo, dependency
)
coding_doc = Document(root_path=str(src_file_repo.root_path), filename=filename, content=context.json())
return coding_doc

View file

@ -151,13 +151,6 @@ class QaEngineer(Role):
)
self.publish_message(msg)
async def _observe(self) -> int:
await super()._observe()
self._rc.news = [
msg for msg in self._rc.news if self.profile in msg.send_to
] # only relevant msgs count as observed news
return len(self._rc.news)
async def _act(self) -> Message:
if self.test_round > self.test_round_allowed:
result_msg = Message(

View file

@ -238,3 +238,10 @@ class MessageQueue:
logger.warning(f"JSON load failed: {v}, error:{e}")
return q
class CodingContext(BaseModel):
filename: str
design_doc: Document
task_doc: Document
code_doc: Document

View file

@ -56,6 +56,7 @@ def main(
run_tests: bool = False,
implement: bool = True,
project_path: str = None,
reqa_file: str = None,
):
"""
We are a software startup comprised of AI. By investing in us,
@ -71,6 +72,7 @@ def main(
:return:
"""
CONFIG.WORKDIR = project_path
CONFIG.REQA_FILENAME = reqa_file
asyncio.run(startup(idea, investment, n_round, code_review, run_tests, implement))