mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-28 02:23:52 +02:00
update: resolve comments and support any directory zip file
This commit is contained in:
parent
d8afe9ea86
commit
5e27ee4bad
7 changed files with 62 additions and 102 deletions
|
|
@ -8,7 +8,7 @@
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant
|
||||
from metagpt.schema import Message
|
||||
|
|
@ -16,15 +16,13 @@ from metagpt.schema import Message
|
|||
|
||||
async def main():
|
||||
relative_paths = [
|
||||
"../tests/data/invoices/invoice-1.pdf",
|
||||
"../tests/data/invoices/invoice-2.png",
|
||||
"../tests/data/invoices/invoice-3.jpg",
|
||||
"../tests/data/invoices/invoice-4.zip"
|
||||
Path("../tests/data/invoices/invoice-1.pdf"),
|
||||
Path("../tests/data/invoices/invoice-2.png"),
|
||||
Path("../tests/data/invoices/invoice-3.jpg"),
|
||||
Path("../tests/data/invoices/invoice-4.zip")
|
||||
]
|
||||
# Get the current working directory
|
||||
current_directory = os.getcwd()
|
||||
# The absolute path of the file
|
||||
absolute_file_paths = [os.path.abspath(os.path.join(current_directory, path)) for path in relative_paths]
|
||||
absolute_file_paths = [Path.cwd() / path for path in relative_paths]
|
||||
|
||||
for path in absolute_file_paths:
|
||||
role = InvoiceOCRAssistant()
|
||||
|
|
|
|||
|
|
@ -10,12 +10,10 @@
|
|||
|
||||
import os
|
||||
import zipfile
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from typing import Dict, List
|
||||
from paddleocr import PaddleOCR
|
||||
|
||||
from metagpt.actions import Action
|
||||
|
|
@ -26,36 +24,6 @@ from metagpt.utils.common import OutputParser
|
|||
from metagpt.utils.file import File
|
||||
|
||||
|
||||
class FileExtensionType(Enum):
|
||||
"""Enum representing file extensions and their associated types.
|
||||
Each enum member consists of a tuple containing the file extension and its associated type.
|
||||
|
||||
"""
|
||||
|
||||
ZIP = (".zip", "zip")
|
||||
PDF = (".pdf", "pdf")
|
||||
PNG = (".png", "png")
|
||||
JPG = (".jpg", "jpg")
|
||||
|
||||
@classmethod
|
||||
def get_extension_list(cls) -> List[str]:
|
||||
"""Get a list of file extensions.
|
||||
|
||||
Returns:
|
||||
A list of file extensions as strings.
|
||||
"""
|
||||
return [ext.value[0] for ext in cls]
|
||||
|
||||
@classmethod
|
||||
def get_type_list(cls) -> List[str]:
|
||||
"""Get a list of file types.
|
||||
|
||||
Returns:
|
||||
A list of file types as strings.
|
||||
"""
|
||||
return [ext.value[1] for ext in cls]
|
||||
|
||||
|
||||
class InvoiceOCR(Action):
|
||||
"""Action class for performing OCR on invoice files, including zip, PDF, png, and jpg files.
|
||||
|
||||
|
|
@ -69,11 +37,11 @@ class InvoiceOCR(Action):
|
|||
super().__init__(name, *args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
async def _check_file_type(filename: str) -> str:
|
||||
async def _check_file_type(file_path: Path) -> str:
|
||||
"""Check the file type of the given filename.
|
||||
|
||||
Args:
|
||||
filename: The name of the file.
|
||||
file_path: The path of the file.
|
||||
|
||||
Returns:
|
||||
The file type based on FileExtensionType enum.
|
||||
|
|
@ -81,16 +49,11 @@ class InvoiceOCR(Action):
|
|||
Raises:
|
||||
Exception: If the file format is not zip, pdf, png, or jpg.
|
||||
"""
|
||||
file_ext = None
|
||||
for ext in FileExtensionType:
|
||||
if filename.endswith(ext.value[0]):
|
||||
file_ext = ext.value[1]
|
||||
break
|
||||
|
||||
if not file_ext:
|
||||
ext = file_path.suffix
|
||||
if ext not in [".zip", ".pdf", ".png", ".jpg"]:
|
||||
raise Exception("The invoice format is not zip, pdf, png, or jpg")
|
||||
|
||||
return file_ext
|
||||
return ext
|
||||
|
||||
@staticmethod
|
||||
async def _unzip(file_path: Path) -> Path:
|
||||
|
|
@ -102,51 +65,52 @@ class InvoiceOCR(Action):
|
|||
Returns:
|
||||
The path to the unzipped directory.
|
||||
"""
|
||||
file_directory = file_path.parent
|
||||
with zipfile.ZipFile(file_path, 'r') as zip_ref:
|
||||
file_directory = file_path.parent / "unzip_invoices" / datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
with zipfile.ZipFile(file_path, "r") as zip_ref:
|
||||
for zip_info in zip_ref.infolist():
|
||||
# Use CP437 to encode the file name, and then use GBK decoding to prevent Chinese garbled code
|
||||
relative_name = zip_info.filename.encode('cp437').decode('gbk')
|
||||
unzip_dir, name = relative_name.split("/")
|
||||
if name:
|
||||
relative_name = Path(zip_info.filename.encode("cp437").decode("gbk"))
|
||||
if relative_name.suffix:
|
||||
name = relative_name.name
|
||||
full_filename = file_directory / relative_name
|
||||
await File.write(full_filename.parent, name, zip_ref.read(zip_info.filename))
|
||||
|
||||
unzip_path = file_directory / unzip_dir
|
||||
logger.info(f"unzip_path: {unzip_path}")
|
||||
return unzip_path
|
||||
logger.info(f"unzip_path: {file_directory}")
|
||||
return file_directory
|
||||
|
||||
async def run(self, file_path: Path, filename: str, *args, **kwargs) -> list:
|
||||
@staticmethod
|
||||
async def _ocr(invoice_file_path: Path):
|
||||
ocr = PaddleOCR(use_angle_cls=True, lang="ch", page_num=1)
|
||||
ocr_result = ocr.ocr(str(invoice_file_path), cls=True)
|
||||
return ocr_result
|
||||
|
||||
async def run(self, file_path: Path, *args, **kwargs) -> list:
|
||||
"""Execute the action to identify invoice files through OCR.
|
||||
|
||||
Args:
|
||||
file_path: The path to the input file.
|
||||
filename: The name of the input file.
|
||||
|
||||
Returns:
|
||||
A list of OCR results.
|
||||
"""
|
||||
file_ext = await self._check_file_type(filename)
|
||||
file_ext = await self._check_file_type(file_path)
|
||||
|
||||
if file_ext == FileExtensionType.ZIP.value[1]:
|
||||
if file_ext == ".zip":
|
||||
# OCR recognizes zip batch files
|
||||
unzip_path = await self._unzip(file_path)
|
||||
file_list = os.listdir(unzip_path)
|
||||
ocr_list = []
|
||||
|
||||
for filename in file_list:
|
||||
invoice_file_path = unzip_path / filename
|
||||
# Identify files that match the type
|
||||
if filename.split(".")[-1] in FileExtensionType.get_type_list():
|
||||
ocr = PaddleOCR(use_angle_cls=True, lang="ch", page_num=1)
|
||||
ocr_result = ocr.ocr(str(invoice_file_path), cls=True)
|
||||
ocr_list.append(ocr_result)
|
||||
for root, _, files in os.walk(unzip_path):
|
||||
for filename in files:
|
||||
invoice_file_path = Path(root) / Path(filename)
|
||||
# Identify files that match the type
|
||||
if Path(filename).suffix in [".zip", ".pdf", ".png", ".jpg"]:
|
||||
ocr_result = await self._ocr(str(invoice_file_path))
|
||||
ocr_list.append(ocr_result)
|
||||
return ocr_list
|
||||
|
||||
else:
|
||||
# OCR identifies single file
|
||||
ocr = PaddleOCR(use_angle_cls=True, lang="ch", page_num=1)
|
||||
ocr_result = ocr.ocr(str(file_path), cls=True)
|
||||
ocr_result = await self._ocr(file_path)
|
||||
return [ocr_result]
|
||||
|
||||
|
||||
|
|
@ -163,7 +127,7 @@ class GenerateTable(Action):
|
|||
super().__init__(name, *args, **kwargs)
|
||||
self.language = language
|
||||
|
||||
async def run(self, ocr_results: list, filename: str, *args, **kwargs) -> Dict[str, str]:
|
||||
async def run(self, ocr_results: list, filename: str, *args, **kwargs) -> dict[str, str]:
|
||||
"""Processes OCR results, extracts invoice information, generates a table, and saves it as an Excel file.
|
||||
|
||||
Args:
|
||||
|
|
|
|||
|
|
@ -7,10 +7,8 @@
|
|||
@File : invoice_ocr_assistant.py
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from metagpt.actions.invoice_ocr import InvoiceOCR, GenerateTable, ReplyQuestion
|
||||
from metagpt.prompts.invoice_ocr import INVOICE_OCR_SUCCESS
|
||||
from metagpt.roles import Role
|
||||
|
|
@ -33,7 +31,7 @@ class InvoiceOCRAssistant(Role):
|
|||
def __init__(
|
||||
self,
|
||||
name: str = "Stitch",
|
||||
profile: str = "Invoice Ocr Assistant",
|
||||
profile: str = "Invoice OCR Assistant",
|
||||
goal: str = "OCR identifies invoice files and generates invoice main information table",
|
||||
constraints: str = "",
|
||||
language: str = "ch",
|
||||
|
|
@ -67,11 +65,11 @@ class InvoiceOCRAssistant(Role):
|
|||
if isinstance(todo, InvoiceOCR):
|
||||
self.origin_query = msg.content
|
||||
file_path = msg.instruct_content.get("file_path")
|
||||
self.filename = os.path.basename(file_path)
|
||||
self.filename = file_path.name
|
||||
if not file_path:
|
||||
raise Exception("Invoice file not uploaded")
|
||||
|
||||
resp = await todo.run(Path(file_path), self.filename)
|
||||
resp = await todo.run(file_path)
|
||||
if len(resp) == 1:
|
||||
# Single file support for questioning based on OCR recognition results
|
||||
self._init_actions([GenerateTable, ReplyQuestion])
|
||||
|
|
|
|||
4
requirements-ocr.txt
Normal file
4
requirements-ocr.txt
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
paddlepaddle==2.4.2
|
||||
paddleocr>=2.0.1
|
||||
tabulate==0.9.0
|
||||
-r requirements.txt
|
||||
|
|
@ -1,5 +1,3 @@
|
|||
paddlepaddle==2.4.2
|
||||
paddleocr>=2.0.1
|
||||
aiohttp==3.8.4
|
||||
#azure_storage==0.37.0
|
||||
channels==4.0.0
|
||||
|
|
@ -44,5 +42,4 @@ pytest-mock==3.11.1
|
|||
open-interpreter==0.1.4; python_version>"3.9"
|
||||
ta==0.10.2
|
||||
semantic-kernel==0.3.10.dev0
|
||||
tabulate==0.9.0
|
||||
|
||||
|
|
|
|||
|
|
@ -8,10 +8,10 @@
|
|||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
from metagpt.actions.invoice_ocr import InvoiceOCR, GenerateTable, ReplyQuestion
|
||||
|
||||
|
|
@ -48,7 +48,7 @@ async def test_invoice_ocr(invoice_path: str):
|
|||
),
|
||||
]
|
||||
)
|
||||
async def test_generate_table(invoice_path: str, expected_result: List[dict]):
|
||||
async def test_generate_table(invoice_path: str, expected_result: list[dict]):
|
||||
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
|
||||
filename = os.path.basename(invoice_path)
|
||||
ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)
|
||||
|
|
|
|||
|
|
@ -7,11 +7,10 @@
|
|||
@File : test_invoice_ocr_assistant.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import pandas as pd
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
|
||||
from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant
|
||||
from metagpt.schema import Message
|
||||
|
|
@ -23,8 +22,8 @@ from metagpt.schema import Message
|
|||
[
|
||||
(
|
||||
"Invoicing date",
|
||||
"../../data/invoices/invoice-1.pdf",
|
||||
"../../../data/invoice_table/invoice-1.xlsx",
|
||||
Path("../../data/invoices/invoice-1.pdf"),
|
||||
Path("../../../data/invoice_table/invoice-1.xlsx"),
|
||||
[
|
||||
{
|
||||
"收款人": "小明",
|
||||
|
|
@ -36,8 +35,8 @@ from metagpt.schema import Message
|
|||
),
|
||||
(
|
||||
"Invoicing date",
|
||||
"../../data/invoices/invoice-2.png",
|
||||
"../../../data/invoice_table/invoice-2.xlsx",
|
||||
Path("../../data/invoices/invoice-2.png"),
|
||||
Path("../../../data/invoice_table/invoice-2.xlsx"),
|
||||
[
|
||||
{
|
||||
"收款人": "铁头",
|
||||
|
|
@ -49,8 +48,8 @@ from metagpt.schema import Message
|
|||
),
|
||||
(
|
||||
"Invoicing date",
|
||||
"../../data/invoices/invoice-3.jpg",
|
||||
"../../../data/invoice_table/invoice-3.xlsx",
|
||||
Path("../../data/invoices/invoice-3.jpg"),
|
||||
Path("../../../data/invoice_table/invoice-3.xlsx"),
|
||||
[
|
||||
{
|
||||
"收款人": "夏天",
|
||||
|
|
@ -62,8 +61,8 @@ from metagpt.schema import Message
|
|||
),
|
||||
(
|
||||
"Invoicing date",
|
||||
"../../data/invoices/invoice-4.zip",
|
||||
"../../../data/invoice_table/invoice-4.xlsx",
|
||||
Path("../../data/invoices/invoice-4.zip"),
|
||||
Path("../../../data/invoice_table/invoice-4.xlsx"),
|
||||
[
|
||||
{
|
||||
"收款人": "小明",
|
||||
|
|
@ -89,17 +88,17 @@ from metagpt.schema import Message
|
|||
)
|
||||
async def test_invoice_ocr_assistant(
|
||||
query: str,
|
||||
invoice_path: str,
|
||||
invoice_table_path: str,
|
||||
expected_result: List[dict]
|
||||
invoice_path: Path,
|
||||
invoice_table_path: Path,
|
||||
expected_result: list[dict]
|
||||
):
|
||||
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
|
||||
invoice_path = Path.cwd() / invoice_path
|
||||
role = InvoiceOCRAssistant()
|
||||
await role.run(Message(
|
||||
content=query,
|
||||
instruct_content={"file_path": invoice_path}
|
||||
))
|
||||
invoice_table_path = os.path.abspath(os.path.join(os.getcwd(), invoice_table_path))
|
||||
invoice_table_path = Path.cwd() / invoice_table_path
|
||||
df = pd.read_excel(invoice_table_path)
|
||||
dict_result = df.to_dict(orient='records')
|
||||
assert dict_result == expected_result
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue