update: resolve comments and support any directory zip file

This commit is contained in:
Stitch-z 2023-10-24 14:01:33 +08:00
parent d8afe9ea86
commit 5e27ee4bad
7 changed files with 62 additions and 102 deletions

View file

@ -8,7 +8,7 @@
"""
import asyncio
import os
from pathlib import Path
from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant
from metagpt.schema import Message
@ -16,15 +16,13 @@ from metagpt.schema import Message
async def main():
relative_paths = [
"../tests/data/invoices/invoice-1.pdf",
"../tests/data/invoices/invoice-2.png",
"../tests/data/invoices/invoice-3.jpg",
"../tests/data/invoices/invoice-4.zip"
Path("../tests/data/invoices/invoice-1.pdf"),
Path("../tests/data/invoices/invoice-2.png"),
Path("../tests/data/invoices/invoice-3.jpg"),
Path("../tests/data/invoices/invoice-4.zip")
]
# Get the current working directory
current_directory = os.getcwd()
# The absolute path of the file
absolute_file_paths = [os.path.abspath(os.path.join(current_directory, path)) for path in relative_paths]
absolute_file_paths = [Path.cwd() / path for path in relative_paths]
for path in absolute_file_paths:
role = InvoiceOCRAssistant()

View file

@ -10,12 +10,10 @@
import os
import zipfile
from enum import Enum
from pathlib import Path
from datetime import datetime
import pandas as pd
from typing import Dict, List
from paddleocr import PaddleOCR
from metagpt.actions import Action
@ -26,36 +24,6 @@ from metagpt.utils.common import OutputParser
from metagpt.utils.file import File
class FileExtensionType(Enum):
"""Enum representing file extensions and their associated types.
Each enum member consists of a tuple containing the file extension and its associated type.
"""
ZIP = (".zip", "zip")
PDF = (".pdf", "pdf")
PNG = (".png", "png")
JPG = (".jpg", "jpg")
@classmethod
def get_extension_list(cls) -> List[str]:
"""Get a list of file extensions.
Returns:
A list of file extensions as strings.
"""
return [ext.value[0] for ext in cls]
@classmethod
def get_type_list(cls) -> List[str]:
"""Get a list of file types.
Returns:
A list of file types as strings.
"""
return [ext.value[1] for ext in cls]
class InvoiceOCR(Action):
"""Action class for performing OCR on invoice files, including zip, PDF, png, and jpg files.
@ -69,11 +37,11 @@ class InvoiceOCR(Action):
super().__init__(name, *args, **kwargs)
@staticmethod
async def _check_file_type(filename: str) -> str:
async def _check_file_type(file_path: Path) -> str:
"""Check the file type of the given filename.
Args:
filename: The name of the file.
file_path: The path of the file.
Returns:
The file type based on FileExtensionType enum.
@ -81,16 +49,11 @@ class InvoiceOCR(Action):
Raises:
Exception: If the file format is not zip, pdf, png, or jpg.
"""
file_ext = None
for ext in FileExtensionType:
if filename.endswith(ext.value[0]):
file_ext = ext.value[1]
break
if not file_ext:
ext = file_path.suffix
if ext not in [".zip", ".pdf", ".png", ".jpg"]:
raise Exception("The invoice format is not zip, pdf, png, or jpg")
return file_ext
return ext
@staticmethod
async def _unzip(file_path: Path) -> Path:
@ -102,51 +65,52 @@ class InvoiceOCR(Action):
Returns:
The path to the unzipped directory.
"""
file_directory = file_path.parent
with zipfile.ZipFile(file_path, 'r') as zip_ref:
file_directory = file_path.parent / "unzip_invoices" / datetime.now().strftime("%Y%m%d%H%M%S")
with zipfile.ZipFile(file_path, "r") as zip_ref:
for zip_info in zip_ref.infolist():
# Use CP437 to encode the file name, and then use GBK decoding to prevent Chinese garbled code
relative_name = zip_info.filename.encode('cp437').decode('gbk')
unzip_dir, name = relative_name.split("/")
if name:
relative_name = Path(zip_info.filename.encode("cp437").decode("gbk"))
if relative_name.suffix:
name = relative_name.name
full_filename = file_directory / relative_name
await File.write(full_filename.parent, name, zip_ref.read(zip_info.filename))
unzip_path = file_directory / unzip_dir
logger.info(f"unzip_path: {unzip_path}")
return unzip_path
logger.info(f"unzip_path: {file_directory}")
return file_directory
async def run(self, file_path: Path, filename: str, *args, **kwargs) -> list:
@staticmethod
async def _ocr(invoice_file_path: Path):
ocr = PaddleOCR(use_angle_cls=True, lang="ch", page_num=1)
ocr_result = ocr.ocr(str(invoice_file_path), cls=True)
return ocr_result
async def run(self, file_path: Path, *args, **kwargs) -> list:
"""Execute the action to identify invoice files through OCR.
Args:
file_path: The path to the input file.
filename: The name of the input file.
Returns:
A list of OCR results.
"""
file_ext = await self._check_file_type(filename)
file_ext = await self._check_file_type(file_path)
if file_ext == FileExtensionType.ZIP.value[1]:
if file_ext == ".zip":
# OCR recognizes zip batch files
unzip_path = await self._unzip(file_path)
file_list = os.listdir(unzip_path)
ocr_list = []
for filename in file_list:
invoice_file_path = unzip_path / filename
# Identify files that match the type
if filename.split(".")[-1] in FileExtensionType.get_type_list():
ocr = PaddleOCR(use_angle_cls=True, lang="ch", page_num=1)
ocr_result = ocr.ocr(str(invoice_file_path), cls=True)
ocr_list.append(ocr_result)
for root, _, files in os.walk(unzip_path):
for filename in files:
invoice_file_path = Path(root) / Path(filename)
# Identify files that match the type
if Path(filename).suffix in [".zip", ".pdf", ".png", ".jpg"]:
ocr_result = await self._ocr(str(invoice_file_path))
ocr_list.append(ocr_result)
return ocr_list
else:
# OCR identifies single file
ocr = PaddleOCR(use_angle_cls=True, lang="ch", page_num=1)
ocr_result = ocr.ocr(str(file_path), cls=True)
ocr_result = await self._ocr(file_path)
return [ocr_result]
@ -163,7 +127,7 @@ class GenerateTable(Action):
super().__init__(name, *args, **kwargs)
self.language = language
async def run(self, ocr_results: list, filename: str, *args, **kwargs) -> Dict[str, str]:
async def run(self, ocr_results: list, filename: str, *args, **kwargs) -> dict[str, str]:
"""Processes OCR results, extracts invoice information, generates a table, and saves it as an Excel file.
Args:

View file

@ -7,10 +7,8 @@
@File : invoice_ocr_assistant.py
"""
import os
from pathlib import Path
import pandas as pd
from metagpt.actions.invoice_ocr import InvoiceOCR, GenerateTable, ReplyQuestion
from metagpt.prompts.invoice_ocr import INVOICE_OCR_SUCCESS
from metagpt.roles import Role
@ -33,7 +31,7 @@ class InvoiceOCRAssistant(Role):
def __init__(
self,
name: str = "Stitch",
profile: str = "Invoice Ocr Assistant",
profile: str = "Invoice OCR Assistant",
goal: str = "OCR identifies invoice files and generates invoice main information table",
constraints: str = "",
language: str = "ch",
@ -67,11 +65,11 @@ class InvoiceOCRAssistant(Role):
if isinstance(todo, InvoiceOCR):
self.origin_query = msg.content
file_path = msg.instruct_content.get("file_path")
self.filename = os.path.basename(file_path)
self.filename = file_path.name
if not file_path:
raise Exception("Invoice file not uploaded")
resp = await todo.run(Path(file_path), self.filename)
resp = await todo.run(file_path)
if len(resp) == 1:
# Single file support for questioning based on OCR recognition results
self._init_actions([GenerateTable, ReplyQuestion])

4
requirements-ocr.txt Normal file
View file

@ -0,0 +1,4 @@
paddlepaddle==2.4.2
paddleocr>=2.0.1
tabulate==0.9.0
-r requirements.txt

View file

@ -1,5 +1,3 @@
paddlepaddle==2.4.2
paddleocr>=2.0.1
aiohttp==3.8.4
#azure_storage==0.37.0
channels==4.0.0
@ -44,5 +42,4 @@ pytest-mock==3.11.1
open-interpreter==0.1.4; python_version>"3.9"
ta==0.10.2
semantic-kernel==0.3.10.dev0
tabulate==0.9.0

View file

@ -8,10 +8,10 @@
"""
import os
from pathlib import Path
from typing import List
import pytest
from pathlib import Path
from metagpt.actions.invoice_ocr import InvoiceOCR, GenerateTable, ReplyQuestion
@ -48,7 +48,7 @@ async def test_invoice_ocr(invoice_path: str):
),
]
)
async def test_generate_table(invoice_path: str, expected_result: List[dict]):
async def test_generate_table(invoice_path: str, expected_result: list[dict]):
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
filename = os.path.basename(invoice_path)
ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)

View file

@ -7,11 +7,10 @@
@File : test_invoice_ocr_assistant.py
"""
import os
import pandas as pd
from typing import List
from pathlib import Path
import pytest
import pandas as pd
from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant
from metagpt.schema import Message
@ -23,8 +22,8 @@ from metagpt.schema import Message
[
(
"Invoicing date",
"../../data/invoices/invoice-1.pdf",
"../../../data/invoice_table/invoice-1.xlsx",
Path("../../data/invoices/invoice-1.pdf"),
Path("../../../data/invoice_table/invoice-1.xlsx"),
[
{
"收款人": "小明",
@ -36,8 +35,8 @@ from metagpt.schema import Message
),
(
"Invoicing date",
"../../data/invoices/invoice-2.png",
"../../../data/invoice_table/invoice-2.xlsx",
Path("../../data/invoices/invoice-2.png"),
Path("../../../data/invoice_table/invoice-2.xlsx"),
[
{
"收款人": "铁头",
@ -49,8 +48,8 @@ from metagpt.schema import Message
),
(
"Invoicing date",
"../../data/invoices/invoice-3.jpg",
"../../../data/invoice_table/invoice-3.xlsx",
Path("../../data/invoices/invoice-3.jpg"),
Path("../../../data/invoice_table/invoice-3.xlsx"),
[
{
"收款人": "夏天",
@ -62,8 +61,8 @@ from metagpt.schema import Message
),
(
"Invoicing date",
"../../data/invoices/invoice-4.zip",
"../../../data/invoice_table/invoice-4.xlsx",
Path("../../data/invoices/invoice-4.zip"),
Path("../../../data/invoice_table/invoice-4.xlsx"),
[
{
"收款人": "小明",
@ -89,17 +88,17 @@ from metagpt.schema import Message
)
async def test_invoice_ocr_assistant(
query: str,
invoice_path: str,
invoice_table_path: str,
expected_result: List[dict]
invoice_path: Path,
invoice_table_path: Path,
expected_result: list[dict]
):
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
invoice_path = Path.cwd() / invoice_path
role = InvoiceOCRAssistant()
await role.run(Message(
content=query,
instruct_content={"file_path": invoice_path}
))
invoice_table_path = os.path.abspath(os.path.join(os.getcwd(), invoice_table_path))
invoice_table_path = Path.cwd() / invoice_table_path
df = pd.read_excel(invoice_table_path)
dict_result = df.to_dict(orient='records')
assert dict_result == expected_result