mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-28 18:36:22 +02:00
update: resolve comments and support any directory zip file
This commit is contained in:
parent
d8afe9ea86
commit
5e27ee4bad
7 changed files with 62 additions and 102 deletions
|
|
@ -10,12 +10,10 @@
|
|||
|
||||
import os
|
||||
import zipfile
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from typing import Dict, List
|
||||
from paddleocr import PaddleOCR
|
||||
|
||||
from metagpt.actions import Action
|
||||
|
|
@ -26,36 +24,6 @@ from metagpt.utils.common import OutputParser
|
|||
from metagpt.utils.file import File
|
||||
|
||||
|
||||
class FileExtensionType(Enum):
|
||||
"""Enum representing file extensions and their associated types.
|
||||
Each enum member consists of a tuple containing the file extension and its associated type.
|
||||
|
||||
"""
|
||||
|
||||
ZIP = (".zip", "zip")
|
||||
PDF = (".pdf", "pdf")
|
||||
PNG = (".png", "png")
|
||||
JPG = (".jpg", "jpg")
|
||||
|
||||
@classmethod
|
||||
def get_extension_list(cls) -> List[str]:
|
||||
"""Get a list of file extensions.
|
||||
|
||||
Returns:
|
||||
A list of file extensions as strings.
|
||||
"""
|
||||
return [ext.value[0] for ext in cls]
|
||||
|
||||
@classmethod
|
||||
def get_type_list(cls) -> List[str]:
|
||||
"""Get a list of file types.
|
||||
|
||||
Returns:
|
||||
A list of file types as strings.
|
||||
"""
|
||||
return [ext.value[1] for ext in cls]
|
||||
|
||||
|
||||
class InvoiceOCR(Action):
|
||||
"""Action class for performing OCR on invoice files, including zip, PDF, png, and jpg files.
|
||||
|
||||
|
|
@ -69,11 +37,11 @@ class InvoiceOCR(Action):
|
|||
super().__init__(name, *args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
async def _check_file_type(filename: str) -> str:
|
||||
async def _check_file_type(file_path: Path) -> str:
|
||||
"""Check the file type of the given filename.
|
||||
|
||||
Args:
|
||||
filename: The name of the file.
|
||||
file_path: The path of the file.
|
||||
|
||||
Returns:
|
||||
The file type based on FileExtensionType enum.
|
||||
|
|
@ -81,16 +49,11 @@ class InvoiceOCR(Action):
|
|||
Raises:
|
||||
Exception: If the file format is not zip, pdf, png, or jpg.
|
||||
"""
|
||||
file_ext = None
|
||||
for ext in FileExtensionType:
|
||||
if filename.endswith(ext.value[0]):
|
||||
file_ext = ext.value[1]
|
||||
break
|
||||
|
||||
if not file_ext:
|
||||
ext = file_path.suffix
|
||||
if ext not in [".zip", ".pdf", ".png", ".jpg"]:
|
||||
raise Exception("The invoice format is not zip, pdf, png, or jpg")
|
||||
|
||||
return file_ext
|
||||
return ext
|
||||
|
||||
@staticmethod
|
||||
async def _unzip(file_path: Path) -> Path:
|
||||
|
|
@ -102,51 +65,52 @@ class InvoiceOCR(Action):
|
|||
Returns:
|
||||
The path to the unzipped directory.
|
||||
"""
|
||||
file_directory = file_path.parent
|
||||
with zipfile.ZipFile(file_path, 'r') as zip_ref:
|
||||
file_directory = file_path.parent / "unzip_invoices" / datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
with zipfile.ZipFile(file_path, "r") as zip_ref:
|
||||
for zip_info in zip_ref.infolist():
|
||||
# Use CP437 to encode the file name, and then use GBK decoding to prevent Chinese garbled code
|
||||
relative_name = zip_info.filename.encode('cp437').decode('gbk')
|
||||
unzip_dir, name = relative_name.split("/")
|
||||
if name:
|
||||
relative_name = Path(zip_info.filename.encode("cp437").decode("gbk"))
|
||||
if relative_name.suffix:
|
||||
name = relative_name.name
|
||||
full_filename = file_directory / relative_name
|
||||
await File.write(full_filename.parent, name, zip_ref.read(zip_info.filename))
|
||||
|
||||
unzip_path = file_directory / unzip_dir
|
||||
logger.info(f"unzip_path: {unzip_path}")
|
||||
return unzip_path
|
||||
logger.info(f"unzip_path: {file_directory}")
|
||||
return file_directory
|
||||
|
||||
async def run(self, file_path: Path, filename: str, *args, **kwargs) -> list:
|
||||
@staticmethod
|
||||
async def _ocr(invoice_file_path: Path):
|
||||
ocr = PaddleOCR(use_angle_cls=True, lang="ch", page_num=1)
|
||||
ocr_result = ocr.ocr(str(invoice_file_path), cls=True)
|
||||
return ocr_result
|
||||
|
||||
async def run(self, file_path: Path, *args, **kwargs) -> list:
|
||||
"""Execute the action to identify invoice files through OCR.
|
||||
|
||||
Args:
|
||||
file_path: The path to the input file.
|
||||
filename: The name of the input file.
|
||||
|
||||
Returns:
|
||||
A list of OCR results.
|
||||
"""
|
||||
file_ext = await self._check_file_type(filename)
|
||||
file_ext = await self._check_file_type(file_path)
|
||||
|
||||
if file_ext == FileExtensionType.ZIP.value[1]:
|
||||
if file_ext == ".zip":
|
||||
# OCR recognizes zip batch files
|
||||
unzip_path = await self._unzip(file_path)
|
||||
file_list = os.listdir(unzip_path)
|
||||
ocr_list = []
|
||||
|
||||
for filename in file_list:
|
||||
invoice_file_path = unzip_path / filename
|
||||
# Identify files that match the type
|
||||
if filename.split(".")[-1] in FileExtensionType.get_type_list():
|
||||
ocr = PaddleOCR(use_angle_cls=True, lang="ch", page_num=1)
|
||||
ocr_result = ocr.ocr(str(invoice_file_path), cls=True)
|
||||
ocr_list.append(ocr_result)
|
||||
for root, _, files in os.walk(unzip_path):
|
||||
for filename in files:
|
||||
invoice_file_path = Path(root) / Path(filename)
|
||||
# Identify files that match the type
|
||||
if Path(filename).suffix in [".zip", ".pdf", ".png", ".jpg"]:
|
||||
ocr_result = await self._ocr(str(invoice_file_path))
|
||||
ocr_list.append(ocr_result)
|
||||
return ocr_list
|
||||
|
||||
else:
|
||||
# OCR identifies single file
|
||||
ocr = PaddleOCR(use_angle_cls=True, lang="ch", page_num=1)
|
||||
ocr_result = ocr.ocr(str(file_path), cls=True)
|
||||
ocr_result = await self._ocr(file_path)
|
||||
return [ocr_result]
|
||||
|
||||
|
||||
|
|
@ -163,7 +127,7 @@ class GenerateTable(Action):
|
|||
super().__init__(name, *args, **kwargs)
|
||||
self.language = language
|
||||
|
||||
async def run(self, ocr_results: list, filename: str, *args, **kwargs) -> Dict[str, str]:
|
||||
async def run(self, ocr_results: list, filename: str, *args, **kwargs) -> dict[str, str]:
|
||||
"""Processes OCR results, extracts invoice information, generates a table, and saves it as an Excel file.
|
||||
|
||||
Args:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue