From ad80dab67859385e911816390034979e8a2f4ac3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Thu, 5 Sep 2024 19:24:12 +0800 Subject: [PATCH] feat: +pdf --- metagpt/tools/libs/editor.py | 101 ++----------------- metagpt/tools/libs/index_repo.py | 18 ++-- metagpt/utils/file.py | 125 ++++++++++++++++++++++++ tests/metagpt/tools/libs/test_editor.py | 10 +- 4 files changed, 146 insertions(+), 108 deletions(-) diff --git a/metagpt/tools/libs/editor.py b/metagpt/tools/libs/editor.py index 9a1cf63c3..cf56eb292 100644 --- a/metagpt/tools/libs/editor.py +++ b/metagpt/tools/libs/editor.py @@ -4,25 +4,21 @@ You can find the original repository here: https://github.com/All-Hands-AI/OpenHands/blob/main/openhands/runtime/plugins/agent_skills/file_ops/file_ops.py """ import asyncio -import base64 import os import re import shutil import tempfile from pathlib import Path -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Union from pydantic import BaseModel, ConfigDict -from metagpt.config2 import Config from metagpt.const import DEFAULT_WORKSPACE_ROOT from metagpt.logs import logger from metagpt.tools.libs.index_repo import DEFAULT_MIN_TOKEN_COUNT, OTHER_TYPE, IndexRepo from metagpt.tools.libs.linter import Linter from metagpt.tools.tool_registry import register_tool -from metagpt.utils import read_docx -from metagpt.utils.common import aread, aread_bin, awrite_bin, check_http_endpoint -from metagpt.utils.repo_to_markdown import is_text_file +from metagpt.utils.file import File from metagpt.utils.report import EditorReporter # This is also used in unit tests! @@ -72,23 +68,12 @@ class Editor(BaseModel): async def read(self, path: str) -> FileBlock: """Read the whole content of a file. Using absolute paths as the argument for specifying the file location.""" - is_text, mime_type = await is_text_file(path) - if is_text: - lines = await self._read_text(path) - elif mime_type == "application/pdf": - lines = await self._read_pdf(path) - elif mime_type in { - "application/msword", - "application/vnd.openxmlformats-officedocument.wordprocessingml.document", - "application/vnd.ms-word.document.macroEnabled.12", - "application/vnd.openxmlformats-officedocument.wordprocessingml.template", - "application/vnd.ms-word.template.macroEnabled.12", - }: - lines = await self._read_docx(path) - else: + content = await File.read_text_file(path) + if not content: return FileBlock(file_path=str(path), block_content="") self.resource.report(str(path), "path") + lines = content.splitlines(keepends=True) lines_with_num = [f"{i + 1:03}|{line}" for i, line in enumerate(lines)] result = FileBlock( file_path=str(path), @@ -96,80 +81,6 @@ class Editor(BaseModel): ) return result - @staticmethod - async def _read_text(path: Union[str, Path]) -> List[str]: - content = await aread(path) - lines = content.split("\n") - return lines - - @staticmethod - async def _read_pdf(path: Union[str, Path]) -> List[str]: - result = await Editor._omniparse_read_file(path) - if result: - return result - - from llama_index.readers.file import PDFReader - - reader = PDFReader() - lines = reader.load_data(file=Path(path)) - return [i.text for i in lines] - - @staticmethod - async def _read_docx(path: Union[str, Path]) -> List[str]: - result = await Editor._omniparse_read_file(path) - if result: - return result - return read_docx(str(path)) - - @staticmethod - async def _omniparse_read_file(path: Union[str, Path]) -> Optional[List[str]]: - from metagpt.tools.libs import get_env_default - from metagpt.utils.omniparse_client import OmniParseClient - - env_base_url = await get_env_default(key="base_url", app_name="OmniParse", default_value="") - env_timeout = await get_env_default(key="timeout", app_name="OmniParse", default_value="") - conf_base_url, conf_timeout = await Editor._read_omniparse_config() - - base_url = env_base_url or conf_base_url - if not base_url: - return None - api_key = await get_env_default(key="api_key", app_name="OmniParse", default_value="") - timeout = env_timeout or conf_timeout or 600 - try: - timeout = int(timeout) - except ValueError: - timeout = 600 - - try: - if not await check_http_endpoint(url=base_url): - logger.warning(f"{base_url}: NOT AVAILABLE") - return None - client = OmniParseClient(api_key=api_key, base_url=base_url, max_timeout=timeout) - file_data = await aread_bin(filename=path) - ret = await client.parse_document(file_input=file_data, bytes_filename=str(path)) - except (ValueError, Exception) as e: - logger.exception(f"{path}: {e}") - return None - if not ret.images: - return [ret.text] if ret.text else None - - result = [ret.text] - img_dir = Path(path).parent / (Path(path).name.replace(".", "_") + "_images") - img_dir.mkdir(parents=True, exist_ok=True) - for i in ret.images: - byte_data = base64.b64decode(i.image) - filename = img_dir / i.image_name - await awrite_bin(filename=filename, data=byte_data) - result.append(f"![{i.image_name}]({str(filename)})") - return result - - @staticmethod - async def _read_omniparse_config() -> Tuple[str, int]: - config = Config.default() - if config.omniparse and config.omniparse.url: - return config.omniparse.url, config.omniparse.timeout - return "", 0 - @staticmethod def _is_valid_filename(file_name: str) -> bool: if not file_name or not file_name.strip(): @@ -985,7 +896,7 @@ class Editor(BaseModel): futures.append(repo.search(query=query, filenames=list(filenames))) for i in others: - futures.append(aread(filename=i)) + futures.append(File.read_text_file(i)) futures_results = [] if futures: diff --git a/metagpt/tools/libs/index_repo.py b/metagpt/tools/libs/index_repo.py index b4907d74e..d29254318 100644 --- a/metagpt/tools/libs/index_repo.py +++ b/metagpt/tools/libs/index_repo.py @@ -16,8 +16,8 @@ from metagpt.logs import logger from metagpt.rag.engines import SimpleEngine from metagpt.rag.factories.embedding import RAGEmbeddingFactory from metagpt.rag.schema import FAISSIndexConfig, FAISSRetrieverConfig, LLMRankerConfig -from metagpt.utils.common import aread, awrite, generate_fingerprint, list_files -from metagpt.utils.repo_to_markdown import is_text_file +from metagpt.utils.common import awrite, generate_fingerprint, list_files +from metagpt.utils.file import File UPLOADS_INDEX_ROOT = "/data/.index/uploads" DEFAULT_INDEX_ROOT = UPLOADS_INDEX_ROOT @@ -82,13 +82,13 @@ class IndexRepo(BaseModel): filenames, _ = await self._filter(filenames) filter_filenames = set() for i in filenames: - content = await aread(filename=i) + content = await File.read_text_file(i) token_count = len(encoding.encode(content)) if not self._is_buildable(token_count): result.append(TextScore(filename=str(i), text=content)) continue file_fingerprint = generate_fingerprint(content) - if self.fingerprints.get(str(i)) != file_fingerprint: + if self.fingerprints.get(str(i)) != file_fingerprint and Path(i).suffix.lower() not in {".pdf"}: logger.error(f'file: "{i}" changed but not indexed') continue filter_filenames.add(str(i)) @@ -107,7 +107,7 @@ class IndexRepo(BaseModel): Returns: List[Union[NodeWithScore, TextScore]]: A list of merged results sorted by similarity. """ - flat_nodes = [node for indices in indices_list for node in indices] + flat_nodes = [node for indices in indices_list if indices for node in indices if node] if len(flat_nodes) <= self.recall_count: return flat_nodes @@ -138,7 +138,7 @@ class IndexRepo(BaseModel): filter_filenames = [] delete_filenames = [] for i in filenames: - content = await aread(filename=i) + content = await File.read_text_file(i) if not self._is_fingerprint_changed(filename=i, content=content): continue token_count = len(encoding.encode(content)) @@ -186,7 +186,7 @@ class IndexRepo(BaseModel): logger.debug(f"add docs {filenames}") engine.persist(persist_dir=self.persist_path) for i in filenames: - content = await aread(i) + content = await File.read_text_file(i) fp = generate_fingerprint(content) self.fingerprints[str(i)] = fp await awrite(filename=Path(self.persist_path) / self.fingerprint_filename, data=json.dumps(self.fingerprints)) @@ -233,13 +233,13 @@ class IndexRepo(BaseModel): logger.debug(f"{path} not is_relative_to {root_path})") continue if not path.is_dir(): - is_text, _ = await is_text_file(path) + is_text = await File.is_textual_file(path) if is_text: pathnames.append(path) continue subfiles = list_files(path) for j in subfiles: - is_text, _ = await is_text_file(j) + is_text = await File.is_textual_file(j) if is_text: pathnames.append(j) diff --git a/metagpt/utils/file.py b/metagpt/utils/file.py index 8861f65dc..a3f612bcc 100644 --- a/metagpt/utils/file.py +++ b/metagpt/utils/file.py @@ -6,13 +6,19 @@ @File : file.py @Describe : General file operations. """ +import base64 from pathlib import Path +from typing import Optional, Tuple, Union import aiofiles from fsspec.implementations.memory import MemoryFileSystem as _MemoryFileSystem +from metagpt.config2 import Config from metagpt.logs import logger +from metagpt.utils import read_docx +from metagpt.utils.common import aread, aread_bin, awrite_bin, check_http_endpoint from metagpt.utils.exceptions import handle_exception +from metagpt.utils.repo_to_markdown import is_text_file class File: @@ -70,6 +76,125 @@ class File: logger.debug(f"Successfully read file, the path of file: {file_path}") return content + @staticmethod + async def is_textual_file(filename: Union[str, Path]) -> bool: + """Determines if a given file is a textual file. + + A file is considered a textual file if it is plain text or has a + specific set of MIME types associated with textual formats, + including PDF and Microsoft Word documents. + + Args: + filename (Union[str, Path]): The path to the file to be checked. + + Returns: + bool: True if the file is a textual file, False otherwise. + """ + is_text, mime_type = await is_text_file(filename) + if is_text: + return True + if mime_type == "application/pdf": + return True + if mime_type in { + "application/msword", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "application/vnd.ms-word.document.macroEnabled.12", + "application/vnd.openxmlformats-officedocument.wordprocessingml.template", + "application/vnd.ms-word.template.macroEnabled.12", + }: + return True + return False + + @staticmethod + async def read_text_file(filename: Union[str, Path]) -> Optional[str]: + """Read the whole content of a file. Using absolute paths as the argument for specifying the file location.""" + is_text, mime_type = await is_text_file(filename) + if is_text: + return await File._read_text(filename) + if mime_type == "application/pdf": + return await File._read_pdf(filename) + if mime_type in { + "application/msword", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "application/vnd.ms-word.document.macroEnabled.12", + "application/vnd.openxmlformats-officedocument.wordprocessingml.template", + "application/vnd.ms-word.template.macroEnabled.12", + }: + return await File._read_docx(filename) + return None + + @staticmethod + async def _read_text(path: Union[str, Path]) -> str: + return await aread(path) + + @staticmethod + async def _read_pdf(path: Union[str, Path]) -> str: + result = await File._omniparse_read_file(path) + if result: + return result + + from llama_index.readers.file import PDFReader + + reader = PDFReader() + lines = reader.load_data(file=Path(path)) + return "\n".join([i.text for i in lines]) + + @staticmethod + async def _read_docx(path: Union[str, Path]) -> str: + result = await File._omniparse_read_file(path) + if result: + return result + return "\n".join(read_docx(str(path))) + + @staticmethod + async def _omniparse_read_file(path: Union[str, Path], auto_save_image: bool = False) -> Optional[str]: + from metagpt.tools.libs import get_env_default + from metagpt.utils.omniparse_client import OmniParseClient + + env_base_url = await get_env_default(key="base_url", app_name="OmniParse", default_value="") + env_timeout = await get_env_default(key="timeout", app_name="OmniParse", default_value="") + conf_base_url, conf_timeout = await File._read_omniparse_config() + + base_url = env_base_url or conf_base_url + if not base_url: + return None + api_key = await get_env_default(key="api_key", app_name="OmniParse", default_value="") + timeout = env_timeout or conf_timeout or 600 + try: + timeout = int(timeout) + except ValueError: + timeout = 600 + + try: + if not await check_http_endpoint(url=base_url): + logger.warning(f"{base_url}: NOT AVAILABLE") + return None + client = OmniParseClient(api_key=api_key, base_url=base_url, max_timeout=timeout) + file_data = await aread_bin(filename=path) + ret = await client.parse_document(file_input=file_data, bytes_filename=str(path)) + except (ValueError, Exception) as e: + logger.exception(f"{path}: {e}") + return None + if not ret.images or not auto_save_image: + return ret.text + + result = [ret.text] + img_dir = Path(path).parent / (Path(path).name.replace(".", "_") + "_images") + img_dir.mkdir(parents=True, exist_ok=True) + for i in ret.images: + byte_data = base64.b64decode(i.image) + filename = img_dir / i.image_name + await awrite_bin(filename=filename, data=byte_data) + result.append(f"![{i.image_name}]({str(filename)})") + return "\n".join(result) + + @staticmethod + async def _read_omniparse_config() -> Tuple[str, int]: + config = Config.default() + if config.omniparse and config.omniparse.url: + return config.omniparse.url, config.omniparse.timeout + return "", 0 + class MemoryFileSystem(_MemoryFileSystem): @classmethod diff --git a/tests/metagpt/tools/libs/test_editor.py b/tests/metagpt/tools/libs/test_editor.py index 26d53a703..3a4cf65fe 100644 --- a/tests/metagpt/tools/libs/test_editor.py +++ b/tests/metagpt/tools/libs/test_editor.py @@ -665,7 +665,7 @@ async def mock_index_repo(): command = f"cp -rf {str(src_path)} {str(chat_path)}" os.system(command) filenames = list_files(chat_path) - chat_files = [i for i in filenames if Path(i).suffix in {".md", ".txt", ".json"}] + chat_files = [i for i in filenames if Path(i).suffix in {".md", ".txt", ".json", ".pdf"}] chat_repo = IndexRepo( persist_path=str(Path(CHATS_INDEX_ROOT) / chat_id), root_path=str(chat_path), min_token_count=0 ) @@ -675,12 +675,12 @@ async def mock_index_repo(): command = f"cp -rf {str(src_path)} {str(UPLOAD_ROOT)}" os.system(command) filenames = list_files(UPLOAD_ROOT) - uploads_files = [i for i in filenames if Path(i).suffix in {".md", ".txt", ".json"}] + uploads_files = [i for i in filenames if Path(i).suffix in {".md", ".txt", ".json", ".pdf"}] uploads_repo = IndexRepo(persist_path=UPLOADS_INDEX_ROOT, root_path=UPLOAD_ROOT, min_token_count=0) await uploads_repo.add(uploads_files) filenames = list_files(src_path) - other_files = [i for i in filenames if Path(i).suffix in {".md", ".txt", ".json"}] + other_files = [i for i in filenames if Path(i).suffix in {".md", ".txt", ".json", ".pdf"}] return chat_files, uploads_files, other_files @@ -692,7 +692,9 @@ async def test_index_repo(): chat_files, uploads_files, other_files = await mock_index_repo() editor = Editor() - rsp = await editor.vsearch(query="业务线", files_or_paths=chat_files + uploads_files + other_files, min_token_count=0) + rsp = await editor.search_index_repo( + query="业务线", files_or_paths=chat_files + uploads_files + other_files, min_token_count=0 + ) assert rsp shutil.rmtree(CHATS_ROOT)