rewrite_code accept content or file_path of design doc

This commit is contained in:
hongjiongteng 2024-06-20 13:03:51 +08:00
parent 76f0d5aad8
commit 896c7d8d4c

View file

@ -1,3 +1,4 @@
import asyncio
import os
from tenacity import retry, stop_after_attempt, wait_random_exponential
@ -45,24 +46,18 @@ class RewriteCode(Action):
task_doc_input='{"Required packages":["No third-party dependencies required"], ...}'
)
"""
if not design_doc_input or not task_doc_input:
return
code = await aread(code_path)
# Check if design_doc_input and task_doc_input are paths or content, and read if they are paths
if os.path.exists(design_doc_input):
logger.info(f"read from {design_doc_input}")
design_doc_input = await aread(design_doc_input)
if os.path.exists(task_doc_input):
logger.info(f"read from {task_doc_input}")
task_doc_input = await aread(task_doc_input)
code, design_doc, task_doc = await asyncio.gather(
aread(code_path), self._try_aread(design_doc_input), self._try_aread(task_doc_input)
)
context = "\n".join(
[
"## System Design\n" + design_doc_input + "\n",
"## Task\n" + task_doc_input + "\n",
"## System Design\n" + design_doc + "\n",
"## Task\n" + task_doc + "\n",
]
)
@ -97,3 +92,12 @@ class RewriteCode(Action):
code_rsp = await self._aask(rewrite_prompt)
code = CodeParser.parse_code(block="", text=code_rsp)
return result, code
@staticmethod
async def _try_aread(input: str) -> str:
"""Try to read from the path if it's a file; return input directly if not."""
if os.path.exists(input):
return await aread(input)
return input