diff --git a/metagpt/actions/di/rewrite_code.py b/metagpt/actions/di/rewrite_code.py index 504ff72c8..0b03b534e 100644 --- a/metagpt/actions/di/rewrite_code.py +++ b/metagpt/actions/di/rewrite_code.py @@ -1,3 +1,4 @@ +import asyncio import os from tenacity import retry, stop_after_attempt, wait_random_exponential @@ -45,24 +46,18 @@ class RewriteCode(Action): task_doc_input='{"Required packages":["No third-party dependencies required"], ...}' ) """ + if not design_doc_input or not task_doc_input: return - code = await aread(code_path) - - # Check if design_doc_input and task_doc_input are paths or content, and read if they are paths - if os.path.exists(design_doc_input): - logger.info(f"read from {design_doc_input}") - design_doc_input = await aread(design_doc_input) - - if os.path.exists(task_doc_input): - logger.info(f"read from {task_doc_input}") - task_doc_input = await aread(task_doc_input) + code, design_doc, task_doc = await asyncio.gather( + aread(code_path), self._try_aread(design_doc_input), self._try_aread(task_doc_input) + ) context = "\n".join( [ - "## System Design\n" + design_doc_input + "\n", - "## Task\n" + task_doc_input + "\n", + "## System Design\n" + design_doc + "\n", + "## Task\n" + task_doc + "\n", ] ) @@ -97,3 +92,12 @@ class RewriteCode(Action): code_rsp = await self._aask(rewrite_prompt) code = CodeParser.parse_code(block="", text=code_rsp) return result, code + + @staticmethod + async def _try_aread(input: str) -> str: + """Try to read from the path if it's a file; return input directly if not.""" + + if os.path.exists(input): + return await aread(input) + + return input