From 697f790c837f248321d0c7705da6c5ba7d840897 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Tue, 12 Dec 2023 16:42:08 +0800 Subject: [PATCH] bugfix: write code add related code file context --- metagpt/actions/write_code.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/metagpt/actions/write_code.py b/metagpt/actions/write_code.py index 4c138a124..b20539e78 100644 --- a/metagpt/actions/write_code.py +++ b/metagpt/actions/write_code.py @@ -14,12 +14,13 @@ 3. Encapsulate the input of RunCode into RunCodeContext and encapsulate the output of RunCode into RunCodeResult to standardize and unify parameter passing between WriteCode, RunCode, and DebugError. """ +import json from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions.action import Action from metagpt.config import CONFIG -from metagpt.const import CODE_SUMMARIES_FILE_REPO, TEST_OUTPUTS_FILE_REPO +from metagpt.const import CODE_SUMMARIES_FILE_REPO, TEST_OUTPUTS_FILE_REPO, TASK_FILE_REPO from metagpt.logs import logger from metagpt.schema import CodingContext, Document, RunCodeResult from metagpt.utils.common import CodeParser @@ -101,10 +102,11 @@ class WriteCode(Action): if test_doc: test_detail = RunCodeResult.loads(test_doc.content) logs = test_detail.stderr + code_context = await self._get_codes(coding_context.task_doc) prompt = PROMPT_TEMPLATE.format( design=coding_context.design_doc.content, tasks=coding_context.task_doc.content if coding_context.task_doc else "", - code=coding_context.code_doc.content if coding_context.code_doc else "", + code=code_context, logs=logs, filename=self.context.filename, summary_log=summary_doc.content if summary_doc else "", @@ -115,3 +117,21 @@ class WriteCode(Action): coding_context.code_doc = Document(filename=coding_context.filename, root_path=CONFIG.src_workspace) coding_context.code_doc.content = code return coding_context + + @staticmethod + async def _get_codes(task_doc) -> str: + if not task_doc: + return "" + if not task_doc.content: + task_doc.content = FileRepository.get_file(filename=task_doc.filename, relative_path=TASK_FILE_REPO) + m = json.loads(task_doc.content) + code_filenames = m.get("Task list", []) + codes = [] + src_file_repo = CONFIG.git_repo.new_file_repository(relative_path=CONFIG.src_workspace) + for filename in code_filenames: + doc = await src_file_repo.get(filename=filename) + if not doc: + continue + codes.append(doc.content) + return "\n----------\n".join(codes) +