diff --git a/metagpt/prompts/di/swe_agent.py b/metagpt/prompts/di/swe_agent.py index b543c01d5..86a062214 100644 --- a/metagpt/prompts/di/swe_agent.py +++ b/metagpt/prompts/di/swe_agent.py @@ -183,9 +183,7 @@ IMPORTANT_TIPS = """ 15. When the edit fails, try to enlarge the starting line. -16. Use an absolute path instead of a relative path. - -17. Once again, and this is critical: YOU CAN ONLY ENTER ONE COMMAND AT A TIME. +16. Once again, and this is critical: YOU CAN ONLY ENTER ONE COMMAND AT A TIME. """ NEXT_STEP_TEMPLATE = f""" diff --git a/metagpt/roles/di/role_zero.py b/metagpt/roles/di/role_zero.py index e32292b96..d2486b89e 100644 --- a/metagpt/roles/di/role_zero.py +++ b/metagpt/roles/di/role_zero.py @@ -71,7 +71,7 @@ class RoleZero(Role): tools: list[str] = [] # Use special symbol [""] to indicate use of all registered tools tool_recommender: Optional[ToolRecommender] = None tool_execution_map: Annotated[dict[str, Callable], Field(exclude=True)] = {} - special_tool_commands: list[str] = ["Plan.finish_current_task", "end", "Bash.run"] + special_tool_commands: list[str] = ["Plan.finish_current_task", "end"] # Equipped with three basic tools by default for optional use editor: Editor = Editor() browser: Browser = Browser() @@ -140,12 +140,11 @@ class RoleZero(Role): "goto_line", "insert_content_at_line", "open_file", - # "read", "scroll_down", "scroll_up", "search_dir", "search_file", - # "write", + "set_workdir", ] } ) diff --git a/metagpt/tools/libs/editor.py b/metagpt/tools/libs/editor.py index 806098522..71f297acd 100644 --- a/metagpt/tools/libs/editor.py +++ b/metagpt/tools/libs/editor.py @@ -1,9 +1,15 @@ +""" +This file is borrowed from OpenDevin +You can find the original repository here: +https://github.com/All-Hands-AI/OpenHands/blob/main/openhands/runtime/plugins/agent_skills/file_ops/file_ops.py +""" + import os import re import shutil import tempfile from pathlib import Path -from typing import Optional +from typing import Optional, Union from pydantic import BaseModel @@ -25,6 +31,8 @@ class LineNumberError(Exception): class Editor(BaseModel): """ A state-of-state tool for open, reading, and editing files. + Args: + working_dir: The working directory to use for the editor. """ current_file: Optional[Path] = None @@ -74,22 +82,22 @@ class Editor(BaseModel): def _clamp(value, min_value, max_value): return max(min_value, min(value, max_value)) - @staticmethod - def _lint_file(file_path: Path) -> tuple[Optional[str], Optional[int]]: + def _lint_file(self, file_path: Path) -> tuple[Optional[str], Optional[int]]: """Lint the file at the given path and return a tuple with a boolean indicating if there are errors, and the line number of the first error, if any. Returns: tuple[str | None, int | None]: (lint_error, first_error_line_number) """ - linter = Linter(root=os.getcwd()) + + linter = Linter(root=self.working_dir) lint_error = linter.lint(str(file_path)) if not lint_error: # Linting successful. No issues found. return None, None return "ERRORS:\n" + lint_error.text, lint_error.lines[0] - def _print_window(self, file_path: Path, targeted_line: int, window: int, return_str: bool = False): + def _print_window(self, file_path: Path, targeted_line: int, window: int): self._check_current_file(file_path) with file_path.open() as file: content = file.read() @@ -133,10 +141,7 @@ class Editor(BaseModel): output += "(this is the end of the file)\n" output = output.rstrip() - if return_str: - return output - else: - logger.info(output) + return output @staticmethod def _cur_file_header(current_file: Path, total_lines: int) -> str: @@ -154,7 +159,9 @@ class Editor(BaseModel): """ self.working_dir = Path(path) - def open_file(self, path: str, line_number: Optional[int] = 1, context_lines: Optional[int] = None) -> None: + def open_file( + self, path: Union[Path, str], line_number: Optional[int] = 1, context_lines: Optional[int] = None + ) -> str: """Opens the file at the given path in the editor. If line_number is provided, the window will be moved to include that line. It only shows the first 100 lines by default! Max `context_lines` supported is 2000, use `scroll up/down` to view the file if you want to see more. @@ -167,7 +174,7 @@ class Editor(BaseModel): if context_lines is None: context_lines = self.window - path = self.working_dir / Path(path) + path = self._try_fix_path(path) if not path.is_file(): raise FileNotFoundError(f"File {path} not found") @@ -185,10 +192,10 @@ class Editor(BaseModel): context_lines = self.window output = self._cur_file_header(path, total_lines) - output += self._print_window(path, self.current_line, self._clamp(context_lines, 1, 2000), return_str=True) - logger.info(output) + output += self._print_window(path, self.current_line, self._clamp(context_lines, 1, 2000)) + return output - def goto_line(self, line_number: int) -> None: + def goto_line(self, line_number: int) -> str: """Moves the window to show the specified line number. Args: @@ -204,10 +211,10 @@ class Editor(BaseModel): self.current_line = self._clamp(line_number, 1, total_lines) output = self._cur_file_header(self.current_file, total_lines) - output += self._print_window(self.current_file, self.current_line, self.window, return_str=True) - logger.info(output) + output += self._print_window(self.current_file, self.current_line, self.window) + return output - def scroll_down(self) -> None: + def scroll_down(self) -> str: """Moves the window down by 100 lines.""" self._check_current_file() @@ -215,10 +222,10 @@ class Editor(BaseModel): total_lines = max(1, sum(1 for _ in file)) self.current_line = self._clamp(self.current_line + self.window, 1, total_lines) output = self._cur_file_header(self.current_file, total_lines) - output += self._print_window(self.current_file, self.current_line, self.window, return_str=True) - logger.info(output) + output += self._print_window(self.current_file, self.current_line, self.window) + return output - def scroll_up(self) -> None: + def scroll_up(self) -> str: """Moves the window up by 100 lines.""" self._check_current_file() @@ -226,16 +233,16 @@ class Editor(BaseModel): total_lines = max(1, sum(1 for _ in file)) self.current_line = self._clamp(self.current_line - self.window, 1, total_lines) output = self._cur_file_header(self.current_file, total_lines) - output += self._print_window(self.current_file, self.current_line, self.window, return_str=True) - logger.info(output) + output += self._print_window(self.current_file, self.current_line, self.window) + return output - def create_file(self, filename: str) -> None: + def create_file(self, filename: str) -> str: """Creates and opens a new file with the given name. Args: filename: str: The name of the file to create. """ - filename = self.working_dir / Path(filename) + filename = self._try_fix_path(filename) if filename.exists(): raise FileExistsError(f"File '{filename}' already exists.") @@ -244,7 +251,7 @@ class Editor(BaseModel): file.write("\n") self.open_file(filename) - logger.info(f"[File {filename} created.]") + return f"[File {filename} created.]" @staticmethod def _append_impl(lines, content): @@ -345,11 +352,22 @@ class Editor(BaseModel): if start > end: raise LineNumberError(f"Invalid line range: {start}-{end}. Start must be less than or equal to end.") + # Split content into lines and ensure it ends with a newline if not content.endswith("\n"): content += "\n" content_lines = content.splitlines(True) + + # Calculate the number of lines to be added n_added_lines = len(content_lines) + + # Remove the specified range of lines and insert the new content new_lines = lines[: start - 1] + content_lines + lines[end:] + + # Handle the case where the original lines are empty + if len(lines) == 0: + new_lines = content_lines + + # Join the lines to create the new content content = "".join(new_lines) return content, n_added_lines @@ -403,8 +421,6 @@ class Editor(BaseModel): first_error_line = None try: - n_added_lines = None - # lint the original file enable_auto_lint = os.getenv("ENABLE_AUTO_LINT", "false").lower() == "true" if enable_auto_lint: @@ -506,7 +522,6 @@ class Editor(BaseModel): original_file_backup_path, show_line, editor_lines, - return_str=True, ) + "\n" ) @@ -549,11 +564,11 @@ class Editor(BaseModel): self.current_line = start or n_total_lines or 1 ret_str += f"[File: {file_name.resolve()} ({n_total_lines} lines total after edit)]\n" CURRENT_FILE = file_name - ret_str += self._print_window(CURRENT_FILE, self.current_line, self.window, return_str=True) + "\n" + ret_str += self._print_window(CURRENT_FILE, self.current_line, self.window) + "\n" ret_str += MSG_FILE_UPDATED.format(line_number=self.current_line) return ret_str - def edit_file_by_replace(self, file_name: str, to_replace: str, new_content: str) -> None: + def edit_file_by_replace(self, file_name: str, to_replace: str, new_content: str) -> str: """Edit a file. This will search for `to_replace` in the given file and replace it with `new_content`. Every *to_replace* must *EXACTLY MATCH* the existing source code, character for character, including all comments, docstrings, etc. @@ -609,7 +624,7 @@ class Editor(BaseModel): # search for `to_replace` in the file # if found, replace it with `new_content` # if not found, perform a fuzzy search to find the closest match and replace it with `new_content` - file_name = self.working_dir / Path(file_name) + file_name = self._try_fix_path(file_name) with file_name.open("r") as file: file_content = file.read() @@ -635,8 +650,7 @@ class Editor(BaseModel): # find the closest match start = file_content_fuzzy.find(to_replace_fuzzy) if start == -1: - logger.info(f"[No exact match found in {file_name} for\n```\n{to_replace}\n```\n]") - return + return f"[No exact match found in {file_name} for\n```\n{to_replace}\n```\n]" # Convert start from index to line number for fuzzy match start_line_number = file_content_fuzzy[:start].count("\n") + 1 end_line_number = start_line_number + len(to_replace.splitlines()) - 1 @@ -650,9 +664,9 @@ class Editor(BaseModel): ) # lint_error = bool(LINTER_ERROR_MSG in ret_str) # TODO: automatically tries to fix linter error (maybe involve some static analysis tools on the location near the edit to figure out indentation) - logger.info(ret_str) + return ret_str - def insert_content_at_line(self, file_name: str, line_number: int, content: str) -> None: + def insert_content_at_line(self, file_name: str, line_number: int, content: str) -> str: """Insert content at the given line number in a file. This will NOT modify the content of the lines before OR after the given line number. @@ -675,7 +689,8 @@ class Editor(BaseModel): line_number: int: The line number (starting from 1) to insert the content after. content: str: The content to insert. """ - file_name = self.working_dir / Path(file_name) + file_name = self._try_fix_path(file_name) + ret_str = self._edit_file_impl( file_name, start=line_number, @@ -684,9 +699,9 @@ class Editor(BaseModel): is_insert=True, is_append=False, ) - logger.info(ret_str) + return ret_str - def append_file(self, file_name: str, content: str) -> None: + def append_file(self, file_name: str, content: str) -> str: """Append content to the given file. It appends text `content` to the end of the specified file. @@ -694,7 +709,8 @@ class Editor(BaseModel): file_name: str: The name of the file to edit. content: str: The content to insert. """ - file_name = self.working_dir / Path(file_name) + file_name = self._try_fix_path(file_name) + ret_str = self._edit_file_impl( file_name, start=None, @@ -703,16 +719,16 @@ class Editor(BaseModel): is_insert=False, is_append=True, ) - logger.info(ret_str) + return ret_str - def search_dir(self, search_term: str, dir_path: str = "./") -> None: + def search_dir(self, search_term: str, dir_path: str = "./") -> str: """Searches for search_term in all files in dir. If dir is not provided, searches in the current directory. Args: search_term: str: The term to search for. dir_path: str: The path to the directory to search. """ - dir_path = self.working_dir / Path(dir_path) + dir_path = self._try_fix_path(dir_path) if not dir_path.is_dir(): raise FileNotFoundError(f"Directory {dir_path} not found") matches = [] @@ -727,24 +743,21 @@ class Editor(BaseModel): matches.append((file_path, line_num, line.strip())) if not matches: - logger.info(f'No matches found for "{search_term}" in {dir_path}') - return + return f'No matches found for "{search_term}" in {dir_path}' num_matches = len(matches) num_files = len(set(match[0] for match in matches)) if num_files > 100: - logger.info( - f'More than {num_files} files matched for "{search_term}" in {dir_path}. Please narrow your search.' - ) - return + return f'More than {num_files} files matched for "{search_term}" in {dir_path}. Please narrow your search.' - logger.info(f'[Found {num_matches} matches for "{search_term}" in {dir_path}]') + res_list = [f'[Found {num_matches} matches for "{search_term}" in {dir_path}]'] for file_path, line_num, line in matches: - logger.info(f"{file_path} (Line {line_num}): {line}") - logger.info(f'[End of matches for "{search_term}" in {dir_path}]') + res_list.append(f"{file_path} (Line {line_num}): {line}") + res_list.append(f'[End of matches for "{search_term}" in {dir_path}]') + return "\n".join(res_list) - def search_file(self, search_term: str, file_path: Optional[str] = None) -> None: + def search_file(self, search_term: str, file_path: Optional[str] = None) -> str: """Searches for search_term in file. If file is not provided, searches in the current open file. Args: @@ -754,7 +767,7 @@ class Editor(BaseModel): if file_path is None: file_path = self.current_file else: - file_path = self.working_dir / Path(file_path) + file_path = self._try_fix_path(file_path) if file_path is None: raise FileNotFoundError("No file specified or open. Use the open_file function first.") if not file_path.is_file(): @@ -765,24 +778,25 @@ class Editor(BaseModel): for i, line in enumerate(file, 1): if search_term in line: matches.append((i, line.strip())) - + res_list = [] if matches: - logger.info(f'[Found {len(matches)} matches for "{search_term}" in {file_path}]') + res_list.append(f'[Found {len(matches)} matches for "{search_term}" in {file_path}]') for match in matches: - logger.info(f"Line {match[0]}: {match[1]}") - logger.info(f'[End of matches for "{search_term}" in {file_path}]') + res_list.append(f"Line {match[0]}: {match[1]}") + res_list.append(f'[End of matches for "{search_term}" in {file_path}]') else: - logger.info(f'[No matches found for "{search_term}" in {file_path}]') + res_list.append(f'[No matches found for "{search_term}" in {file_path}]') + return "\n".join(res_list) - def find_file(self, file_name: str, dir_path: str = "./") -> None: + def find_file(self, file_name: str, dir_path: str = "./") -> str: """Finds all files with the given name in the specified directory. Args: file_name: str: The name of the file to find. dir_path: str: The path to the directory to search. """ - file_name = self.working_dir / Path(file_name) - dir_path = self.working_dir / Path(dir_path) + file_name = self._try_fix_path(file_name) + dir_path = self._try_fix_path(dir_path) if not dir_path.is_dir(): raise FileNotFoundError(f"Directory {dir_path} not found") @@ -792,10 +806,20 @@ class Editor(BaseModel): if file_name in file: matches.append(Path(root) / file) + res_list = [] if matches: - logger.info(f'[Found {len(matches)} matches for "{file_name}" in {dir_path}]') + res_list.append(f'[Found {len(matches)} matches for "{file_name}" in {dir_path}]') for match in matches: - logger.info(f"{match}") - logger.info(f'[End of matches for "{file_name}" in {dir_path}]') + res_list.append(f"{match}") + res_list.append(f'[End of matches for "{file_name}" in {dir_path}]') else: - logger.info(f'[No matches found for "{file_name}" in {dir_path}]') + res_list.append(f'[No matches found for "{file_name}" in {dir_path}]') + return "\n".join(res_list) + + def _try_fix_path(self, path: Union[Path, str]) -> Path: + """Tries to fix the path if it is not absolute.""" + if not isinstance(path, Path): + path = Path(path) + if not path.is_absolute(): + path = self.working_dir / path + return path diff --git a/metagpt/tools/libs/file_io_operator.py b/metagpt/tools/libs/file_io_operator.py index 3e846c333..f30d2d4fd 100644 --- a/metagpt/tools/libs/file_io_operator.py +++ b/metagpt/tools/libs/file_io_operator.py @@ -26,7 +26,7 @@ class LineNumberError(Exception): @register_tool() -class FileIOOperator(BaseModel): +class FileOperator(BaseModel): """ A state-of-state tool for reading, understanding, and writing files. """ @@ -82,7 +82,7 @@ class FileIOOperator(BaseModel): @staticmethod async def _read_pdf(path: Union[str, Path]) -> List[str]: - result = await FileIOOperator._omniparse_read_file(path) + result = await FileOperator._omniparse_read_file(path) if result: return result @@ -94,7 +94,7 @@ class FileIOOperator(BaseModel): @staticmethod async def _read_docx(path: Union[str, Path]) -> List[str]: - result = await FileIOOperator._omniparse_read_file(path) + result = await FileOperator._omniparse_read_file(path) if result: return result return read_docx(str(path)) @@ -106,7 +106,7 @@ class FileIOOperator(BaseModel): base_url = await get_env_default(key="base_url", app_name="OmniParse", default_value="") if not base_url: - base_url = await FileIOOperator._read_omniparse_config() + base_url = await FileOperator._read_omniparse_config() if not base_url: return None api_key = await get_env_default(key="api_key", app_name="OmniParse", default_value="") diff --git a/metagpt/tools/libs/linter.py b/metagpt/tools/libs/linter.py index d77384095..9f3ab7fd0 100644 --- a/metagpt/tools/libs/linter.py +++ b/metagpt/tools/libs/linter.py @@ -46,7 +46,6 @@ class Linter: def run_cmd(self, cmd, rel_fname, code): cmd += " " + rel_fname cmd = cmd.split() - process = subprocess.Popen(cmd, cwd=self.root, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) stdout, _ = process.communicate() errors = stdout.decode().strip() diff --git a/tests/data/tools/test_script_for_editor.py b/tests/data/tools/test_script_for_editor.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/metagpt/tools/libs/test_editor.py b/tests/metagpt/tools/libs/test_editor.py index 6f0861c75..18b1400bf 100644 --- a/tests/metagpt/tools/libs/test_editor.py +++ b/tests/metagpt/tools/libs/test_editor.py @@ -1,10 +1,7 @@ -import contextlib -import io - import pytest from metagpt.const import TEST_DATA_PATH -from metagpt.tools.libs.editor import WINDOW, Editor +from metagpt.tools.libs.editor import Editor TEST_FILE_CONTENT = """ # this is line one @@ -17,6 +14,7 @@ def test_function_for_fm(): """.strip() TEST_FILE_PATH = TEST_DATA_PATH / "tools/test_script_for_editor.py" +WINDOW = 100 @pytest.fixture @@ -38,32 +36,43 @@ def test_function_for_fm(): """.strip() -@pytest.mark.skip def test_replace_content(test_file): - Editor().write_content( - file_path=str(TEST_FILE_PATH), - start_line=3, - end_line=5, - new_block_content=" # This is the new line A replacing lines 3 to 5.\n # This is the new line B.", + editor = Editor() + editor._edit_file_impl( + file_name=TEST_FILE_PATH, + start=3, + end=5, + content=" # This is the new line A replacing lines 3 to 5.\n # This is the new line B.", + is_insert=False, + is_append=False, ) with open(TEST_FILE_PATH, "r") as f: new_content = f.read() - assert new_content == EXPECTED_CONTENT_AFTER_REPLACE + assert new_content.strip() == EXPECTED_CONTENT_AFTER_REPLACE.strip() EXPECTED_CONTENT_AFTER_DELETE = """ # this is line one def test_function_for_fm(): + c = 3 # this is the 7th line """.strip() def test_delete_content(test_file): - Editor().write_content(file_path=str(TEST_FILE_PATH), start_line=3, end_line=5) + editor = Editor() + editor._edit_file_impl( + file_name=TEST_FILE_PATH, + start=3, + end=5, + content="", + is_insert=False, + is_append=False, + ) with open(TEST_FILE_PATH, "r") as f: new_content = f.read() - assert new_content == EXPECTED_CONTENT_AFTER_DELETE + assert new_content.strip() == EXPECTED_CONTENT_AFTER_DELETE.strip() EXPECTED_CONTENT_AFTER_INSERT = """ @@ -78,17 +87,19 @@ def test_function_for_fm(): """.strip() -@pytest.mark.skip def test_insert_content(test_file): - Editor().write_content( - file_path=str(TEST_FILE_PATH), - start_line=3, - end_line=-1, - new_block_content=" # This is the new line to be inserted, at line 3", + editor = Editor() + editor._edit_file_impl( + file_name=TEST_FILE_PATH, + start=3, + end=3, + content=" # This is the new line to be inserted, at line 3", + is_insert=True, + is_append=False, ) with open(TEST_FILE_PATH, "r") as f: new_content = f.read() - assert new_content == EXPECTED_CONTENT_AFTER_INSERT + assert new_content.strip() == EXPECTED_CONTENT_AFTER_INSERT.strip() @pytest.mark.parametrize( @@ -117,12 +128,6 @@ async def test_read_files(filename): assert file_block.block_content -@pytest.fixture(autouse=True) -def reset_current_file(): - global CURRENT_FILE - CURRENT_FILE = None - - def _numbered_test_lines(start, end) -> str: return ("\n".join(f"{i}|" for i in range(start, end + 1))) + "\n" @@ -150,24 +155,20 @@ def _calculate_window_bounds(current_line, total_lines, window_size): return start, end -@pytest.mark.asyncio -async def test_open_file_unexist_path(): +def test_open_file_unexist_path(): editor = Editor() with pytest.raises(FileNotFoundError): editor.open_file("/unexist/path/a.txt") -@pytest.mark.asyncio -async def test_open_file(tmp_path): +def test_open_file(tmp_path): editor = Editor() assert tmp_path is not None temp_file_path = tmp_path / "a.txt" temp_file_path.write_text("Line 1\nLine 2\nLine 3\nLine 4\nLine 5") - with io.StringIO() as buf: - with contextlib.redirect_stdout(buf): - editor.open_file(str(temp_file_path)) - result = buf.getvalue() + result = editor.open_file(str(temp_file_path)) + assert result is not None expected = ( f"[File: {temp_file_path} (5 lines total)]\n" @@ -177,21 +178,17 @@ async def test_open_file(tmp_path): "3|Line 3\n" "4|Line 4\n" "5|Line 5\n" - "(this is the end of the file)\n" + "(this is the end of the file)" ) assert result.split("\n") == expected.split("\n") -@pytest.mark.asyncio -async def test_open_file_with_indentation(tmp_path): +def test_open_file_with_indentation(tmp_path): editor = Editor() temp_file_path = tmp_path / "a.txt" temp_file_path.write_text("Line 1\n Line 2\nLine 3\nLine 4\nLine 5") - with io.StringIO() as buf: - with contextlib.redirect_stdout(buf): - editor.open_file(str(temp_file_path)) - result = buf.getvalue() + result = editor.open_file(str(temp_file_path)) assert result is not None expected = ( f"[File: {temp_file_path} (5 lines total)]\n" @@ -201,33 +198,28 @@ async def test_open_file_with_indentation(tmp_path): "3|Line 3\n" "4|Line 4\n" "5|Line 5\n" - "(this is the end of the file)\n" + "(this is the end of the file)" ) assert result.split("\n") == expected.split("\n") -@pytest.mark.asyncio -async def test_open_file_long(tmp_path): +def test_open_file_long(tmp_path): editor = Editor() temp_file_path = tmp_path / "a.txt" content = "\n".join([f"Line {i}" for i in range(1, 1001)]) temp_file_path.write_text(content) - with io.StringIO() as buf: - with contextlib.redirect_stdout(buf): - editor.open_file(str(temp_file_path), 1, 50) - result = buf.getvalue() + result = editor.open_file(str(temp_file_path), 1, 50) assert result is not None expected = f"[File: {temp_file_path} (1000 lines total)]\n" expected += "(this is the beginning of the file)\n" for i in range(1, 51): expected += f"{i}|Line {i}\n" - expected += "(950 more lines below)\n" + expected += "(950 more lines below)" assert result.split("\n") == expected.split("\n") -@pytest.mark.asyncio -async def test_open_file_long_with_lineno(tmp_path): +def test_open_file_long_with_lineno(tmp_path): editor = Editor() temp_file_path = tmp_path / "a.txt" content = "\n".join([f"Line {i}" for i in range(1, 1001)]) @@ -235,10 +227,7 @@ async def test_open_file_long_with_lineno(tmp_path): cur_line = 100 - with io.StringIO() as buf: - with contextlib.redirect_stdout(buf): - editor.open_file(str(temp_file_path), cur_line) - result = buf.getvalue() + result = editor.open_file(str(temp_file_path), cur_line) assert result is not None expected = f"[File: {temp_file_path} (1000 lines total)]\n" start, end = _calculate_window_bounds(cur_line, 1000, WINDOW) @@ -251,61 +240,44 @@ async def test_open_file_long_with_lineno(tmp_path): if end == 1000: expected += "(this is the end of the file)\n" else: - expected += f"({1000 - end} more lines below)\n" + expected += f"({1000 - end} more lines below)" assert result.split("\n") == expected.split("\n") -@pytest.mark.asyncio -async def test_create_file_unexist_path(): +def test_create_file_unexist_path(): editor = Editor() with pytest.raises(FileNotFoundError): editor.create_file("/unexist/path/a.txt") -@pytest.mark.asyncio -async def test_create_file(tmp_path): +def test_create_file(tmp_path): editor = Editor() temp_file_path = tmp_path / "a.txt" - with io.StringIO() as buf: - with contextlib.redirect_stdout(buf): - editor.create_file(str(temp_file_path)) - result = buf.getvalue() + result = editor.create_file(str(temp_file_path)) - expected = ( - f"[File: {temp_file_path} (1 lines total)]\n" - "(this is the beginning of the file)\n" - "1|\n" - "(this is the end of the file)\n" - f"[File {temp_file_path} created.]\n" - ) + expected = f"[File {temp_file_path} created.]" assert result.split("\n") == expected.split("\n") -@pytest.mark.asyncio -async def test_goto_line(tmp_path): +def test_goto_line(tmp_path): editor = Editor() temp_file_path = tmp_path / "a.txt" total_lines = 1000 content = "\n".join([f"Line {i}" for i in range(1, total_lines + 1)]) temp_file_path.write_text(content) - with io.StringIO() as buf: - with contextlib.redirect_stdout(buf): - editor.open_file(str(temp_file_path)) - result = buf.getvalue() + result = editor.open_file(str(temp_file_path)) assert result is not None expected = f"[File: {temp_file_path} ({total_lines} lines total)]\n" expected += "(this is the beginning of the file)\n" for i in range(1, WINDOW + 1): expected += f"{i}|Line {i}\n" - expected += f"({total_lines - WINDOW} more lines below)\n" + expected += f"({total_lines - WINDOW} more lines below)" assert result.split("\n") == expected.split("\n") - with io.StringIO() as buf: - with contextlib.redirect_stdout(buf): - editor.goto_line(500) - result = buf.getvalue() + result = editor.goto_line(500) + assert result is not None cur_line = 500 @@ -320,50 +292,39 @@ async def test_goto_line(tmp_path): if end == total_lines: expected += "(this is the end of the file)\n" else: - expected += f"({total_lines - end} more lines below)\n" + expected += f"({total_lines - end} more lines below)" assert result.split("\n") == expected.split("\n") -@pytest.mark.asyncio -async def test_goto_line_negative(tmp_path): +def test_goto_line_negative(tmp_path): editor = Editor() temp_file_path = tmp_path / "a.txt" content = "\n".join([f"Line {i}" for i in range(1, 5)]) temp_file_path.write_text(content) - with io.StringIO() as buf: - with contextlib.redirect_stdout(buf): - editor.open_file(str(temp_file_path)) + editor.open_file(str(temp_file_path)) with pytest.raises(ValueError): editor.goto_line(-1) -@pytest.mark.asyncio -async def test_goto_line_out_of_bound(tmp_path): +def test_goto_line_out_of_bound(tmp_path): editor = Editor() temp_file_path = tmp_path / "a.txt" content = "\n".join([f"Line {i}" for i in range(1, 5)]) temp_file_path.write_text(content) - with io.StringIO() as buf: - with contextlib.redirect_stdout(buf): - editor.open_file(str(temp_file_path)) + editor.open_file(str(temp_file_path)) with pytest.raises(ValueError): editor.goto_line(100) -@pytest.mark.asyncio -async def test_scroll_down(tmp_path): +def test_scroll_down(tmp_path): editor = Editor() temp_file_path = tmp_path / "a.txt" total_lines = 1000 content = "\n".join([f"Line {i}" for i in range(1, total_lines + 1)]) temp_file_path.write_text(content) - - with io.StringIO() as buf: - with contextlib.redirect_stdout(buf): - editor.open_file(str(temp_file_path)) - result = buf.getvalue() + result = editor.open_file(str(temp_file_path)) assert result is not None expected = f"[File: {temp_file_path} ({total_lines} lines total)]\n" @@ -375,15 +336,13 @@ async def test_scroll_down(tmp_path): for i in range(start, end + 1): expected += f"{i}|Line {i}\n" if end == total_lines: - expected += "(this is the end of the file)\n" + expected += "(this is the end of the file)" else: - expected += f"({total_lines - end} more lines below)\n" + expected += f"({total_lines - end} more lines below)" assert result.split("\n") == expected.split("\n") - with io.StringIO() as buf: - with contextlib.redirect_stdout(buf): - editor.scroll_down() - result = buf.getvalue() + result = editor.scroll_down() + assert result is not None expected = f"[File: {temp_file_path} ({total_lines} lines total)]\n" @@ -397,12 +356,11 @@ async def test_scroll_down(tmp_path): if end == total_lines: expected += "(this is the end of the file)\n" else: - expected += f"({total_lines - end} more lines below)\n" + expected += f"({total_lines - end} more lines below)" assert result.split("\n") == expected.split("\n") -@pytest.mark.asyncio -async def test_scroll_up(tmp_path): +def test_scroll_up(tmp_path): editor = Editor() temp_file_path = tmp_path / "a.txt" total_lines = 1000 @@ -410,10 +368,8 @@ async def test_scroll_up(tmp_path): temp_file_path.write_text(content) cur_line = 300 - with io.StringIO() as buf: - with contextlib.redirect_stdout(buf): - editor.open_file(str(temp_file_path), cur_line) - result = buf.getvalue() + + result = editor.open_file(str(temp_file_path), cur_line) assert result is not None expected = f"[File: {temp_file_path} ({total_lines} lines total)]\n" @@ -427,13 +383,9 @@ async def test_scroll_up(tmp_path): if end == total_lines: expected += "(this is the end of the file)\n" else: - expected += f"({total_lines - end} more lines below)\n" + expected += f"({total_lines - end} more lines below)" assert result.split("\n") == expected.split("\n") - - with io.StringIO() as buf: - with contextlib.redirect_stdout(buf): - editor.scroll_up() - result = buf.getvalue() + result = editor.scroll_up() assert result is not None cur_line = cur_line - WINDOW @@ -449,44 +401,35 @@ async def test_scroll_up(tmp_path): if end == total_lines: expected += "(this is the end of the file)\n" else: - expected += f"({total_lines - end} more lines below)\n" + expected += f"({total_lines - end} more lines below)" assert result.split("\n") == expected.split("\n") -@pytest.mark.asyncio -async def test_scroll_down_edge(tmp_path): +def test_scroll_down_edge(tmp_path): editor = Editor() temp_file_path = tmp_path / "a.txt" content = "\n".join([f"Line {i}" for i in range(1, 10)]) temp_file_path.write_text(content) - with io.StringIO() as buf: - with contextlib.redirect_stdout(buf): - editor.open_file(str(temp_file_path)) - result = buf.getvalue() + result = editor.open_file(str(temp_file_path)) assert result is not None expected = f"[File: {temp_file_path} (9 lines total)]\n" expected += "(this is the beginning of the file)\n" for i in range(1, 10): expected += f"{i}|Line {i}\n" - expected += "(this is the end of the file)\n" + expected += "(this is the end of the file)" - with io.StringIO() as buf: - with contextlib.redirect_stdout(buf): - editor.scroll_down() - result = buf.getvalue() + result = editor.scroll_down() assert result is not None assert result.split("\n") == expected.split("\n") -@pytest.mark.asyncio -async def test_print_window_internal(tmp_path): +def test_print_window_internal(tmp_path): editor = Editor() test_file_path = tmp_path / "a.txt" - await editor.create_file(str(test_file_path)) - editor.open_file(str(test_file_path)) + editor.create_file(str(test_file_path)) with open(test_file_path, "w") as file: for i in range(1, 101): file.write(f"Line `{i}`\n") @@ -494,20 +437,15 @@ async def test_print_window_internal(tmp_path): current_line = 50 window = 2 - with io.StringIO() as buf: - with contextlib.redirect_stdout(buf): - editor._print_window(str(test_file_path), current_line, window, return_str=False) - result = buf.getvalue() - expected = "(48 more lines above)\n" "49|Line `49`\n" "50|Line `50`\n" "51|Line `51`\n" "(49 more lines below)\n" + result = editor._print_window(test_file_path, current_line, window) + expected = "(48 more lines above)\n" "49|Line `49`\n" "50|Line `50`\n" "51|Line `51`\n" "(49 more lines below)" assert result == expected -@pytest.mark.asyncio -async def test_open_file_large_line_number(tmp_path): +def test_open_file_large_line_number(tmp_path): editor = Editor() test_file_path = tmp_path / "a.txt" editor.create_file(str(test_file_path)) - editor.open_file(str(test_file_path)) with open(test_file_path, "w") as file: for i in range(1, 1000): file.write(f"Line `{i}`\n") @@ -515,24 +453,20 @@ async def test_open_file_large_line_number(tmp_path): current_line = 800 window = 100 - with io.StringIO() as buf: - with contextlib.redirect_stdout(buf): - editor.open_file(str(test_file_path), current_line, window) - result = buf.getvalue() + result = editor.open_file(str(test_file_path), current_line, window) + expected = f"[File: {test_file_path} (999 lines total)]\n" expected += "(749 more lines above)\n" for i in range(750, 850 + 1): expected += f"{i}|Line `{i}`\n" - expected += "(149 more lines below)\n" + expected += "(149 more lines below)" assert result == expected -@pytest.mark.asyncio -async def test_open_file_large_line_number_consecutive_diff_window(tmp_path): +def test_open_file_large_line_number_consecutive_diff_window(tmp_path): editor = Editor() test_file_path = tmp_path / "a.txt" editor.create_file(str(test_file_path)) - editor.open_file(str(test_file_path)) total_lines = 1000 with open(test_file_path, "w") as file: for i in range(1, total_lines + 1): @@ -541,10 +475,8 @@ async def test_open_file_large_line_number_consecutive_diff_window(tmp_path): current_line = 800 cur_window = 300 - with io.StringIO() as buf: - with contextlib.redirect_stdout(buf): - editor.open_file(str(test_file_path), current_line, cur_window) - result = buf.getvalue() + result = editor.open_file(str(test_file_path), current_line, cur_window) + expected = f"[File: {test_file_path} ({total_lines} lines total)]\n" start, end = _calculate_window_bounds(current_line, total_lines, cur_window) if start == 1: @@ -556,13 +488,26 @@ async def test_open_file_large_line_number_consecutive_diff_window(tmp_path): if end == total_lines: expected += "(this is the end of the file)\n" else: - expected += f"({total_lines - end} more lines below)\n" + expected += f"({total_lines - end} more lines below)" assert result == expected current_line = current_line - WINDOW - with io.StringIO() as buf: - with contextlib.redirect_stdout(buf): - editor.scroll_up() + + result = editor.scroll_up() + + expected = f"[File: {test_file_path} ({total_lines} lines total)]\n" + start, end = _calculate_window_bounds(current_line, total_lines, WINDOW) + if start == 1: + expected += "(this is the beginning of the file)\n" + else: + expected += f"({start - 1} more lines above)\n" + for i in range(start, end + 1): + expected += f"{i}|Line `{i}`\n" + if end == total_lines: + expected += "(this is the end of the file)\n" + else: + expected += f"({total_lines - end} more lines below)" + assert result.split("\n") == expected.split("\n") if __name__ == "__main__":