feat: +pdf

This commit is contained in:
莘权 马 2024-09-05 19:24:12 +08:00
parent 285a6bf164
commit ad80dab678
4 changed files with 146 additions and 108 deletions

View file

@ -4,25 +4,21 @@ You can find the original repository here:
https://github.com/All-Hands-AI/OpenHands/blob/main/openhands/runtime/plugins/agent_skills/file_ops/file_ops.py
"""
import asyncio
import base64
import os
import re
import shutil
import tempfile
from pathlib import Path
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Union
from pydantic import BaseModel, ConfigDict
from metagpt.config2 import Config
from metagpt.const import DEFAULT_WORKSPACE_ROOT
from metagpt.logs import logger
from metagpt.tools.libs.index_repo import DEFAULT_MIN_TOKEN_COUNT, OTHER_TYPE, IndexRepo
from metagpt.tools.libs.linter import Linter
from metagpt.tools.tool_registry import register_tool
from metagpt.utils import read_docx
from metagpt.utils.common import aread, aread_bin, awrite_bin, check_http_endpoint
from metagpt.utils.repo_to_markdown import is_text_file
from metagpt.utils.file import File
from metagpt.utils.report import EditorReporter
# This is also used in unit tests!
@ -72,23 +68,12 @@ class Editor(BaseModel):
async def read(self, path: str) -> FileBlock:
"""Read the whole content of a file. Using absolute paths as the argument for specifying the file location."""
is_text, mime_type = await is_text_file(path)
if is_text:
lines = await self._read_text(path)
elif mime_type == "application/pdf":
lines = await self._read_pdf(path)
elif mime_type in {
"application/msword",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/vnd.ms-word.document.macroEnabled.12",
"application/vnd.openxmlformats-officedocument.wordprocessingml.template",
"application/vnd.ms-word.template.macroEnabled.12",
}:
lines = await self._read_docx(path)
else:
content = await File.read_text_file(path)
if not content:
return FileBlock(file_path=str(path), block_content="")
self.resource.report(str(path), "path")
lines = content.splitlines(keepends=True)
lines_with_num = [f"{i + 1:03}|{line}" for i, line in enumerate(lines)]
result = FileBlock(
file_path=str(path),
@ -96,80 +81,6 @@ class Editor(BaseModel):
)
return result
@staticmethod
async def _read_text(path: Union[str, Path]) -> List[str]:
content = await aread(path)
lines = content.split("\n")
return lines
@staticmethod
async def _read_pdf(path: Union[str, Path]) -> List[str]:
result = await Editor._omniparse_read_file(path)
if result:
return result
from llama_index.readers.file import PDFReader
reader = PDFReader()
lines = reader.load_data(file=Path(path))
return [i.text for i in lines]
@staticmethod
async def _read_docx(path: Union[str, Path]) -> List[str]:
result = await Editor._omniparse_read_file(path)
if result:
return result
return read_docx(str(path))
@staticmethod
async def _omniparse_read_file(path: Union[str, Path]) -> Optional[List[str]]:
from metagpt.tools.libs import get_env_default
from metagpt.utils.omniparse_client import OmniParseClient
env_base_url = await get_env_default(key="base_url", app_name="OmniParse", default_value="")
env_timeout = await get_env_default(key="timeout", app_name="OmniParse", default_value="")
conf_base_url, conf_timeout = await Editor._read_omniparse_config()
base_url = env_base_url or conf_base_url
if not base_url:
return None
api_key = await get_env_default(key="api_key", app_name="OmniParse", default_value="")
timeout = env_timeout or conf_timeout or 600
try:
timeout = int(timeout)
except ValueError:
timeout = 600
try:
if not await check_http_endpoint(url=base_url):
logger.warning(f"{base_url}: NOT AVAILABLE")
return None
client = OmniParseClient(api_key=api_key, base_url=base_url, max_timeout=timeout)
file_data = await aread_bin(filename=path)
ret = await client.parse_document(file_input=file_data, bytes_filename=str(path))
except (ValueError, Exception) as e:
logger.exception(f"{path}: {e}")
return None
if not ret.images:
return [ret.text] if ret.text else None
result = [ret.text]
img_dir = Path(path).parent / (Path(path).name.replace(".", "_") + "_images")
img_dir.mkdir(parents=True, exist_ok=True)
for i in ret.images:
byte_data = base64.b64decode(i.image)
filename = img_dir / i.image_name
await awrite_bin(filename=filename, data=byte_data)
result.append(f"![{i.image_name}]({str(filename)})")
return result
@staticmethod
async def _read_omniparse_config() -> Tuple[str, int]:
config = Config.default()
if config.omniparse and config.omniparse.url:
return config.omniparse.url, config.omniparse.timeout
return "", 0
@staticmethod
def _is_valid_filename(file_name: str) -> bool:
if not file_name or not file_name.strip():
@ -985,7 +896,7 @@ class Editor(BaseModel):
futures.append(repo.search(query=query, filenames=list(filenames)))
for i in others:
futures.append(aread(filename=i))
futures.append(File.read_text_file(i))
futures_results = []
if futures:

View file

@ -16,8 +16,8 @@ from metagpt.logs import logger
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.factories.embedding import RAGEmbeddingFactory
from metagpt.rag.schema import FAISSIndexConfig, FAISSRetrieverConfig, LLMRankerConfig
from metagpt.utils.common import aread, awrite, generate_fingerprint, list_files
from metagpt.utils.repo_to_markdown import is_text_file
from metagpt.utils.common import awrite, generate_fingerprint, list_files
from metagpt.utils.file import File
UPLOADS_INDEX_ROOT = "/data/.index/uploads"
DEFAULT_INDEX_ROOT = UPLOADS_INDEX_ROOT
@ -82,13 +82,13 @@ class IndexRepo(BaseModel):
filenames, _ = await self._filter(filenames)
filter_filenames = set()
for i in filenames:
content = await aread(filename=i)
content = await File.read_text_file(i)
token_count = len(encoding.encode(content))
if not self._is_buildable(token_count):
result.append(TextScore(filename=str(i), text=content))
continue
file_fingerprint = generate_fingerprint(content)
if self.fingerprints.get(str(i)) != file_fingerprint:
if self.fingerprints.get(str(i)) != file_fingerprint and Path(i).suffix.lower() not in {".pdf"}:
logger.error(f'file: "{i}" changed but not indexed')
continue
filter_filenames.add(str(i))
@ -107,7 +107,7 @@ class IndexRepo(BaseModel):
Returns:
List[Union[NodeWithScore, TextScore]]: A list of merged results sorted by similarity.
"""
flat_nodes = [node for indices in indices_list for node in indices]
flat_nodes = [node for indices in indices_list if indices for node in indices if node]
if len(flat_nodes) <= self.recall_count:
return flat_nodes
@ -138,7 +138,7 @@ class IndexRepo(BaseModel):
filter_filenames = []
delete_filenames = []
for i in filenames:
content = await aread(filename=i)
content = await File.read_text_file(i)
if not self._is_fingerprint_changed(filename=i, content=content):
continue
token_count = len(encoding.encode(content))
@ -186,7 +186,7 @@ class IndexRepo(BaseModel):
logger.debug(f"add docs {filenames}")
engine.persist(persist_dir=self.persist_path)
for i in filenames:
content = await aread(i)
content = await File.read_text_file(i)
fp = generate_fingerprint(content)
self.fingerprints[str(i)] = fp
await awrite(filename=Path(self.persist_path) / self.fingerprint_filename, data=json.dumps(self.fingerprints))
@ -233,13 +233,13 @@ class IndexRepo(BaseModel):
logger.debug(f"{path} not is_relative_to {root_path})")
continue
if not path.is_dir():
is_text, _ = await is_text_file(path)
is_text = await File.is_textual_file(path)
if is_text:
pathnames.append(path)
continue
subfiles = list_files(path)
for j in subfiles:
is_text, _ = await is_text_file(j)
is_text = await File.is_textual_file(j)
if is_text:
pathnames.append(j)

View file

@ -6,13 +6,19 @@
@File : file.py
@Describe : General file operations.
"""
import base64
from pathlib import Path
from typing import Optional, Tuple, Union
import aiofiles
from fsspec.implementations.memory import MemoryFileSystem as _MemoryFileSystem
from metagpt.config2 import Config
from metagpt.logs import logger
from metagpt.utils import read_docx
from metagpt.utils.common import aread, aread_bin, awrite_bin, check_http_endpoint
from metagpt.utils.exceptions import handle_exception
from metagpt.utils.repo_to_markdown import is_text_file
class File:
@ -70,6 +76,125 @@ class File:
logger.debug(f"Successfully read file, the path of file: {file_path}")
return content
@staticmethod
async def is_textual_file(filename: Union[str, Path]) -> bool:
"""Determines if a given file is a textual file.
A file is considered a textual file if it is plain text or has a
specific set of MIME types associated with textual formats,
including PDF and Microsoft Word documents.
Args:
filename (Union[str, Path]): The path to the file to be checked.
Returns:
bool: True if the file is a textual file, False otherwise.
"""
is_text, mime_type = await is_text_file(filename)
if is_text:
return True
if mime_type == "application/pdf":
return True
if mime_type in {
"application/msword",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/vnd.ms-word.document.macroEnabled.12",
"application/vnd.openxmlformats-officedocument.wordprocessingml.template",
"application/vnd.ms-word.template.macroEnabled.12",
}:
return True
return False
@staticmethod
async def read_text_file(filename: Union[str, Path]) -> Optional[str]:
"""Read the whole content of a file. Using absolute paths as the argument for specifying the file location."""
is_text, mime_type = await is_text_file(filename)
if is_text:
return await File._read_text(filename)
if mime_type == "application/pdf":
return await File._read_pdf(filename)
if mime_type in {
"application/msword",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/vnd.ms-word.document.macroEnabled.12",
"application/vnd.openxmlformats-officedocument.wordprocessingml.template",
"application/vnd.ms-word.template.macroEnabled.12",
}:
return await File._read_docx(filename)
return None
@staticmethod
async def _read_text(path: Union[str, Path]) -> str:
return await aread(path)
@staticmethod
async def _read_pdf(path: Union[str, Path]) -> str:
result = await File._omniparse_read_file(path)
if result:
return result
from llama_index.readers.file import PDFReader
reader = PDFReader()
lines = reader.load_data(file=Path(path))
return "\n".join([i.text for i in lines])
@staticmethod
async def _read_docx(path: Union[str, Path]) -> str:
result = await File._omniparse_read_file(path)
if result:
return result
return "\n".join(read_docx(str(path)))
@staticmethod
async def _omniparse_read_file(path: Union[str, Path], auto_save_image: bool = False) -> Optional[str]:
from metagpt.tools.libs import get_env_default
from metagpt.utils.omniparse_client import OmniParseClient
env_base_url = await get_env_default(key="base_url", app_name="OmniParse", default_value="")
env_timeout = await get_env_default(key="timeout", app_name="OmniParse", default_value="")
conf_base_url, conf_timeout = await File._read_omniparse_config()
base_url = env_base_url or conf_base_url
if not base_url:
return None
api_key = await get_env_default(key="api_key", app_name="OmniParse", default_value="")
timeout = env_timeout or conf_timeout or 600
try:
timeout = int(timeout)
except ValueError:
timeout = 600
try:
if not await check_http_endpoint(url=base_url):
logger.warning(f"{base_url}: NOT AVAILABLE")
return None
client = OmniParseClient(api_key=api_key, base_url=base_url, max_timeout=timeout)
file_data = await aread_bin(filename=path)
ret = await client.parse_document(file_input=file_data, bytes_filename=str(path))
except (ValueError, Exception) as e:
logger.exception(f"{path}: {e}")
return None
if not ret.images or not auto_save_image:
return ret.text
result = [ret.text]
img_dir = Path(path).parent / (Path(path).name.replace(".", "_") + "_images")
img_dir.mkdir(parents=True, exist_ok=True)
for i in ret.images:
byte_data = base64.b64decode(i.image)
filename = img_dir / i.image_name
await awrite_bin(filename=filename, data=byte_data)
result.append(f"![{i.image_name}]({str(filename)})")
return "\n".join(result)
@staticmethod
async def _read_omniparse_config() -> Tuple[str, int]:
config = Config.default()
if config.omniparse and config.omniparse.url:
return config.omniparse.url, config.omniparse.timeout
return "", 0
class MemoryFileSystem(_MemoryFileSystem):
@classmethod

View file

@ -665,7 +665,7 @@ async def mock_index_repo():
command = f"cp -rf {str(src_path)} {str(chat_path)}"
os.system(command)
filenames = list_files(chat_path)
chat_files = [i for i in filenames if Path(i).suffix in {".md", ".txt", ".json"}]
chat_files = [i for i in filenames if Path(i).suffix in {".md", ".txt", ".json", ".pdf"}]
chat_repo = IndexRepo(
persist_path=str(Path(CHATS_INDEX_ROOT) / chat_id), root_path=str(chat_path), min_token_count=0
)
@ -675,12 +675,12 @@ async def mock_index_repo():
command = f"cp -rf {str(src_path)} {str(UPLOAD_ROOT)}"
os.system(command)
filenames = list_files(UPLOAD_ROOT)
uploads_files = [i for i in filenames if Path(i).suffix in {".md", ".txt", ".json"}]
uploads_files = [i for i in filenames if Path(i).suffix in {".md", ".txt", ".json", ".pdf"}]
uploads_repo = IndexRepo(persist_path=UPLOADS_INDEX_ROOT, root_path=UPLOAD_ROOT, min_token_count=0)
await uploads_repo.add(uploads_files)
filenames = list_files(src_path)
other_files = [i for i in filenames if Path(i).suffix in {".md", ".txt", ".json"}]
other_files = [i for i in filenames if Path(i).suffix in {".md", ".txt", ".json", ".pdf"}]
return chat_files, uploads_files, other_files
@ -692,7 +692,9 @@ async def test_index_repo():
chat_files, uploads_files, other_files = await mock_index_repo()
editor = Editor()
rsp = await editor.vsearch(query="业务线", files_or_paths=chat_files + uploads_files + other_files, min_token_count=0)
rsp = await editor.search_index_repo(
query="业务线", files_or_paths=chat_files + uploads_files + other_files, min_token_count=0
)
assert rsp
shutil.rmtree(CHATS_ROOT)