From 6ebb9952b892dd814804f0f6ec46f3a9669cb85b Mon Sep 17 00:00:00 2001 From: liushaojie Date: Fri, 30 Aug 2024 17:43:52 +0800 Subject: [PATCH] update: editor --- metagpt/actions/write_code.py | 4 +- metagpt/actions/write_code_review.py | 4 +- metagpt/ext/cr/actions/modify_code.py | 4 +- metagpt/roles/qa_engineer.py | 4 +- metagpt/tools/libs/cr.py | 4 +- metagpt/tools/libs/editor.py | 145 ++++++++++++++- metagpt/tools/libs/file_io_operator.py | 149 ---------------- metagpt/utils/report.py | 8 +- tests/metagpt/test_reporter.py | 6 +- tests/metagpt/tools/libs/test_editor.py | 226 +++++++++++++++++------- 10 files changed, 317 insertions(+), 237 deletions(-) delete mode 100644 metagpt/tools/libs/file_io_operator.py diff --git a/metagpt/actions/write_code.py b/metagpt/actions/write_code.py index a2d55ff13..da25fe621 100644 --- a/metagpt/actions/write_code.py +++ b/metagpt/actions/write_code.py @@ -29,7 +29,7 @@ from metagpt.logs import logger from metagpt.schema import CodingContext, Document, RunCodeResult from metagpt.utils.common import CodeParser, get_markdown_code_block_type from metagpt.utils.project_repo import ProjectRepo -from metagpt.utils.report import FileIOOperatorReporter +from metagpt.utils.report import EditorReporter PROMPT_TEMPLATE = """ NOTICE @@ -152,7 +152,7 @@ class WriteCode(Action): summary_log=summary_doc.content if summary_doc else "", ) logger.info(f"Writing {coding_context.filename}..") - async with FileIOOperatorReporter(enable_llm_stream=True) as reporter: + async with EditorReporter(enable_llm_stream=True) as reporter: await reporter.async_report({"type": "code", "filename": coding_context.filename}, "meta") code = await self.write_code(prompt) if not coding_context.code_doc: diff --git a/metagpt/actions/write_code_review.py b/metagpt/actions/write_code_review.py index a7141747a..6a283f812 100644 --- a/metagpt/actions/write_code_review.py +++ b/metagpt/actions/write_code_review.py @@ -22,7 +22,7 @@ from metagpt.schema import CodingContext, Document from metagpt.tools.tool_registry import register_tool from metagpt.utils.common import CodeParser, aread, awrite from metagpt.utils.project_repo import ProjectRepo -from metagpt.utils.report import FileIOOperatorReporter +from metagpt.utils.report import EditorReporter PROMPT_TEMPLATE = """ # System @@ -144,7 +144,7 @@ class WriteCodeReview(Action): return result, None # if LBTM, rewrite code - async with FileIOOperatorReporter(enable_llm_stream=True) as reporter: + async with EditorReporter(enable_llm_stream=True) as reporter: await reporter.async_report( {"type": "code", "filename": filename, "src_path": doc.root_relative_path}, "meta" ) diff --git a/metagpt/ext/cr/actions/modify_code.py b/metagpt/ext/cr/actions/modify_code.py index e4c637347..820bdae4a 100644 --- a/metagpt/ext/cr/actions/modify_code.py +++ b/metagpt/ext/cr/actions/modify_code.py @@ -13,7 +13,7 @@ from metagpt.ext.cr.utils.cleaner import ( rm_patch_useless_part, ) from metagpt.utils.common import CodeParser -from metagpt.utils.report import FileIOOperatorReporter +from metagpt.utils.report import EditorReporter SYSTEM_MSGS_PROMPT = """ You're an adaptive software developer who excels at refining code based on user inputs. You're proficient in creating Git patches to represent code modifications. @@ -100,7 +100,7 @@ class ModifyCode(Action): ) patch_file = output_dir / f"{patch_target_file_name}.patch" patch_file.parent.mkdir(exist_ok=True, parents=True) - async with FileIOOperatorReporter(enable_llm_stream=True) as reporter: + async with EditorReporter(enable_llm_stream=True) as reporter: await reporter.async_report( {"type": "Patch", "src_path": str(patch_file), "filename": patch_file.name}, "meta" ) diff --git a/metagpt/roles/qa_engineer.py b/metagpt/roles/qa_engineer.py index de2b27372..fc8fa5353 100644 --- a/metagpt/roles/qa_engineer.py +++ b/metagpt/roles/qa_engineer.py @@ -33,7 +33,7 @@ from metagpt.utils.common import ( parse_recipient, ) from metagpt.utils.project_repo import ProjectRepo -from metagpt.utils.report import FileIOOperatorReporter +from metagpt.utils.report import EditorReporter class QaEngineer(Role): @@ -80,7 +80,7 @@ class QaEngineer(Role): context = TestingContext(filename=test_doc.filename, test_doc=test_doc, code_doc=code_doc) context = await WriteTest(i_context=context, context=self.context, llm=self.llm).run() - async with FileIOOperatorReporter(enable_llm_stream=True) as reporter: + async with EditorReporter(enable_llm_stream=True) as reporter: await reporter.async_report({"type": "test", "filename": test_doc.filename}, "meta") doc = await self.repo.tests.save_doc( diff --git a/metagpt/tools/libs/cr.py b/metagpt/tools/libs/cr.py index 87a686eb1..7d156b4d6 100644 --- a/metagpt/tools/libs/cr.py +++ b/metagpt/tools/libs/cr.py @@ -13,7 +13,7 @@ from metagpt.ext.cr.actions.modify_code import ModifyCode from metagpt.ext.cr.utils.schema import Point from metagpt.tools.libs.browser import Browser from metagpt.tools.tool_registry import register_tool -from metagpt.utils.report import FileIOOperatorReporter +from metagpt.utils.report import EditorReporter @register_tool(tags=["codereview"], include_functions=["review", "fix"]) @@ -86,7 +86,7 @@ class CodeReview: else: async with aiofiles.open(patch_path, encoding="utf-8") as f: patch_file_content = await f.read() - await FileIOOperatorReporter().async_report(patch_path) + await EditorReporter().async_report(patch_path) if not patch_path.endswith((".diff", ".patch")): name = Path(patch_path).name patch_file_content = "".join( diff --git a/metagpt/tools/libs/editor.py b/metagpt/tools/libs/editor.py index 71f297acd..81f2bd4a7 100644 --- a/metagpt/tools/libs/editor.py +++ b/metagpt/tools/libs/editor.py @@ -3,26 +3,38 @@ This file is borrowed from OpenDevin You can find the original repository here: https://github.com/All-Hands-AI/OpenHands/blob/main/openhands/runtime/plugins/agent_skills/file_ops/file_ops.py """ - +import base64 import os import re import shutil import tempfile from pathlib import Path -from typing import Optional, Union +from typing import List, Optional, Tuple, Union -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict +from metagpt.config2 import Config from metagpt.const import DEFAULT_WORKSPACE_ROOT from metagpt.logs import logger from metagpt.tools.libs.linter import Linter from metagpt.tools.tool_registry import register_tool +from metagpt.utils import read_docx +from metagpt.utils.common import aread, aread_bin, awrite_bin, check_http_endpoint +from metagpt.utils.repo_to_markdown import is_text_file +from metagpt.utils.report import EditorReporter # This is also used in unit tests! MSG_FILE_UPDATED = "[File updated (edited at line {line_number}). Please review the changes and make sure they are correct (correct indentation, no duplicate lines, etc). Edit the file again if necessary.]" LINTER_ERROR_MSG = "[Your proposed edit has introduced new syntax error(s). Please understand the errors and retry your edit command.]\n" +class FileBlock(BaseModel): + """A block of content in a file""" + + file_path: str + block_content: str + + class LineNumberError(Exception): pass @@ -30,16 +42,133 @@ class LineNumberError(Exception): @register_tool() class Editor(BaseModel): """ - A state-of-state tool for open, reading, and editing files. + A state-of-state tool for open/reading, understanding, and editing/writing files. Args: working_dir: The working directory to use for the editor. """ + model_config = ConfigDict(arbitrary_types_allowed=True) + resource: EditorReporter = EditorReporter() current_file: Optional[Path] = None current_line: int = 1 window: int = 100 + enable_auto_lint: bool = False working_dir: Path = DEFAULT_WORKSPACE_ROOT + def write(self, path: str, content: str): + """Write the whole content to a file. When used, make sure content arg contains the full content of the file.""" + if "\n" not in content and "\\n" in content: + # A very raw rule to correct the content: If 'content' lacks actual newlines ('\n') but includes '\\n', consider + # replacing them with '\n' to potentially correct mistaken representations of newline characters. + content = content.replace("\\n", "\n") + directory = os.path.dirname(path) + if directory and not os.path.exists(directory): + os.makedirs(directory) + with open(path, "w", encoding="utf-8") as f: + f.write(content) + # self.resource.report(path, "path") + return f"The writing/coding the of the file {os.path.basename(path)}' is now completed. The file '{os.path.basename(path)}' has been successfully created." + + async def read(self, path: str) -> FileBlock: + """Read the whole content of a file. Using absolute paths as the argument for specifying the file location.""" + is_text, mime_type = await is_text_file(path) + if is_text: + lines = await self._read_text(path) + elif mime_type == "application/pdf": + lines = await self._read_pdf(path) + elif mime_type in { + "application/msword", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "application/vnd.ms-word.document.macroEnabled.12", + "application/vnd.openxmlformats-officedocument.wordprocessingml.template", + "application/vnd.ms-word.template.macroEnabled.12", + }: + lines = await self._read_docx(path) + else: + return FileBlock(file_path=str(path), block_content="") + self.resource.report(str(path), "path") + + lines_with_num = [f"{i + 1:03}|{line}" for i, line in enumerate(lines)] + result = FileBlock( + file_path=str(path), + block_content="".join(lines_with_num), + ) + return result + + @staticmethod + async def _read_text(path: Union[str, Path]) -> List[str]: + content = await aread(path) + lines = content.split("\n") + return lines + + @staticmethod + async def _read_pdf(path: Union[str, Path]) -> List[str]: + result = await Editor._omniparse_read_file(path) + if result: + return result + + from llama_index.readers.file import PDFReader + + reader = PDFReader() + lines = reader.load_data(file=Path(path)) + return [i.text for i in lines] + + @staticmethod + async def _read_docx(path: Union[str, Path]) -> List[str]: + result = await Editor._omniparse_read_file(path) + if result: + return result + return read_docx(str(path)) + + @staticmethod + async def _omniparse_read_file(path: Union[str, Path]) -> Optional[List[str]]: + from metagpt.tools.libs import get_env_default + from metagpt.utils.omniparse_client import OmniParseClient + + env_base_url = await get_env_default(key="base_url", app_name="OmniParse", default_value="") + env_timeout = await get_env_default(key="timeout", app_name="OmniParse", default_value="") + conf_base_url, conf_timeout = await Editor._read_omniparse_config() + + base_url = env_base_url or conf_base_url + if not base_url: + return None + api_key = await get_env_default(key="api_key", app_name="OmniParse", default_value="") + timeout = env_timeout or conf_timeout or 600 + try: + timeout = int(timeout) + except ValueError: + timeout = 600 + + try: + if not await check_http_endpoint(url=base_url): + logger.warning(f"{base_url}: NOT AVAILABLE") + return None + client = OmniParseClient(api_key=api_key, base_url=base_url, max_timeout=timeout) + file_data = await aread_bin(filename=path) + ret = await client.parse_document(file_input=file_data, bytes_filename=str(path)) + except (ValueError, Exception) as e: + logger.exception(f"{path}: {e}") + return None + if not ret.images: + return [ret.text] if ret.text else None + + result = [ret.text] + img_dir = Path(path).parent / (Path(path).name.replace(".", "_") + "_images") + img_dir.mkdir(parents=True, exist_ok=True) + for i in ret.images: + byte_data = base64.b64decode(i.image) + filename = img_dir / i.image_name + await awrite_bin(filename=filename, data=byte_data) + result.append(f"![{i.image_name}]({str(filename)})") + return result + + @staticmethod + async def _read_omniparse_config() -> Tuple[str, int]: + config = Config.default() + if config.omniparse and config.omniparse.url: + return config.omniparse.url, config.omniparse.timeout + return "", 0 + @staticmethod def _is_valid_filename(file_name: str) -> bool: if not file_name or not file_name.strip(): @@ -422,8 +551,8 @@ class Editor(BaseModel): try: # lint the original file - enable_auto_lint = os.getenv("ENABLE_AUTO_LINT", "false").lower() == "true" - if enable_auto_lint: + # enable_auto_lint = os.getenv("ENABLE_AUTO_LINT", "false").lower() == "true" + if self.enable_auto_lint: original_lint_error, _ = self._lint_file(file_name) # Create a temporary file @@ -461,7 +590,7 @@ class Editor(BaseModel): # Handle linting # NOTE: we need to get env var inside this function # because the env var will be set AFTER the agentskills is imported - if enable_auto_lint: + if self.enable_auto_lint: # BACKUP the original file original_file_backup_path = file_name.parent / f".backup.{file_name.name}" with original_file_backup_path.open("w") as f: @@ -803,7 +932,7 @@ class Editor(BaseModel): matches = [] for root, _, files in os.walk(dir_path): for file in files: - if file_name in file: + if str(file_name) in file: matches.append(Path(root) / file) res_list = [] diff --git a/metagpt/tools/libs/file_io_operator.py b/metagpt/tools/libs/file_io_operator.py deleted file mode 100644 index 29578789d..000000000 --- a/metagpt/tools/libs/file_io_operator.py +++ /dev/null @@ -1,149 +0,0 @@ -import base64 -import os -from pathlib import Path -from typing import List, Optional, Tuple, Union - -from pydantic import BaseModel, ConfigDict - -from metagpt.config2 import Config -from metagpt.logs import logger -from metagpt.tools.tool_registry import register_tool -from metagpt.utils import read_docx -from metagpt.utils.common import aread, aread_bin, awrite_bin, check_http_endpoint -from metagpt.utils.repo_to_markdown import is_text_file -from metagpt.utils.report import FileIOOperatorReporter - - -class FileBlock(BaseModel): - """A block of content in a file""" - - file_path: str - block_content: str - - -class LineNumberError(Exception): - pass - - -@register_tool() -class FileOperator(BaseModel): - """ - A state-of-state tool for reading, understanding, and writing files. - """ - - model_config = ConfigDict(arbitrary_types_allowed=True) - resource: FileIOOperatorReporter = FileIOOperatorReporter() - - def write(self, path: str, content: str): - """Write the whole content to a file. When used, make sure content arg contains the full content of the file.""" - if "\n" not in content and "\\n" in content: - # A very raw rule to correct the content: If 'content' lacks actual newlines ('\n') but includes '\\n', consider - # replacing them with '\n' to potentially correct mistaken representations of newline characters. - content = content.replace("\\n", "\n") - directory = os.path.dirname(path) - if directory and not os.path.exists(directory): - os.makedirs(directory) - with open(path, "w", encoding="utf-8") as f: - f.write(content) - # self.resource.report(path, "path") - return f"The writing/coding the of the file {os.path.basename(path)}' is now completed. The file '{os.path.basename(path)}' has been successfully created." - - async def read(self, path: str) -> FileBlock: - """Read the whole content of a file. Using absolute paths as the argument for specifying the file location.""" - is_text, mime_type = await is_text_file(path) - if is_text: - lines = await self._read_text(path) - elif mime_type == "application/pdf": - lines = await self._read_pdf(path) - elif mime_type in { - "application/msword", - "application/vnd.openxmlformats-officedocument.wordprocessingml.document", - "application/vnd.ms-word.document.macroEnabled.12", - "application/vnd.openxmlformats-officedocument.wordprocessingml.template", - "application/vnd.ms-word.template.macroEnabled.12", - }: - lines = await self._read_docx(path) - else: - return FileBlock(file_path=str(path), block_content="") - self.resource.report(str(path), "path") - - lines_with_num = [f"{i + 1:03}|{line}" for i, line in enumerate(lines)] - result = FileBlock( - file_path=str(path), - block_content="".join(lines_with_num), - ) - return result - - @staticmethod - async def _read_text(path: Union[str, Path]) -> List[str]: - content = await aread(path) - lines = content.split("\n") - return lines - - @staticmethod - async def _read_pdf(path: Union[str, Path]) -> List[str]: - result = await FileOperator._omniparse_read_file(path) - if result: - return result - - from llama_index.readers.file import PDFReader - - reader = PDFReader() - lines = reader.load_data(file=Path(path)) - return [i.text for i in lines] - - @staticmethod - async def _read_docx(path: Union[str, Path]) -> List[str]: - result = await FileOperator._omniparse_read_file(path) - if result: - return result - return read_docx(str(path)) - - @staticmethod - async def _omniparse_read_file(path: Union[str, Path]) -> Optional[List[str]]: - from metagpt.tools.libs import get_env_default - from metagpt.utils.omniparse_client import OmniParseClient - - env_base_url = await get_env_default(key="base_url", app_name="OmniParse", default_value="") - env_timeout = await get_env_default(key="timeout", app_name="OmniParse", default_value="") - conf_base_url, conf_timeout = await FileOperator._read_omniparse_config() - - base_url = env_base_url or conf_base_url - if not base_url: - return None - api_key = await get_env_default(key="api_key", app_name="OmniParse", default_value="") - timeout = env_timeout or conf_timeout or 600 - try: - timeout = int(timeout) - except ValueError: - timeout = 600 - - try: - if not await check_http_endpoint(url=base_url): - logger.warning(f"{base_url}: NOT AVAILABLE") - return None - client = OmniParseClient(api_key=api_key, base_url=base_url, max_timeout=timeout) - file_data = await aread_bin(filename=path) - ret = await client.parse_document(file_input=file_data, bytes_filename=str(path)) - except (ValueError, Exception) as e: - logger.exception(f"{path}: {e}") - return None - if not ret.images: - return [ret.text] if ret.text else None - - result = [ret.text] - img_dir = Path(path).parent / (Path(path).name.replace(".", "_") + "_images") - img_dir.mkdir(parents=True, exist_ok=True) - for i in ret.images: - byte_data = base64.b64decode(i.image) - filename = img_dir / i.image_name - await awrite_bin(filename=filename, data=byte_data) - result.append(f"![{i.image_name}]({str(filename)})") - return result - - @staticmethod - async def _read_omniparse_config() -> Tuple[str, int]: - config = Config.default() - if config.omniparse and config.omniparse.url: - return config.omniparse.url, config.omniparse.timeout - return "", 0 diff --git a/metagpt/utils/report.py b/metagpt/utils/report.py index 5021011d2..427f401ab 100644 --- a/metagpt/utils/report.py +++ b/metagpt/utils/report.py @@ -35,7 +35,7 @@ class BlockType(str, Enum): TASK = "Task" BROWSER = "Browser" BROWSER_RT = "Browser-RT" - FILE_IO_OPERATOR = "FileIOOperator" + EDITOR = "Editor" GALLERY = "Gallery" NOTEBOOK = "Notebook" DOCS = "Docs" @@ -305,10 +305,10 @@ class DocsReporter(FileReporter): block: Literal[BlockType.DOCS] = BlockType.DOCS -class FileIOOperatorReporter(FileReporter): - """Equivalent to FileReporter(block=BlockType.FileIOOperator).""" +class EditorReporter(FileReporter): + """Equivalent to FileReporter(block=BlockType.EDITOR).""" - block: Literal[BlockType.FILE_IO_OPERATOR] = BlockType.FILE_IO_OPERATOR + block: Literal[BlockType.EDITOR] = BlockType.EDITOR class GalleryReporter(FileReporter): diff --git a/tests/metagpt/test_reporter.py b/tests/metagpt/test_reporter.py index b1a0918a5..41d963448 100644 --- a/tests/metagpt/test_reporter.py +++ b/tests/metagpt/test_reporter.py @@ -10,7 +10,7 @@ from metagpt.utils.report import ( BlockType, BrowserReporter, DocsReporter, - FileIOOperatorReporter, + EditorReporter, NotebookReporter, ServerReporter, TaskReporter, @@ -148,8 +148,8 @@ async def test_notebook_reporter(http_server): "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n\nprint('Hello World')\n", "/data/main.py", {"type": "write_code"}, - BlockType.FILE_IO_OPERATOR, - FileIOOperatorReporter, + BlockType.EDITOR, + EditorReporter, ), ), ids=["test_docs_reporter", "test_editor_reporter"], diff --git a/tests/metagpt/tools/libs/test_editor.py b/tests/metagpt/tools/libs/test_editor.py index 18b1400bf..e2774ddc5 100644 --- a/tests/metagpt/tools/libs/test_editor.py +++ b/tests/metagpt/tools/libs/test_editor.py @@ -13,17 +13,24 @@ def test_function_for_fm(): # this is the 7th line """.strip() -TEST_FILE_PATH = TEST_DATA_PATH / "tools/test_script_for_editor.py" WINDOW = 100 @pytest.fixture -def test_file(): - with open(TEST_FILE_PATH, "w") as f: - f.write(TEST_FILE_CONTENT) - yield - with open(TEST_FILE_PATH, "w") as f: - f.write("") +def temp_file_path(tmp_path): + assert tmp_path is not None + temp_file_path = tmp_path / "a.txt" + yield temp_file_path + temp_file_path.unlink() + + +@pytest.fixture +def temp_py_file(tmp_path): + assert tmp_path is not None + temp_file_path = tmp_path / "test_script_for_editor.py" + temp_file_path.write_text(TEST_FILE_CONTENT) + yield temp_file_path + temp_file_path.unlink() EXPECTED_CONTENT_AFTER_REPLACE = """ @@ -36,17 +43,17 @@ def test_function_for_fm(): """.strip() -def test_replace_content(test_file): +def test_replace_content(temp_py_file): editor = Editor() editor._edit_file_impl( - file_name=TEST_FILE_PATH, + file_name=temp_py_file, start=3, end=5, content=" # This is the new line A replacing lines 3 to 5.\n # This is the new line B.", is_insert=False, is_append=False, ) - with open(TEST_FILE_PATH, "r") as f: + with open(temp_py_file, "r") as f: new_content = f.read() assert new_content.strip() == EXPECTED_CONTENT_AFTER_REPLACE.strip() @@ -60,17 +67,17 @@ def test_function_for_fm(): """.strip() -def test_delete_content(test_file): +def test_delete_content(temp_py_file): editor = Editor() editor._edit_file_impl( - file_name=TEST_FILE_PATH, + file_name=temp_py_file, start=3, end=5, content="", is_insert=False, is_append=False, ) - with open(TEST_FILE_PATH, "r") as f: + with open(temp_py_file, "r") as f: new_content = f.read() assert new_content.strip() == EXPECTED_CONTENT_AFTER_DELETE.strip() @@ -87,17 +94,14 @@ def test_function_for_fm(): """.strip() -def test_insert_content(test_file): - editor = Editor() - editor._edit_file_impl( - file_name=TEST_FILE_PATH, - start=3, - end=3, +def test_insert_content(temp_py_file): + editor = Editor(enable_auto_lint=True) + editor.insert_content_at_line( + file_name=temp_py_file, + line_number=3, content=" # This is the new line to be inserted, at line 3", - is_insert=True, - is_append=False, ) - with open(TEST_FILE_PATH, "r") as f: + with open(temp_py_file, "r") as f: new_content = f.read() assert new_content.strip() == EXPECTED_CONTENT_AFTER_INSERT.strip() @@ -161,10 +165,8 @@ def test_open_file_unexist_path(): editor.open_file("/unexist/path/a.txt") -def test_open_file(tmp_path): +def test_open_file(temp_file_path): editor = Editor() - assert tmp_path is not None - temp_file_path = tmp_path / "a.txt" temp_file_path.write_text("Line 1\nLine 2\nLine 3\nLine 4\nLine 5") result = editor.open_file(str(temp_file_path)) @@ -183,9 +185,8 @@ def test_open_file(tmp_path): assert result.split("\n") == expected.split("\n") -def test_open_file_with_indentation(tmp_path): +def test_open_file_with_indentation(temp_file_path): editor = Editor() - temp_file_path = tmp_path / "a.txt" temp_file_path.write_text("Line 1\n Line 2\nLine 3\nLine 4\nLine 5") result = editor.open_file(str(temp_file_path)) @@ -203,9 +204,8 @@ def test_open_file_with_indentation(tmp_path): assert result.split("\n") == expected.split("\n") -def test_open_file_long(tmp_path): +def test_open_file_long(temp_file_path): editor = Editor() - temp_file_path = tmp_path / "a.txt" content = "\n".join([f"Line {i}" for i in range(1, 1001)]) temp_file_path.write_text(content) @@ -219,9 +219,8 @@ def test_open_file_long(tmp_path): assert result.split("\n") == expected.split("\n") -def test_open_file_long_with_lineno(tmp_path): +def test_open_file_long_with_lineno(temp_file_path): editor = Editor() - temp_file_path = tmp_path / "a.txt" content = "\n".join([f"Line {i}" for i in range(1, 1001)]) temp_file_path.write_text(content) @@ -250,18 +249,16 @@ def test_create_file_unexist_path(): editor.create_file("/unexist/path/a.txt") -def test_create_file(tmp_path): +def test_create_file(temp_file_path): editor = Editor() - temp_file_path = tmp_path / "a.txt" result = editor.create_file(str(temp_file_path)) expected = f"[File {temp_file_path} created.]" assert result.split("\n") == expected.split("\n") -def test_goto_line(tmp_path): +def test_goto_line(temp_file_path): editor = Editor() - temp_file_path = tmp_path / "a.txt" total_lines = 1000 content = "\n".join([f"Line {i}" for i in range(1, total_lines + 1)]) temp_file_path.write_text(content) @@ -296,9 +293,8 @@ def test_goto_line(tmp_path): assert result.split("\n") == expected.split("\n") -def test_goto_line_negative(tmp_path): +def test_goto_line_negative(temp_file_path): editor = Editor() - temp_file_path = tmp_path / "a.txt" content = "\n".join([f"Line {i}" for i in range(1, 5)]) temp_file_path.write_text(content) @@ -307,9 +303,8 @@ def test_goto_line_negative(tmp_path): editor.goto_line(-1) -def test_goto_line_out_of_bound(tmp_path): +def test_goto_line_out_of_bound(temp_file_path): editor = Editor() - temp_file_path = tmp_path / "a.txt" content = "\n".join([f"Line {i}" for i in range(1, 5)]) temp_file_path.write_text(content) @@ -318,9 +313,8 @@ def test_goto_line_out_of_bound(tmp_path): editor.goto_line(100) -def test_scroll_down(tmp_path): +def test_scroll_down(temp_file_path): editor = Editor() - temp_file_path = tmp_path / "a.txt" total_lines = 1000 content = "\n".join([f"Line {i}" for i in range(1, total_lines + 1)]) temp_file_path.write_text(content) @@ -360,9 +354,8 @@ def test_scroll_down(tmp_path): assert result.split("\n") == expected.split("\n") -def test_scroll_up(tmp_path): +def test_scroll_up(temp_file_path): editor = Editor() - temp_file_path = tmp_path / "a.txt" total_lines = 1000 content = "\n".join([f"Line {i}" for i in range(1, total_lines + 1)]) temp_file_path.write_text(content) @@ -405,9 +398,8 @@ def test_scroll_up(tmp_path): assert result.split("\n") == expected.split("\n") -def test_scroll_down_edge(tmp_path): +def test_scroll_down_edge(temp_file_path): editor = Editor() - temp_file_path = tmp_path / "a.txt" content = "\n".join([f"Line {i}" for i in range(1, 10)]) temp_file_path.write_text(content) @@ -426,36 +418,34 @@ def test_scroll_down_edge(tmp_path): assert result.split("\n") == expected.split("\n") -def test_print_window_internal(tmp_path): +def test_print_window_internal(temp_file_path): editor = Editor() - test_file_path = tmp_path / "a.txt" - editor.create_file(str(test_file_path)) - with open(test_file_path, "w") as file: + editor.create_file(str(temp_file_path)) + with open(temp_file_path, "w") as file: for i in range(1, 101): file.write(f"Line `{i}`\n") current_line = 50 window = 2 - result = editor._print_window(test_file_path, current_line, window) + result = editor._print_window(temp_file_path, current_line, window) expected = "(48 more lines above)\n" "49|Line `49`\n" "50|Line `50`\n" "51|Line `51`\n" "(49 more lines below)" assert result == expected -def test_open_file_large_line_number(tmp_path): +def test_open_file_large_line_number(temp_file_path): editor = Editor() - test_file_path = tmp_path / "a.txt" - editor.create_file(str(test_file_path)) - with open(test_file_path, "w") as file: + editor.create_file(str(temp_file_path)) + with open(temp_file_path, "w") as file: for i in range(1, 1000): file.write(f"Line `{i}`\n") current_line = 800 window = 100 - result = editor.open_file(str(test_file_path), current_line, window) + result = editor.open_file(str(temp_file_path), current_line, window) - expected = f"[File: {test_file_path} (999 lines total)]\n" + expected = f"[File: {temp_file_path} (999 lines total)]\n" expected += "(749 more lines above)\n" for i in range(750, 850 + 1): expected += f"{i}|Line `{i}`\n" @@ -463,21 +453,20 @@ def test_open_file_large_line_number(tmp_path): assert result == expected -def test_open_file_large_line_number_consecutive_diff_window(tmp_path): +def test_open_file_large_line_number_consecutive_diff_window(temp_file_path): editor = Editor() - test_file_path = tmp_path / "a.txt" - editor.create_file(str(test_file_path)) + editor.create_file(str(temp_file_path)) total_lines = 1000 - with open(test_file_path, "w") as file: + with open(temp_file_path, "w") as file: for i in range(1, total_lines + 1): file.write(f"Line `{i}`\n") current_line = 800 cur_window = 300 - result = editor.open_file(str(test_file_path), current_line, cur_window) + result = editor.open_file(str(temp_file_path), current_line, cur_window) - expected = f"[File: {test_file_path} ({total_lines} lines total)]\n" + expected = f"[File: {temp_file_path} ({total_lines} lines total)]\n" start, end = _calculate_window_bounds(current_line, total_lines, cur_window) if start == 1: expected += "(this is the beginning of the file)\n" @@ -495,7 +484,7 @@ def test_open_file_large_line_number_consecutive_diff_window(tmp_path): result = editor.scroll_up() - expected = f"[File: {test_file_path} ({total_lines} lines total)]\n" + expected = f"[File: {temp_file_path} ({total_lines} lines total)]\n" start, end = _calculate_window_bounds(current_line, total_lines, WINDOW) if start == 1: expected += "(this is the beginning of the file)\n" @@ -510,5 +499,116 @@ def test_open_file_large_line_number_consecutive_diff_window(tmp_path): assert result.split("\n") == expected.split("\n") +EXPECTED_CONTENT_AFTER_REPLACE_TEXT = """ +# this is line one +def test_function_for_fm(): + "some docstring" + a = 1 + b = 9 + c = 3 + # this is the 7th line +""".strip() + + +def test_edit_file_by_replace(temp_py_file): + editor = Editor() + editor.edit_file_by_replace(file_name=str(temp_py_file), to_replace=" b = 2", new_content=" b = 9") + with open(temp_py_file, "r") as f: + new_content = f.read() + assert new_content.strip() == EXPECTED_CONTENT_AFTER_REPLACE_TEXT.strip() + + +def test_search_dir(tmp_path): + editor = Editor() + dir_path = tmp_path / "test_dir" + dir_path.mkdir() + + # Create some files with specific content + (dir_path / "file1.txt").write_text("This is a test file with some content.") + (dir_path / "file2.txt").write_text("Another file with different content.") + sub_dir = dir_path / "sub_dir" + sub_dir.mkdir() + (sub_dir / "file3.txt").write_text("This file is inside a sub directory with some content.") + + search_term = "some content" + + result = editor.search_dir(search_term, str(dir_path)) + + assert "file1.txt" in result + assert "file3.txt" in result + assert "Another file with different content." not in result + + +def test_search_file(temp_file_path): + editor = Editor() + file_path = temp_file_path + file_path.write_text("This is a test file with some content.\nAnother line with more content.") + + search_term = "some content" + + result = editor.search_file(search_term, str(file_path)) + + assert "Line 1: This is a test file with some content." in result + assert "Line 2: Another line with more content." not in result + + +def test_find_file(tmp_path): + editor = Editor() + dir_path = tmp_path / "test_dir" + dir_path.mkdir() + + # Create some files with specific names + (dir_path / "file1.txt").write_text("Content of file 1.") + (dir_path / "file2.txt").write_text("Content of file 2.") + sub_dir = dir_path / "sub_dir" + sub_dir.mkdir() + (sub_dir / "file3.txt").write_text("Content of file 3.") + + file_name = "file1.txt" + + result = editor.find_file(file_name, str(dir_path)) + + assert "file1.txt" in result + assert "file2.txt" not in result + assert "file3.txt" not in result + + +# Test data for _append_impl method +TEST_LINES = ["First line\n", "Second line\n", "Third line\n"] + +NEW_CONTENT = "Appended line\n" + +EXPECTED_APPEND_NON_EMPTY_FILE = ["First line\n", "Second line\n", "Third line\n", "Appended line\n"] + +EXPECTED_APPEND_EMPTY_FILE = ["Appended line\n"] + + +def test_append_non_empty_file(): + editor = Editor() + lines = TEST_LINES.copy() + content, n_added_lines = editor._append_impl(lines, NEW_CONTENT) + + assert content.splitlines(keepends=True) == EXPECTED_APPEND_NON_EMPTY_FILE + assert n_added_lines == 1 + + +def test_append_empty_file(): + editor = Editor() + lines = [] + content, n_added_lines = editor._append_impl(lines, NEW_CONTENT) + + assert content.splitlines(keepends=True) == EXPECTED_APPEND_EMPTY_FILE + assert n_added_lines == 1 + + +def test_append_to_single_empty_line_file(): + editor = Editor() + lines = [""] + content, n_added_lines = editor._append_impl(lines, NEW_CONTENT) + + assert content.splitlines(keepends=True) == EXPECTED_APPEND_EMPTY_FILE + assert n_added_lines == 1 + + if __name__ == "__main__": pytest.main([__file__, "-s"])