refactor: 代码优化

This commit is contained in:
莘权 马 2023-11-27 19:23:20 +08:00
parent ef9a925281
commit 4c99107a33
2 changed files with 46 additions and 3 deletions

View file

@ -16,6 +16,7 @@ from metagpt.const import TEST_CODES_FILE_REPO, TEST_OUTPUTS_FILE_REPO
from metagpt.logs import logger
from metagpt.schema import RunCodeResult
from metagpt.utils.common import CodeParser
from metagpt.utils.file_repository import FileRepository
PROMPT_TEMPLATE = """
NOTICE
@ -50,7 +51,7 @@ class DebugError(Action):
super().__init__(name, context, llm)
async def run(self, *args, **kwargs) -> str:
output_doc = await CONFIG.git_repo.new_file_repository(TEST_OUTPUTS_FILE_REPO).get(self.context.output_filename)
output_doc = await FileRepository.get_file(filename=self.context.output_filename, relative_path=TEST_OUTPUTS_FILE_REPO)
if not output_doc:
return ""
output_detail = RunCodeResult.loads(output_doc.content)
@ -60,10 +61,10 @@ class DebugError(Action):
return ""
logger.info(f"Debug and rewrite {self.context.code_filename}")
code_doc = await CONFIG.git_repo.new_file_repository(CONFIG.src_workspace).get(self.context.code_filename)
code_doc = await FileRepository.get_file(filename=self.context.code_filename, relative_path=CONFIG.src_workspace)
if not code_doc:
return ""
test_doc = await CONFIG.git_repo.new_file_repository(TEST_CODES_FILE_REPO).get(self.context.test_filename)
test_doc = await FileRepository.get_file(filename=self.context.test_filename, relative_path=TEST_CODES_FILE_REPO)
if not test_doc:
return ""
prompt = PROMPT_TEMPLATE.format(code=code_doc.content, test_code=test_doc.content, logs=output_detail.stderr)

View file

@ -16,6 +16,7 @@ from typing import Dict, List, Set
import aiofiles
from metagpt.config import CONFIG
from metagpt.logs import logger
from metagpt.schema import Document
from metagpt.utils.json_to_markdown import json_to_markdown
@ -186,3 +187,44 @@ class FileRepository:
filename = Path(doc.filename).with_suffix(".md")
await self.save(filename=str(filename), content=json_to_markdown(m))
logger.info(f"File Saved: {str(filename)}")
@staticmethod
async def get_file(filename: Path | str, relative_path: Path | str = ".") -> Document | None:
"""Retrieve a specific file from the file repository.
:param filename: The name or path of the file to retrieve.
:type filename: Path or str
:param relative_path: The relative path within the file repository.
:type relative_path: Path or str, optional
:return: The document representing the file, or None if not found.
:rtype: Document or None
"""
file_repo = CONFIG.git_repo.new_file_repository(relative_path=relative_path)
return await file_repo.get(filename=filename)
@staticmethod
async def get_all_files(relative_path: Path | str = ".") -> List[Document]:
"""Retrieve all files from the file repository.
:param relative_path: The relative path within the file repository.
:type relative_path: Path or str, optional
:return: A list of documents representing all files in the repository.
:rtype: List[Document]
"""
file_repo = CONFIG.git_repo.new_file_repository(relative_path=relative_path)
return await file_repo.get_all()
@staticmethod
async def save_file(filename: Path | str, content, dependencies: List[str] = None, relative_path: Path | str = "."):
"""Save a file to the file repository.
:param filename: The name or path of the file to save.
:type filename: Path or str
:param content: The content of the file.
:param dependencies: A list of dependencies for the file.
:type dependencies: List[str], optional
:param relative_path: The relative path within the file repository.
:type relative_path: Path or str, optional
"""
file_repo = CONFIG.git_repo.new_file_repository(relative_path=relative_path)
return await file_repo.save(filename=filename, content=content, dependencies=dependencies)