From 8e7696b8e6027404b399cb27ec1aae39b34bf10d Mon Sep 17 00:00:00 2001 From: liushaojie Date: Mon, 26 Aug 2024 13:40:11 +0800 Subject: [PATCH] update: editor --- metagpt/prompts/di/swe_agent.py | 4 +- metagpt/roles/di/role_zero.py | 23 +- metagpt/roles/di/swe_agent.py | 3 +- metagpt/tools/libs/editor.py | 1004 ++++++++++++++++++----- metagpt/tools/libs/linter.py | 222 +++++ requirements.txt | 2 + tests/metagpt/tools/libs/test_editor.py | 490 ++++++++++- 7 files changed, 1516 insertions(+), 232 deletions(-) create mode 100644 metagpt/tools/libs/linter.py diff --git a/metagpt/prompts/di/swe_agent.py b/metagpt/prompts/di/swe_agent.py index 86a062214..b543c01d5 100644 --- a/metagpt/prompts/di/swe_agent.py +++ b/metagpt/prompts/di/swe_agent.py @@ -183,7 +183,9 @@ IMPORTANT_TIPS = """ 15. When the edit fails, try to enlarge the starting line. -16. Once again, and this is critical: YOU CAN ONLY ENTER ONE COMMAND AT A TIME. +16. Use an absolute path instead of a relative path. + +17. Once again, and this is critical: YOU CAN ONLY ENTER ONE COMMAND AT A TIME. """ NEXT_STEP_TEMPLATE = f""" diff --git a/metagpt/roles/di/role_zero.py b/metagpt/roles/di/role_zero.py index ab56dfa59..e32292b96 100644 --- a/metagpt/roles/di/role_zero.py +++ b/metagpt/roles/di/role_zero.py @@ -109,9 +109,6 @@ class RoleZero(Role): "Plan.append_task": self.planner.plan.append_task, "Plan.reset_task": self.planner.plan.reset_task, "Plan.replace_task": self.planner.plan.replace_task, - "Editor.write": self.editor.write, - "Editor.write_content": self.editor.write_content, - "Editor.read": self.editor.read, "RoleZero.ask_human": self.ask_human, "RoleZero.reply_to_human": self.reply_to_human, } @@ -132,6 +129,26 @@ class RoleZero(Role): ] } ) + self.tool_execution_map.update( + { + f"Editor.{i}": getattr(self.editor, i) + for i in [ + "append_file", + "create_file", + "edit_file_by_replace", + "find_file", + "goto_line", + "insert_content_at_line", + "open_file", + # "read", + "scroll_down", + "scroll_up", + "search_dir", + "search_file", + # "write", + ] + } + ) # can be updated by subclass self._update_tool_execution() return self diff --git a/metagpt/roles/di/swe_agent.py b/metagpt/roles/di/swe_agent.py index e1d2c9613..9efe9ce34 100644 --- a/metagpt/roles/di/swe_agent.py +++ b/metagpt/roles/di/swe_agent.py @@ -19,10 +19,11 @@ class SWEAgent(RoleZero): goal: str = "Resolve GitHub issue or bug in any existing codebase" _instruction: str = NEXT_STEP_TEMPLATE tools: list[str] = [ - "Bash", + # "Bash", "Browser:goto,scroll", "RoleZero", "git_create_pull", + "Editor", ] terminal: Bash = Field(default_factory=Bash, exclude=True) output_diff: str = "" diff --git a/metagpt/tools/libs/editor.py b/metagpt/tools/libs/editor.py index 240c28767..dde5df613 100644 --- a/metagpt/tools/libs/editor.py +++ b/metagpt/tools/libs/editor.py @@ -1,7 +1,8 @@ import base64 import os +import re import shutil -import subprocess +import tempfile from pathlib import Path from typing import List, Optional, Union @@ -9,12 +10,16 @@ from pydantic import BaseModel, ConfigDict from metagpt.config2 import Config from metagpt.logs import logger +from metagpt.tools.libs.linter import Linter from metagpt.tools.tool_registry import register_tool from metagpt.utils import read_docx from metagpt.utils.common import aread, aread_bin, awrite_bin, check_http_endpoint -from metagpt.utils.repo_to_markdown import is_text_file from metagpt.utils.report import EditorReporter +# This is also used in unit tests! +MSG_FILE_UPDATED = "[File updated (edited at line {line_number}). Please review the changes and make sure they are correct (correct indentation, no duplicate lines, etc). Edit the file again if necessary.]" +LINTER_ERROR_MSG = "[Your proposed edit has introduced new syntax error(s). Please understand the errors and retry your edit command.]\n" + class FileBlock(BaseModel): """A block of content in a file""" @@ -23,203 +28,65 @@ class FileBlock(BaseModel): block_content: str +class LineNumberError(Exception): + pass + + @register_tool() class Editor(BaseModel): """ - A tool for reading, understanding, writing, and editing files. - Support local file including text-based files (txt, md, json, py, html, js, css, etc.), pdf, docx, excluding images, csv, excel, or online links + A state-of-state tool for reading, understanding, writing, and editing files. + All path parameters should use an absolute path. """ model_config = ConfigDict(arbitrary_types_allowed=True) resource: EditorReporter = EditorReporter() + # CURRENT_FILE: Optional[str] = None + current_file: Optional[str] = None + current_line: int = 1 + # WINDOW: int = 100 + window: int = 100 - def write(self, path: str, content: str): - """Write the whole content to a file. When used, make sure content arg contains the full content of the file.""" - if "\n" not in content and "\\n" in content: - # A very raw rule to correct the content: If 'content' lacks actual newlines ('\n') but includes '\\n', consider - # replacing them with '\n' to potentially correct mistaken representations of newline characters. - content = content.replace("\\n", "\n") - directory = os.path.dirname(path) - if directory and not os.path.exists(directory): - os.makedirs(directory) - with open(path, "w", encoding="utf-8") as f: - f.write(content) - # self.resource.report(path, "path") - return f"The writing/coding the of the file {os.path.basename(path)}' is now completed. The file '{os.path.basename(path)}' has been successfully created." + # def write(self, path: str, content: str): + # """Write the whole content to a file. When used, make sure content arg contains the full content of the file.""" + # if "\n" not in content and "\\n" in content: + # # A very raw rule to correct the content: If 'content' lacks actual newlines ('\n') but includes '\\n', consider + # # replacing them with '\n' to potentially correct mistaken representations of newline characters. + # content = content.replace("\\n", "\n") + # directory = os.path.dirname(path) + # if directory and not os.path.exists(directory): + # os.makedirs(directory) + # with open(path, "w", encoding="utf-8") as f: + # f.write(content) + # # self.resource.report(path, "path") + # return f"The writing/coding the of the file {os.path.basename(path)}' is now completed. The file '{os.path.basename(path)}' has been successfully created." - async def read(self, path: str) -> FileBlock: - """Read the whole content of a file. Using absolute paths as the argument for specifying the file location.""" - is_text, mime_type = await is_text_file(path) - if is_text: - lines = await self._read_text(path) - elif mime_type == "application/pdf": - lines = await self._read_pdf(path) - elif mime_type in { - "application/msword", - "application/vnd.openxmlformats-officedocument.wordprocessingml.document", - "application/vnd.ms-word.document.macroEnabled.12", - "application/vnd.openxmlformats-officedocument.wordprocessingml.template", - "application/vnd.ms-word.template.macroEnabled.12", - }: - lines = await self._read_docx(path) - else: - return FileBlock(file_path=str(path), block_content="") - self.resource.report(str(path), "path") - - lines_with_num = [f"{i + 1:03}|{line}" for i, line in enumerate(lines)] - result = FileBlock( - file_path=str(path), - block_content="".join(lines_with_num), - ) - return result - - def search_content(self, symbol: str, root_path: str = ".", window: int = 50) -> FileBlock: - """ - Search symbol in all files under root_path, return the context of symbol with window size - Useful for locating class or function in a large codebase. Example symbol can be "def some_function", "class SomeClass", etc. - In searching, attempt different symbols of different granualities, e.g. "def some_function", "class SomeClass", a certain line of code, etc. - - Args: - symbol (str): The symbol to search. - root_path (str, optional): The root path to search in, the path can be a folder or a file. If not provided, search in the current directory. Defaults to ".". - window (int, optional): The window size to return. Defaults to 20. - - Returns: - FileBlock: The block containing the symbol, a pydantic BaseModel with the schema below. - class FileBlock(BaseModel): - file_path: str - block_content: str - """ - if not os.path.exists(root_path): - print(f"Currently at {os.getcwd()} containing: {os.listdir()}. Path {root_path} does not exist.") - return None - not_found_msg = ( - "symbol not found, you may try searching another one, or break down your search term to search a part of it" - ) - if os.path.isfile(root_path): - result = self._search_content_in_file(symbol, root_path, window) - if not result: - print(not_found_msg) - return result - for root, _, files in os.walk(root_path or "."): - for file in files: - file_path = os.path.join(root, file) - result = self._search_content_in_file(symbol, file_path, window) - if result: - # FIXME: This returns the first found result, not all results. - return result - print(not_found_msg) - return None - - def _search_content_in_file(self, symbol: str, file_path: str, window: int = 50) -> FileBlock: - print("search in", file_path) - if not file_path.endswith(".py"): - return None - with open(file_path, "r", encoding="utf-8") as f: - try: - lines = f.readlines() - except Exception: - return None - for i, line in enumerate(lines): - if symbol in line: - start = max(i - window, 0) - end = min(i + window, len(lines) - 1) - for row_num in range(start, end + 1): - lines[row_num] = f"{(row_num + 1):03}|{lines[row_num]}" - block_content = "".join(lines[start : end + 1]) - result = FileBlock( - file_path=file_path, - block_content=block_content, - ) - self.resource.report(result.file_path, "path", extra={"type": "search", "line": i, "symbol": symbol}) - return result - return None - - def write_content(self, file_path: str, start_line: int, end_line: int, new_block_content: str = "") -> str: - """ - Write a new block of content into a file. Use this method to update a block of code in a file. There are three cases: - 1. If the new block content is empty, the original block will be deleted. - 2. If the new block content is not empty and end_line < start_line (e.g. set end_line = -1) the new block content will be inserted at start_line. - 3. If the new block content is not empty and end_line >= start_line, the original block from start_line to end_line (both inclusively) will be replaced by the new block content. - This function can sometimes be used given a FileBlock upstream. You should carefully review its row number. Determine the start_line and end_line based on the row number of the FileBlock. - The file content from start_line to end_line will be replaced by your new_block_content. DON'T replace more than you intend to. - - Args: - file_path (str): The file path to write the new block content. - start_line (int): start line of the original block to be updated (inclusive). - end_line (int): end line of the original block to be updated (inclusive). - new_block_content (str): The new block content to write. Don't include row number in the content. - - Returns: - str: A message indicating the status of the write operation. - """ - # Create a temporary copy of the file - temp_file_path = file_path + ".temp" - shutil.copy(file_path, temp_file_path) - - try: - # Modify the temporary file with the new content - self._write_content(temp_file_path, start_line, end_line, new_block_content) - - # Lint the modified temporary file - lint_passed, lint_message = self._lint_file(temp_file_path) - # if not lint_passed: - # return f"Linting the content at a temp file, failed with:\n{lint_message}" - - # If linting passes, overwrite the original file with the temporary file - shutil.move(temp_file_path, file_path) - - new_file_block = FileBlock( - file_path=file_path, - block_content=new_block_content, - ) - self.resource.report(new_file_block.file_path, "path") - - return f"Content written successfully to {file_path}" - - finally: - # Clean up: Ensure the temporary file is removed if it still exists - if os.path.exists(temp_file_path): - os.remove(temp_file_path) - - def _write_content(self, file_path: str, start_line: int, end_line: int, new_block_content: str = ""): - """start_line and end_line are both 1-based indices and inclusive.""" - with open(file_path, "r") as file: - lines = file.readlines() - - start_line_index = start_line - 1 # Adjusting because list indices start at 0 - end_line_index = end_line - - if new_block_content: - # Split the new_block_content by newline and ensure each line ends with a newline character - new_content_lines = new_block_content.splitlines( - keepends=True - ) # FIXME: This will split \n within a line, such as ab\ncd - if end_line >= start_line: - # This replaces the block between start_line and end_line with new_block_content - # irrespective of the length difference between the original and new content. - lines[start_line_index:end_line_index] = new_content_lines - else: - lines.insert(start_line_index, "".join(new_content_lines)) - else: - del lines[start_line_index:end_line_index] - - with open(file_path, "w") as file: - file.writelines(lines) - - @classmethod - def _lint_file(cls, file_path: str) -> (bool, str): - """Lints an entire Python file using pylint, returns True if linting passes, along with pylint's output.""" - result = subprocess.run( - ["pylint", file_path, "--disable=all", "--enable=E"], - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - text=True, - ) - lint_passed = result.returncode == 0 - lint_message = result.stdout - return lint_passed, lint_message + # async def read(self, path: str) -> FileBlock: + # """Read the whole content of a file. Using absolute paths as the argument for specifying the file location.""" + # is_text, mime_type = await is_text_file(path) + # if is_text: + # lines = await self._read_text(path) + # elif mime_type == "application/pdf": + # lines = await self._read_pdf(path) + # elif mime_type in { + # "application/msword", + # "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + # "application/vnd.ms-word.document.macroEnabled.12", + # "application/vnd.openxmlformats-officedocument.wordprocessingml.template", + # "application/vnd.ms-word.template.macroEnabled.12", + # }: + # lines = await self._read_docx(path) + # else: + # return FileBlock(file_path=str(path), block_content="") + # self.resource.report(str(path), "path") + # + # lines_with_num = [f"{i + 1:03}|{line}" for i, line in enumerate(lines)] + # result = FileBlock( + # file_path=str(path), + # block_content="".join(lines_with_num), + # ) + # return result @staticmethod async def _read_text(path: Union[str, Path]) -> List[str]: @@ -292,3 +159,762 @@ class Editor(BaseModel): if config.omniparse and config.omniparse.url: return config.omniparse.url return "" + + @staticmethod + def _is_valid_filename(file_name) -> bool: + if not file_name or not isinstance(file_name, str) or not file_name.strip(): + return False + invalid_chars = '<>:"/\\|?*' + if os.name == "nt": # Windows + invalid_chars = '<>:"/\\|?*' + elif os.name == "posix": # Unix-like systems + invalid_chars = "\0" + + for char in invalid_chars: + if char in file_name: + return False + return True + + @staticmethod + def _is_valid_path(path) -> bool: + if not path or not isinstance(path, str): + return False + try: + return os.path.exists(os.path.normpath(path)) + except PermissionError: + return False + + @staticmethod + def _create_paths(file_name) -> bool: + try: + dirname = os.path.dirname(file_name) + if dirname: + os.makedirs(dirname, exist_ok=True) + return True + except PermissionError: + return False + + def _check_current_file(self, file_path: Optional[str] = None) -> bool: + if not file_path: + file_path = self.current_file + if not file_path or not os.path.isfile(file_path): + raise ValueError("No file open. Use the open_file function first.") + return True + + @staticmethod + def _clamp(value, min_value, max_value): + return max(min_value, min(value, max_value)) + + @staticmethod + def _lint_file(file_path: str) -> tuple[Optional[str], Optional[int]]: + """Lint the file at the given path and return a tuple with a boolean indicating if there are errors, + and the line number of the first error, if any. + + Returns: + tuple[str | None, int | None]: (lint_error, first_error_line_number) + """ + linter = Linter(root=os.getcwd()) + lint_error = linter.lint(file_path) + if not lint_error: + # Linting successful. No issues found. + return None, None + return "ERRORS:\n" + lint_error.text, lint_error.lines[0] + + def _print_window(self, file_path, targeted_line, window, return_str=False): + self._check_current_file(file_path) + with open(file_path) as file: + content = file.read() + + # Ensure the content ends with a newline character + if not content.endswith("\n"): + content += "\n" + + lines = content.splitlines(True) # Keep all line ending characters + total_lines = len(lines) + + # cover edge cases + self.current_line = self._clamp(targeted_line, 1, total_lines) + half_window = max(1, window // 2) + + # Ensure at least one line above and below the targeted line + start = max(1, self.current_line - half_window) + end = min(total_lines, self.current_line + half_window) + + # Adjust start and end to ensure at least one line above and below + if start == 1: + end = min(total_lines, start + window - 1) + if end == total_lines: + start = max(1, end - window + 1) + + output = "" + + # only display this when there's at least one line above + if start > 1: + output += f"({start - 1} more lines above)\n" + else: + output += "(this is the beginning of the file)\n" + for i in range(start, end + 1): + _new_line = f"{i}|{lines[i - 1]}" + if not _new_line.endswith("\n"): + _new_line += "\n" + output += _new_line + if end < total_lines: + output += f"({total_lines - end} more lines below)\n" + else: + output += "(this is the end of the file)\n" + output = output.rstrip() + + if return_str: + return output + else: + print(output) + + @staticmethod + def _cur_file_header(current_file, total_lines) -> str: + if not current_file: + return "" + return f"[File: {os.path.abspath(current_file)} ({total_lines} lines total)]\n" + + def open_file(self, path: str, line_number: Optional[int] = 1, context_lines: Optional[int] = None) -> None: + """Opens the file at the given path in the editor. If line_number is provided, the window will be moved to include that line. + It only shows the first 100 lines by default! Max `context_lines` supported is 2000, use `scroll up/down` + to view the file if you want to see more. + + Args: + path: str: The path to the file to open, preferred absolute path. + line_number: int | None = 1: The line number to move to. Defaults to 1. + context_lines: int | None = 100: Only shows this number of lines in the context window (usually from line 1), with line_number as the center (if possible). Defaults to 100. + """ + if context_lines is None: + context_lines = self.window + + if not os.path.isfile(path): + raise FileNotFoundError(f"File {path} not found") + + CURRENT_FILE = os.path.abspath(path) + with open(CURRENT_FILE) as file: + total_lines = max(1, sum(1 for _ in file)) + + if not isinstance(line_number, int) or line_number < 1 or line_number > total_lines: + raise ValueError(f"Line number must be between 1 and {total_lines}") + self.current_line = line_number + + # Override WINDOW with context_lines + if context_lines is None or context_lines < 1: + context_lines = self.window + + output = self._cur_file_header(CURRENT_FILE, total_lines) + output += self._print_window( + CURRENT_FILE, self.current_line, self._clamp(context_lines, 1, 2000), return_str=True + ) + print(output) + + def goto_line(self, line_number: int) -> None: + """Moves the window to show the specified line number. + + Args: + line_number: int: The line number to move to. + """ + self._check_current_file() + + with open(str(self.current_file)) as file: + total_lines = max(1, sum(1 for _ in file)) + if not isinstance(line_number, int) or line_number < 1 or line_number > total_lines: + raise ValueError(f"Line number must be between 1 and {total_lines}") + + self.current_line = self._clamp(line_number, 1, total_lines) + + output = self._cur_file_header(self.current_file, total_lines) + output += self._print_window(self.current_file, self.current_line, self.window, return_str=True) + print(output) + + def scroll_down(self) -> None: + """Moves the window down by 100 lines.""" + self._check_current_file() + + with open(str(self.current_file)) as file: + total_lines = max(1, sum(1 for _ in file)) + self.current_line = self._clamp(self.current_line + self.window, 1, total_lines) + output = self._cur_file_header(self.current_file, total_lines) + output += self._print_window(self.current_file, self.current_line, self.window, return_str=True) + print(output) + + def scroll_up(self) -> None: + """Moves the window up by 100 lines.""" + self._check_current_file() + + with open(str(self.current_file)) as file: + total_lines = max(1, sum(1 for _ in file)) + self.current_line = self._clamp(self.current_line - self.window, 1, total_lines) + output = self._cur_file_header(self.current_file, total_lines) + output += self._print_window(self.current_file, self.current_line, self.window, return_str=True) + print(output) + + @classmethod + def create_file(cls, filename: str) -> None: + """Creates and opens a new file with the given name. + + Args: + filename: str: The name of the file to create. + """ + if os.path.exists(filename): + raise FileExistsError(f"File '{filename}' already exists.") + + with open(filename, "w") as file: + file.write("\n") + + cls.open_file(filename) + print(f"[File {filename} created.]") + + @staticmethod + def _append_impl(lines, content): + """Internal method to handle appending to a file. + + Args: + lines: list[str]: The lines in the original file. + content: str: The content to append to the file. + + Returns: + content: str: The new content of the file. + n_added_lines: int: The number of lines added to the file. + """ + content_lines = content.splitlines(keepends=True) + n_added_lines = len(content_lines) + if lines and not (len(lines) == 1 and lines[0].strip() == ""): + # file is not empty + if not lines[-1].endswith("\n"): + lines[-1] += "\n" + new_lines = lines + content_lines + content = "".join(new_lines) + else: + # file is empty + content = "".join(content_lines) + + return content, n_added_lines + + @staticmethod + def _insert_impl(lines, start, content): + """Internal method to handle inserting to a file. + + Args: + lines: list[str]: The lines in the original file. + start: int: The start line number for inserting. + content: str: The content to insert to the file. + + Returns: + content: str: The new content of the file. + n_added_lines: int: The number of lines added to the file. + + Raises: + LineNumberError: If the start line number is invalid. + """ + inserted_lines = [content + "\n" if not content.endswith("\n") else content] + if len(lines) == 0: + new_lines = inserted_lines + elif start is not None: + if len(lines) == 1 and lines[0].strip() == "": + # if the file with only 1 line and that line is empty + lines = [] + + if len(lines) == 0: + new_lines = inserted_lines + else: + new_lines = lines[: start - 1] + inserted_lines + lines[start - 1 :] + else: + raise LineNumberError( + f"Invalid line number: {start}. Line numbers must be between 1 and {len(lines)} (inclusive)." + ) + + content = "".join(new_lines) + n_added_lines = len(inserted_lines) + return content, n_added_lines + + @staticmethod + def _edit_impl(lines, start, end, content): + """Internal method to handle editing a file. + + REQUIRES (should be checked by caller): + start <= end + start and end are between 1 and len(lines) (inclusive) + content ends with a newline + + Args: + lines: list[str]: The lines in the original file. + start: int: The start line number for editing. + end: int: The end line number for editing. + content: str: The content to replace the lines with. + + Returns: + content: str: The new content of the file. + n_added_lines: int: The number of lines added to the file. + """ + # Handle cases where start or end are None + if start is None: + start = 1 # Default to the beginning + if end is None: + end = len(lines) # Default to the end + # Check arguments + if not (1 <= start <= len(lines)): + raise LineNumberError( + f"Invalid start line number: {start}. Line numbers must be between 1 and {len(lines)} (inclusive)." + ) + if not (1 <= end <= len(lines)): + raise LineNumberError( + f"Invalid end line number: {end}. Line numbers must be between 1 and {len(lines)} (inclusive)." + ) + if start > end: + raise LineNumberError(f"Invalid line range: {start}-{end}. Start must be less than or equal to end.") + + if not content.endswith("\n"): + content += "\n" + content_lines = content.splitlines(True) + n_added_lines = len(content_lines) + new_lines = lines[: start - 1] + content_lines + lines[end:] + content = "".join(new_lines) + return content, n_added_lines + + def _edit_file_impl( + self, + file_name: str, + start: Optional[int] = None, + end: Optional[int] = None, + content: str = "", + is_insert: bool = False, + is_append: bool = False, + ) -> str: + """Internal method to handle common logic for edit_/append_file methods. + + Args: + file_name: str: The name of the file to edit or append to. + start: int | None = None: The start line number for editing. Ignored if is_append is True. + end: int | None = None: The end line number for editing. Ignored if is_append is True. + content: str: The content to replace the lines with or to append. + is_insert: bool = False: Whether to insert content at the given line number instead of editing. + is_append: bool = False: Whether to append content to the file instead of editing. + """ + ret_str = "" + + ERROR_MSG = f"[Error editing file {file_name}. Please confirm the file is correct.]" + ERROR_MSG_SUFFIX = ( + "Your changes have NOT been applied. Please fix your edit command and try again.\n" + "You either need to 1) Open the correct file and try again or 2) Specify the correct line number arguments.\n" + "DO NOT re-run the same failed edit command. Running it again will lead to the same error." + ) + + if not self._is_valid_filename(file_name): + raise FileNotFoundError("Invalid file name.") + + if not self._is_valid_path(file_name): + raise FileNotFoundError("Invalid path or file name.") + + if not self._create_paths(file_name): + raise PermissionError("Could not access or create directories.") + + if not os.path.isfile(file_name): + raise FileNotFoundError(f"File {file_name} not found.") + + if is_insert and is_append: + raise ValueError("Cannot insert and append at the same time.") + + # Use a temporary file to write changes + content = str(content or "") + temp_file_path = "" + src_abs_path = os.path.abspath(file_name) + first_error_line = None + + try: + n_added_lines = None + + # lint the original file + enable_auto_lint = os.getenv("ENABLE_AUTO_LINT", "false").lower() == "true" + if enable_auto_lint: + original_lint_error, _ = self._lint_file(file_name) + + # Create a temporary file + with tempfile.NamedTemporaryFile("w", delete=False) as temp_file: + temp_file_path = temp_file.name + + # Read the original file and check if empty and for a trailing newline + with open(file_name) as original_file: + lines = original_file.readlines() + + if is_append: + content, n_added_lines = self._append_impl(lines, content) + elif is_insert: + try: + content, n_added_lines = self._insert_impl(lines, start, content) + except LineNumberError as e: + ret_str += (f"{ERROR_MSG}\n" f"{e}\n" f"{ERROR_MSG_SUFFIX}") + "\n" + return ret_str + else: + try: + content, n_added_lines = self._edit_impl(lines, start, end, content) + except LineNumberError as e: + ret_str += (f"{ERROR_MSG}\n" f"{e}\n" f"{ERROR_MSG_SUFFIX}") + "\n" + return ret_str + + if not content.endswith("\n"): + content += "\n" + + # Write the new content to the temporary file + temp_file.write(content) + + # Replace the original file with the temporary file atomically + shutil.move(temp_file_path, src_abs_path) + + # Handle linting + # NOTE: we need to get env var inside this function + # because the env var will be set AFTER the agentskills is imported + if enable_auto_lint: + # BACKUP the original file + original_file_backup_path = os.path.join( + os.path.dirname(file_name), + f".backup.{os.path.basename(file_name)}", + ) + with open(original_file_backup_path, "w") as f: + f.writelines(lines) + + lint_error, first_error_line = self._lint_file(file_name) + + # Select the errors caused by the modification + def extract_last_part(line): + parts = line.split(":") + if len(parts) > 1: + return parts[-1].strip() + return line.strip() + + def subtract_strings(str1, str2) -> str: + lines1 = str1.splitlines() + lines2 = str2.splitlines() + + last_parts1 = [extract_last_part(line) for line in lines1] + + remaining_lines = [line for line in lines2 if extract_last_part(line) not in last_parts1] + + result = "\n".join(remaining_lines) + return result + + if original_lint_error and lint_error: + lint_error = subtract_strings(original_lint_error, lint_error) + if lint_error == "": + lint_error = None + first_error_line = None + + if lint_error is not None: + if first_error_line is not None: + show_line = int(first_error_line) + elif is_append: + # original end-of-file + show_line = len(lines) + # insert OR edit WILL provide meaningful line numbers + elif start is not None and end is not None: + show_line = int((start + end) / 2) + else: + raise ValueError("Invalid state. This should never happen.") + + ret_str += LINTER_ERROR_MSG + ret_str += lint_error + "\n" + + editor_lines = n_added_lines + 20 + + ret_str += "[This is how your edit would have looked if applied]\n" + ret_str += "-------------------------------------------------\n" + ret_str += self._print_window(file_name, show_line, editor_lines, return_str=True) + "\n" + ret_str += "-------------------------------------------------\n\n" + + ret_str += "[This is the original code before your edit]\n" + ret_str += "-------------------------------------------------\n" + ret_str += ( + self._print_window( + original_file_backup_path, + show_line, + editor_lines, + return_str=True, + ) + + "\n" + ) + ret_str += "-------------------------------------------------\n" + + ret_str += ( + "Your changes have NOT been applied. Please fix your edit command and try again.\n" + "You either need to 1) Specify the correct start/end line arguments or 2) Correct your edit code.\n" + "DO NOT re-run the same failed edit command. Running it again will lead to the same error." + ) + + # recover the original file + with open(original_file_backup_path) as fin, open(file_name, "w") as fout: + fout.write(fin.read()) + os.remove(original_file_backup_path) + return ret_str + + except FileNotFoundError as e: + ret_str += f"File not found: {e}\n" + except IOError as e: + ret_str += f"An error occurred while handling the file: {e}\n" + except ValueError as e: + ret_str += f"Invalid input: {e}\n" + except Exception as e: + # Clean up the temporary file if an error occurs + if temp_file_path and os.path.exists(temp_file_path): + os.remove(temp_file_path) + print(f"An unexpected error occurred: {e}") + raise e + + # Update the file information and print the updated content + with open(file_name, "r", encoding="utf-8") as file: + n_total_lines = max(1, len(file.readlines())) + if first_error_line is not None and int(first_error_line) > 0: + self.current_line = first_error_line + else: + if is_append: + self.current_line = max(1, len(lines)) # end of original file + else: + self.current_line = start or n_total_lines or 1 + ret_str += f"[File: {os.path.abspath(file_name)} ({n_total_lines} lines total after edit)]\n" + CURRENT_FILE = file_name + ret_str += self._print_window(CURRENT_FILE, self.current_line, self.window, return_str=True) + "\n" + ret_str += MSG_FILE_UPDATED.format(line_number=self.current_line) + return ret_str + + @classmethod + def edit_file_by_replace(cls, file_name: str, to_replace: str, new_content: str) -> None: + """Edit a file. This will search for `to_replace` in the given file and replace it with `new_content`. + + Every *to_replace* must *EXACTLY MATCH* the existing source code, character for character, including all comments, docstrings, etc. + + Include enough lines to make code in `to_replace` unique. `to_replace` should NOT be empty. + + For example, given a file "/workspace/example.txt" with the following content: + ``` + line 1 + line 2 + line 2 + line 3 + ``` + + EDITING: If you want to replace the second occurrence of "line 2", you can make `to_replace` unique: + + edit_file_by_replace( + '/workspace/example.txt', + to_replace='line 2\nline 3', + new_content='new line\nline 3', + ) + + This will replace only the second "line 2" with "new line". The first "line 2" will remain unchanged. + + The resulting file will be: + ``` + line 1 + line 2 + new line + line 3 + ``` + + REMOVAL: If you want to remove "line 2" and "line 3", you can set `new_content` to an empty string: + + edit_file_by_replace( + '/workspace/example.txt', + to_replace='line 2\nline 3', + new_content='', + ) + + Args: + file_name: str: The name of the file to edit. + to_replace: str: The content to search for and replace. + new_content: str: The new content to replace the old content with. + """ + # FIXME: support replacing *all* occurrences + if to_replace.strip() == "": + raise ValueError("`to_replace` must not be empty.") + + if to_replace == new_content: + raise ValueError("`to_replace` and `new_content` must be different.") + + # search for `to_replace` in the file + # if found, replace it with `new_content` + # if not found, perform a fuzzy search to find the closest match and replace it with `new_content` + with open(file_name, "r") as file: + file_content = file.read() + + if file_content.count(to_replace) > 1: + raise ValueError( + "`to_replace` appears more than once, please include enough lines to make code in `to_replace` unique." + ) + + start = file_content.find(to_replace) + if start != -1: + # Convert start from index to line number + start_line_number = file_content[:start].count("\n") + 1 + end_line_number = start_line_number + len(to_replace.splitlines()) - 1 + else: + + def _fuzzy_transform(s: str) -> str: + # remove all space except newline + return re.sub(r"[^\S\n]+", "", s) + + # perform a fuzzy search (remove all spaces except newlines) + to_replace_fuzzy = _fuzzy_transform(to_replace) + file_content_fuzzy = _fuzzy_transform(file_content) + # find the closest match + start = file_content_fuzzy.find(to_replace_fuzzy) + if start == -1: + print(f"[No exact match found in {file_name} for\n```\n{to_replace}\n```\n]") + return + # Convert start from index to line number for fuzzy match + start_line_number = file_content_fuzzy[:start].count("\n") + 1 + end_line_number = start_line_number + len(to_replace.splitlines()) - 1 + + ret_str = cls._edit_file_impl( + file_name, + start=start_line_number, + end=end_line_number, + content=new_content, + is_insert=False, + ) + # lint_error = bool(LINTER_ERROR_MSG in ret_str) + # TODO: automatically tries to fix linter error (maybe involve some static analysis tools on the location near the edit to figure out indentation) + print(ret_str) + + @classmethod + def insert_content_at_line(cls, file_name: str, line_number: int, content: str) -> None: + """Insert content at the given line number in a file. + This will NOT modify the content of the lines before OR after the given line number. + + For example, if the file has the following content: + ``` + line 1 + line 2 + line 3 + ``` + and you call `insert_content_at_line('file.txt', 2, 'new line')`, the file will be updated to: + ``` + line 1 + new line + line 2 + line 3 + ``` + + Args: + file_name: str: The name of the file to edit. + line_number: int: The line number (starting from 1) to insert the content after. + content: str: The content to insert. + """ + ret_str = cls._edit_file_impl( + file_name, + start=line_number, + end=line_number, + content=content, + is_insert=True, + is_append=False, + ) + print(ret_str) + + @classmethod + def append_file(cls, file_name: str, content: str) -> None: + """Append content to the given file. + It appends text `content` to the end of the specified file. + + Args: + file_name: str: The name of the file to edit. + line_number: int: The line number (starting from 1) to insert the content after. + content: str: The content to insert. + """ + ret_str = cls._edit_file_impl( + file_name, + start=None, + end=None, + content=content, + is_insert=False, + is_append=True, + ) + print(ret_str) + + @classmethod + def search_dir(cls, search_term: str, dir_path: str = "./") -> None: + """Searches for search_term in all files in dir. If dir is not provided, searches in the current directory. + + Args: + search_term: str: The term to search for. + dir_path: str: The path to the directory to search. + """ + if not os.path.isdir(dir_path): + raise FileNotFoundError(f"Directory {dir_path} not found") + matches = [] + for root, _, files in os.walk(dir_path): + for file in files: + if file.startswith("."): + continue + file_path = os.path.join(root, file) + with open(file_path, "r", errors="ignore") as f: + for line_num, line in enumerate(f, 1): + if search_term in line: + matches.append((file_path, line_num, line.strip())) + + if not matches: + print(f'No matches found for "{search_term}" in {dir_path}') + return + + num_matches = len(matches) + num_files = len(set(match[0] for match in matches)) + + if num_files > 100: + print(f'More than {num_files} files matched for "{search_term}" in {dir_path}. Please narrow your search.') + return + + print(f'[Found {num_matches} matches for "{search_term}" in {dir_path}]') + for file_path, line_num, line in matches: + print(f"{file_path} (Line {line_num}): {line}") + print(f'[End of matches for "{search_term}" in {dir_path}]') + + def search_file(self, search_term: str, file_path: Optional[str] = None) -> None: + """Searches for search_term in file. If file is not provided, searches in the current open file. + + Args: + search_term: str: The term to search for. + file_path: str | None: The path to the file to search. + """ + if file_path is None: + file_path = self.current_file + if file_path is None: + raise FileNotFoundError("No file specified or open. Use the open_file function first.") + if not os.path.isfile(file_path): + raise FileNotFoundError(f"File {file_path} not found") + + matches = [] + with open(file_path) as file: + for i, line in enumerate(file, 1): + if search_term in line: + matches.append((i, line.strip())) + + if matches: + print(f'[Found {len(matches)} matches for "{search_term}" in {file_path}]') + for match in matches: + print(f"Line {match[0]}: {match[1]}") + print(f'[End of matches for "{search_term}" in {file_path}]') + else: + print(f'[No matches found for "{search_term}" in {file_path}]') + + @staticmethod + def find_file(file_name: str, dir_path: str = "./") -> None: + """Finds all files with the given name in the specified directory. + + Args: + file_name: str: The name of the file to find. + dir_path: str: The path to the directory to search. + """ + if not os.path.isdir(dir_path): + raise FileNotFoundError(f"Directory {dir_path} not found") + + matches = [] + for root, _, files in os.walk(dir_path): + for file in files: + if file_name in file: + matches.append(os.path.join(root, file)) + + if matches: + print(f'[Found {len(matches)} matches for "{file_name}" in {dir_path}]') + for match in matches: + print(f"{match}") + print(f'[End of matches for "{file_name}" in {dir_path}]') + else: + print(f'[No matches found for "{file_name}" in {dir_path}]') diff --git a/metagpt/tools/libs/linter.py b/metagpt/tools/libs/linter.py new file mode 100644 index 000000000..509cb04c3 --- /dev/null +++ b/metagpt/tools/libs/linter.py @@ -0,0 +1,222 @@ +import os +import subprocess +import sys +import traceback +import warnings +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +from grep_ast import TreeContext, filename_to_lang +from tree_sitter_languages import get_parser # noqa: E402 + +# tree_sitter is throwing a FutureWarning +warnings.simplefilter("ignore", category=FutureWarning) + + +@dataclass +class LintResult: + text: str + lines: list + + +class Linter: + def __init__(self, encoding="utf-8", root=None): + self.encoding = encoding + self.root = root + + self.languages = dict( + python=self.py_lint, + ) + self.all_lint_cmd = None + + # def set_linter(self, lang, cmd): + # if lang: + # self.languages[lang] = cmd + # return + # + # self.all_lint_cmd = cmd + + def get_rel_fname(self, fname): + if self.root: + return os.path.relpath(fname, self.root) + else: + return fname + + def run_cmd(self, cmd, rel_fname, code): + cmd += " " + rel_fname + cmd = cmd.split() + + process = subprocess.Popen(cmd, cwd=self.root, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + stdout, _ = process.communicate() + errors = stdout.decode().strip() + self.returncode = process.returncode + if self.returncode == 0: + return # zero exit status + + cmd = " ".join(cmd) + res = "" + res += errors + line_num = extract_error_line_from(res) + return LintResult(text=res, lines=[line_num]) + + def get_abs_fname(self, fname): + if os.path.isabs(fname): + return fname + elif os.path.isfile(fname): + rel_fname = self.get_rel_fname(fname) + return os.path.abspath(rel_fname) + else: # if a temp file + return self.get_rel_fname(fname) + + def lint(self, fname, cmd=None) -> Optional[LintResult]: + code = Path(fname).read_text(self.encoding) + absolute_fname = self.get_abs_fname(fname) + if cmd: + cmd = cmd.strip() + if not cmd: + lang = filename_to_lang(fname) + if not lang: + return None + if self.all_lint_cmd: + cmd = self.all_lint_cmd + else: + cmd = self.languages.get(lang) + if callable(cmd): + linkres = cmd(fname, absolute_fname, code) + elif cmd: + linkres = self.run_cmd(cmd, absolute_fname, code) + else: + linkres = basic_lint(absolute_fname, code) + return linkres + + def flake_lint(self, rel_fname, code): + fatal = "F821,F822,F831,E112,E113,E999,E902" + flake8 = f"flake8 --select={fatal} --isolated" + + try: + flake_res = self.run_cmd(flake8, rel_fname, code) + except FileNotFoundError: + flake_res = None + return flake_res + + def py_lint(self, fname, rel_fname, code): + error = self.flake_lint(rel_fname, code) + if not error: + error = lint_python_compile(fname, code) + if not error: + error = basic_lint(rel_fname, code) + return error + + +def lint_python_compile(fname, code): + try: + compile(code, fname, "exec") # USE TRACEBACK BELOW HERE + return + except IndentationError as err: + end_lineno = getattr(err, "end_lineno", err.lineno) + if isinstance(end_lineno, int): + line_numbers = list(range(end_lineno - 1, end_lineno)) + else: + line_numbers = [] + + tb_lines = traceback.format_exception(type(err), err, err.__traceback__) + last_file_i = 0 + + target = "# USE TRACEBACK" + target += " BELOW HERE" + for i in range(len(tb_lines)): + if target in tb_lines[i]: + last_file_i = i + break + tb_lines = tb_lines[:1] + tb_lines[last_file_i + 1 :] + + res = "".join(tb_lines) + return LintResult(text=res, lines=line_numbers) + + +def basic_lint(fname, code): + """ + Use tree-sitter to look for syntax errors, display them with tree context. + """ + + lang = filename_to_lang(fname) + if not lang: + return + + parser = get_parser(lang) + tree = parser.parse(bytes(code, "utf-8")) + + errors = traverse_tree(tree.root_node) + if not errors: + return + return LintResult(text=f"{fname}:{errors[0]}", lines=errors) + + +def extract_error_line_from(lint_error): + # moved from openhands.agentskills#_lint_file + for line in lint_error.splitlines(True): + if line.strip(): + # The format of the error message is: ::: + parts = line.split(":") + if len(parts) >= 2: + try: + first_error_line = int(parts[1]) + break + except ValueError: + continue + return first_error_line + + +def tree_context(fname, code, line_nums): + context = TreeContext( + fname, + code, + color=False, + line_number=True, + child_context=False, + last_line=False, + margin=0, + mark_lois=True, + loi_pad=3, + # header_max=30, + show_top_of_file_parent_scope=False, + ) + line_nums = set(line_nums) + context.add_lines_of_interest(line_nums) + context.add_context() + output = context.format() + + return output + + +# Traverse the tree to find errors +def traverse_tree(node): + errors = [] + if node.type == "ERROR" or node.is_missing: + line_no = node.start_point[0] + 1 + errors.append(line_no) + + for child in node.children: + errors += traverse_tree(child) + + return errors + + +def main(): + """ + Main function to parse files provided as command line arguments. + """ + if len(sys.argv) < 2: + print("Usage: python linter.py ...") + sys.exit(1) + + linter = Linter(root=os.getcwd()) + for file_path in sys.argv[1:]: + errors = linter.lint(file_path) + if errors: + print(errors) + + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/requirements.txt index 23806eb63..e669da46d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -74,3 +74,5 @@ pylint~=3.0.3 pygithub~=2.3 htmlmin fsspec +grep-ast~=0.3.3 +tree-sitter~=0.21.3 \ No newline at end of file diff --git a/tests/metagpt/tools/libs/test_editor.py b/tests/metagpt/tools/libs/test_editor.py index 64149fdb7..6f0861c75 100644 --- a/tests/metagpt/tools/libs/test_editor.py +++ b/tests/metagpt/tools/libs/test_editor.py @@ -1,7 +1,10 @@ +import contextlib +import io + import pytest from metagpt.const import TEST_DATA_PATH -from metagpt.tools.libs.editor import Editor, FileBlock +from metagpt.tools.libs.editor import WINDOW, Editor TEST_FILE_CONTENT = """ # this is line one @@ -25,21 +28,6 @@ def test_file(): f.write("") -EXPECTED_SEARCHED_BLOCK = FileBlock( - file_path=str(TEST_FILE_PATH), - block_content='001|# this is line one\n002|def test_function_for_fm():\n003| "some docstring"\n004| a = 1\n005| b = 2\n', - block_start_line=1, - block_end_line=5, - symbol="def test_function_for_fm", - symbol_line=2, -) - - -def test_search_content(test_file): - block = Editor().search_content("def test_function_for_fm", root_path=TEST_DATA_PATH, window=3) - assert block == EXPECTED_SEARCHED_BLOCK - - EXPECTED_CONTENT_AFTER_REPLACE = """ # this is line one def test_function_for_fm(): @@ -103,28 +91,6 @@ def test_insert_content(test_file): assert new_content == EXPECTED_CONTENT_AFTER_INSERT -@pytest.mark.skip -def test_new_content_wrong_indentation(test_file): - msg = Editor().write_content( - file_path=str(TEST_FILE_PATH), - start_line=3, - end_line=-1, - new_block_content=" This is the new line to be inserted, at line 3", # omit # should throw a syntax error - ) - assert "failed" in msg - - -@pytest.mark.skip -def test_new_content_format_issue(test_file): - msg = Editor().write_content( - file_path=str(TEST_FILE_PATH), - start_line=3, - end_line=-1, - new_block_content=" # This is the new line to be inserted, at line 3 ", # trailing spaces are format issue only, and should not throw an error - ) - assert "failed" not in msg - - @pytest.mark.parametrize( "filename", [ @@ -151,5 +117,453 @@ async def test_read_files(filename): assert file_block.block_content +@pytest.fixture(autouse=True) +def reset_current_file(): + global CURRENT_FILE + CURRENT_FILE = None + + +def _numbered_test_lines(start, end) -> str: + return ("\n".join(f"{i}|" for i in range(start, end + 1))) + "\n" + + +def _generate_test_file_with_lines(temp_path, num_lines) -> str: + file_path = temp_path / "test_file.py" + file_path.write_text("\n" * num_lines) + return file_path + + +def _generate_ruby_test_file_with_lines(temp_path, num_lines) -> str: + file_path = temp_path / "test_file.rb" + file_path.write_text("\n" * num_lines) + return file_path + + +def _calculate_window_bounds(current_line, total_lines, window_size): + half_window = window_size // 2 + if current_line - half_window < 0: + start = 1 + end = window_size + else: + start = current_line - half_window + end = current_line + half_window + return start, end + + +@pytest.mark.asyncio +async def test_open_file_unexist_path(): + editor = Editor() + with pytest.raises(FileNotFoundError): + editor.open_file("/unexist/path/a.txt") + + +@pytest.mark.asyncio +async def test_open_file(tmp_path): + editor = Editor() + assert tmp_path is not None + temp_file_path = tmp_path / "a.txt" + temp_file_path.write_text("Line 1\nLine 2\nLine 3\nLine 4\nLine 5") + + with io.StringIO() as buf: + with contextlib.redirect_stdout(buf): + editor.open_file(str(temp_file_path)) + result = buf.getvalue() + assert result is not None + expected = ( + f"[File: {temp_file_path} (5 lines total)]\n" + "(this is the beginning of the file)\n" + "1|Line 1\n" + "2|Line 2\n" + "3|Line 3\n" + "4|Line 4\n" + "5|Line 5\n" + "(this is the end of the file)\n" + ) + assert result.split("\n") == expected.split("\n") + + +@pytest.mark.asyncio +async def test_open_file_with_indentation(tmp_path): + editor = Editor() + temp_file_path = tmp_path / "a.txt" + temp_file_path.write_text("Line 1\n Line 2\nLine 3\nLine 4\nLine 5") + + with io.StringIO() as buf: + with contextlib.redirect_stdout(buf): + editor.open_file(str(temp_file_path)) + result = buf.getvalue() + assert result is not None + expected = ( + f"[File: {temp_file_path} (5 lines total)]\n" + "(this is the beginning of the file)\n" + "1|Line 1\n" + "2| Line 2\n" + "3|Line 3\n" + "4|Line 4\n" + "5|Line 5\n" + "(this is the end of the file)\n" + ) + assert result.split("\n") == expected.split("\n") + + +@pytest.mark.asyncio +async def test_open_file_long(tmp_path): + editor = Editor() + temp_file_path = tmp_path / "a.txt" + content = "\n".join([f"Line {i}" for i in range(1, 1001)]) + temp_file_path.write_text(content) + + with io.StringIO() as buf: + with contextlib.redirect_stdout(buf): + editor.open_file(str(temp_file_path), 1, 50) + result = buf.getvalue() + assert result is not None + expected = f"[File: {temp_file_path} (1000 lines total)]\n" + expected += "(this is the beginning of the file)\n" + for i in range(1, 51): + expected += f"{i}|Line {i}\n" + expected += "(950 more lines below)\n" + assert result.split("\n") == expected.split("\n") + + +@pytest.mark.asyncio +async def test_open_file_long_with_lineno(tmp_path): + editor = Editor() + temp_file_path = tmp_path / "a.txt" + content = "\n".join([f"Line {i}" for i in range(1, 1001)]) + temp_file_path.write_text(content) + + cur_line = 100 + + with io.StringIO() as buf: + with contextlib.redirect_stdout(buf): + editor.open_file(str(temp_file_path), cur_line) + result = buf.getvalue() + assert result is not None + expected = f"[File: {temp_file_path} (1000 lines total)]\n" + start, end = _calculate_window_bounds(cur_line, 1000, WINDOW) + if start == 1: + expected += "(this is the beginning of the file)\n" + else: + expected += f"({start - 1} more lines above)\n" + for i in range(start, end + 1): + expected += f"{i}|Line {i}\n" + if end == 1000: + expected += "(this is the end of the file)\n" + else: + expected += f"({1000 - end} more lines below)\n" + assert result.split("\n") == expected.split("\n") + + +@pytest.mark.asyncio +async def test_create_file_unexist_path(): + editor = Editor() + with pytest.raises(FileNotFoundError): + editor.create_file("/unexist/path/a.txt") + + +@pytest.mark.asyncio +async def test_create_file(tmp_path): + editor = Editor() + temp_file_path = tmp_path / "a.txt" + with io.StringIO() as buf: + with contextlib.redirect_stdout(buf): + editor.create_file(str(temp_file_path)) + result = buf.getvalue() + + expected = ( + f"[File: {temp_file_path} (1 lines total)]\n" + "(this is the beginning of the file)\n" + "1|\n" + "(this is the end of the file)\n" + f"[File {temp_file_path} created.]\n" + ) + assert result.split("\n") == expected.split("\n") + + +@pytest.mark.asyncio +async def test_goto_line(tmp_path): + editor = Editor() + temp_file_path = tmp_path / "a.txt" + total_lines = 1000 + content = "\n".join([f"Line {i}" for i in range(1, total_lines + 1)]) + temp_file_path.write_text(content) + + with io.StringIO() as buf: + with contextlib.redirect_stdout(buf): + editor.open_file(str(temp_file_path)) + result = buf.getvalue() + assert result is not None + + expected = f"[File: {temp_file_path} ({total_lines} lines total)]\n" + expected += "(this is the beginning of the file)\n" + for i in range(1, WINDOW + 1): + expected += f"{i}|Line {i}\n" + expected += f"({total_lines - WINDOW} more lines below)\n" + assert result.split("\n") == expected.split("\n") + + with io.StringIO() as buf: + with contextlib.redirect_stdout(buf): + editor.goto_line(500) + result = buf.getvalue() + assert result is not None + + cur_line = 500 + expected = f"[File: {temp_file_path} ({total_lines} lines total)]\n" + start, end = _calculate_window_bounds(cur_line, total_lines, WINDOW) + if start == 1: + expected += "(this is the beginning of the file)\n" + else: + expected += f"({start - 1} more lines above)\n" + for i in range(start, end + 1): + expected += f"{i}|Line {i}\n" + if end == total_lines: + expected += "(this is the end of the file)\n" + else: + expected += f"({total_lines - end} more lines below)\n" + assert result.split("\n") == expected.split("\n") + + +@pytest.mark.asyncio +async def test_goto_line_negative(tmp_path): + editor = Editor() + temp_file_path = tmp_path / "a.txt" + content = "\n".join([f"Line {i}" for i in range(1, 5)]) + temp_file_path.write_text(content) + + with io.StringIO() as buf: + with contextlib.redirect_stdout(buf): + editor.open_file(str(temp_file_path)) + with pytest.raises(ValueError): + editor.goto_line(-1) + + +@pytest.mark.asyncio +async def test_goto_line_out_of_bound(tmp_path): + editor = Editor() + temp_file_path = tmp_path / "a.txt" + content = "\n".join([f"Line {i}" for i in range(1, 5)]) + temp_file_path.write_text(content) + + with io.StringIO() as buf: + with contextlib.redirect_stdout(buf): + editor.open_file(str(temp_file_path)) + with pytest.raises(ValueError): + editor.goto_line(100) + + +@pytest.mark.asyncio +async def test_scroll_down(tmp_path): + editor = Editor() + temp_file_path = tmp_path / "a.txt" + total_lines = 1000 + content = "\n".join([f"Line {i}" for i in range(1, total_lines + 1)]) + temp_file_path.write_text(content) + + with io.StringIO() as buf: + with contextlib.redirect_stdout(buf): + editor.open_file(str(temp_file_path)) + result = buf.getvalue() + assert result is not None + + expected = f"[File: {temp_file_path} ({total_lines} lines total)]\n" + start, end = _calculate_window_bounds(1, total_lines, WINDOW) + if start == 1: + expected += "(this is the beginning of the file)\n" + else: + expected += f"({start - 1} more lines above)\n" + for i in range(start, end + 1): + expected += f"{i}|Line {i}\n" + if end == total_lines: + expected += "(this is the end of the file)\n" + else: + expected += f"({total_lines - end} more lines below)\n" + assert result.split("\n") == expected.split("\n") + + with io.StringIO() as buf: + with contextlib.redirect_stdout(buf): + editor.scroll_down() + result = buf.getvalue() + assert result is not None + + expected = f"[File: {temp_file_path} ({total_lines} lines total)]\n" + start, end = _calculate_window_bounds(WINDOW + 1, total_lines, WINDOW) + if start == 1: + expected += "(this is the beginning of the file)\n" + else: + expected += f"({start - 1} more lines above)\n" + for i in range(start, end + 1): + expected += f"{i}|Line {i}\n" + if end == total_lines: + expected += "(this is the end of the file)\n" + else: + expected += f"({total_lines - end} more lines below)\n" + assert result.split("\n") == expected.split("\n") + + +@pytest.mark.asyncio +async def test_scroll_up(tmp_path): + editor = Editor() + temp_file_path = tmp_path / "a.txt" + total_lines = 1000 + content = "\n".join([f"Line {i}" for i in range(1, total_lines + 1)]) + temp_file_path.write_text(content) + + cur_line = 300 + with io.StringIO() as buf: + with contextlib.redirect_stdout(buf): + editor.open_file(str(temp_file_path), cur_line) + result = buf.getvalue() + assert result is not None + + expected = f"[File: {temp_file_path} ({total_lines} lines total)]\n" + start, end = _calculate_window_bounds(cur_line, total_lines, WINDOW) + if start == 1: + expected += "(this is the beginning of the file)\n" + else: + expected += f"({start - 1} more lines above)\n" + for i in range(start, end + 1): + expected += f"{i}|Line {i}\n" + if end == total_lines: + expected += "(this is the end of the file)\n" + else: + expected += f"({total_lines - end} more lines below)\n" + assert result.split("\n") == expected.split("\n") + + with io.StringIO() as buf: + with contextlib.redirect_stdout(buf): + editor.scroll_up() + result = buf.getvalue() + assert result is not None + + cur_line = cur_line - WINDOW + + expected = f"[File: {temp_file_path} ({total_lines} lines total)]\n" + start, end = _calculate_window_bounds(cur_line, total_lines, WINDOW) + if start == 1: + expected += "(this is the beginning of the file)\n" + else: + expected += f"({start - 1} more lines above)\n" + for i in range(start, end + 1): + expected += f"{i}|Line {i}\n" + if end == total_lines: + expected += "(this is the end of the file)\n" + else: + expected += f"({total_lines - end} more lines below)\n" + assert result.split("\n") == expected.split("\n") + + +@pytest.mark.asyncio +async def test_scroll_down_edge(tmp_path): + editor = Editor() + temp_file_path = tmp_path / "a.txt" + content = "\n".join([f"Line {i}" for i in range(1, 10)]) + temp_file_path.write_text(content) + + with io.StringIO() as buf: + with contextlib.redirect_stdout(buf): + editor.open_file(str(temp_file_path)) + result = buf.getvalue() + assert result is not None + + expected = f"[File: {temp_file_path} (9 lines total)]\n" + expected += "(this is the beginning of the file)\n" + for i in range(1, 10): + expected += f"{i}|Line {i}\n" + expected += "(this is the end of the file)\n" + + with io.StringIO() as buf: + with contextlib.redirect_stdout(buf): + editor.scroll_down() + result = buf.getvalue() + assert result is not None + + assert result.split("\n") == expected.split("\n") + + +@pytest.mark.asyncio +async def test_print_window_internal(tmp_path): + editor = Editor() + test_file_path = tmp_path / "a.txt" + await editor.create_file(str(test_file_path)) + editor.open_file(str(test_file_path)) + with open(test_file_path, "w") as file: + for i in range(1, 101): + file.write(f"Line `{i}`\n") + + current_line = 50 + window = 2 + + with io.StringIO() as buf: + with contextlib.redirect_stdout(buf): + editor._print_window(str(test_file_path), current_line, window, return_str=False) + result = buf.getvalue() + expected = "(48 more lines above)\n" "49|Line `49`\n" "50|Line `50`\n" "51|Line `51`\n" "(49 more lines below)\n" + assert result == expected + + +@pytest.mark.asyncio +async def test_open_file_large_line_number(tmp_path): + editor = Editor() + test_file_path = tmp_path / "a.txt" + editor.create_file(str(test_file_path)) + editor.open_file(str(test_file_path)) + with open(test_file_path, "w") as file: + for i in range(1, 1000): + file.write(f"Line `{i}`\n") + + current_line = 800 + window = 100 + + with io.StringIO() as buf: + with contextlib.redirect_stdout(buf): + editor.open_file(str(test_file_path), current_line, window) + result = buf.getvalue() + expected = f"[File: {test_file_path} (999 lines total)]\n" + expected += "(749 more lines above)\n" + for i in range(750, 850 + 1): + expected += f"{i}|Line `{i}`\n" + expected += "(149 more lines below)\n" + assert result == expected + + +@pytest.mark.asyncio +async def test_open_file_large_line_number_consecutive_diff_window(tmp_path): + editor = Editor() + test_file_path = tmp_path / "a.txt" + editor.create_file(str(test_file_path)) + editor.open_file(str(test_file_path)) + total_lines = 1000 + with open(test_file_path, "w") as file: + for i in range(1, total_lines + 1): + file.write(f"Line `{i}`\n") + + current_line = 800 + cur_window = 300 + + with io.StringIO() as buf: + with contextlib.redirect_stdout(buf): + editor.open_file(str(test_file_path), current_line, cur_window) + result = buf.getvalue() + expected = f"[File: {test_file_path} ({total_lines} lines total)]\n" + start, end = _calculate_window_bounds(current_line, total_lines, cur_window) + if start == 1: + expected += "(this is the beginning of the file)\n" + else: + expected += f"({start - 1} more lines above)\n" + for i in range(current_line - cur_window // 2, current_line + cur_window // 2 + 1): + expected += f"{i}|Line `{i}`\n" + if end == total_lines: + expected += "(this is the end of the file)\n" + else: + expected += f"({total_lines - end} more lines below)\n" + assert result == expected + + current_line = current_line - WINDOW + with io.StringIO() as buf: + with contextlib.redirect_stdout(buf): + editor.scroll_up() + + if __name__ == "__main__": pytest.main([__file__, "-s"])