diff --git a/examples/invoice_ocr.py b/examples/invoice_ocr.py index 734cb2b4d..11656ed52 100644 --- a/examples/invoice_ocr.py +++ b/examples/invoice_ocr.py @@ -8,7 +8,7 @@ """ import asyncio -import os +from pathlib import Path from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant from metagpt.schema import Message @@ -16,15 +16,13 @@ from metagpt.schema import Message async def main(): relative_paths = [ - "../tests/data/invoices/invoice-1.pdf", - "../tests/data/invoices/invoice-2.png", - "../tests/data/invoices/invoice-3.jpg", - "../tests/data/invoices/invoice-4.zip" + Path("../tests/data/invoices/invoice-1.pdf"), + Path("../tests/data/invoices/invoice-2.png"), + Path("../tests/data/invoices/invoice-3.jpg"), + Path("../tests/data/invoices/invoice-4.zip") ] - # Get the current working directory - current_directory = os.getcwd() # The absolute path of the file - absolute_file_paths = [os.path.abspath(os.path.join(current_directory, path)) for path in relative_paths] + absolute_file_paths = [Path.cwd() / path for path in relative_paths] for path in absolute_file_paths: role = InvoiceOCRAssistant() diff --git a/metagpt/actions/invoice_ocr.py b/metagpt/actions/invoice_ocr.py index 172a2b981..2532543d9 100644 --- a/metagpt/actions/invoice_ocr.py +++ b/metagpt/actions/invoice_ocr.py @@ -10,12 +10,10 @@ import os import zipfile -from enum import Enum from pathlib import Path +from datetime import datetime import pandas as pd - -from typing import Dict, List from paddleocr import PaddleOCR from metagpt.actions import Action @@ -26,36 +24,6 @@ from metagpt.utils.common import OutputParser from metagpt.utils.file import File -class FileExtensionType(Enum): - """Enum representing file extensions and their associated types. - Each enum member consists of a tuple containing the file extension and its associated type. - - """ - - ZIP = (".zip", "zip") - PDF = (".pdf", "pdf") - PNG = (".png", "png") - JPG = (".jpg", "jpg") - - @classmethod - def get_extension_list(cls) -> List[str]: - """Get a list of file extensions. - - Returns: - A list of file extensions as strings. - """ - return [ext.value[0] for ext in cls] - - @classmethod - def get_type_list(cls) -> List[str]: - """Get a list of file types. - - Returns: - A list of file types as strings. - """ - return [ext.value[1] for ext in cls] - - class InvoiceOCR(Action): """Action class for performing OCR on invoice files, including zip, PDF, png, and jpg files. @@ -69,11 +37,11 @@ class InvoiceOCR(Action): super().__init__(name, *args, **kwargs) @staticmethod - async def _check_file_type(filename: str) -> str: + async def _check_file_type(file_path: Path) -> str: """Check the file type of the given filename. Args: - filename: The name of the file. + file_path: The path of the file. Returns: The file type based on FileExtensionType enum. @@ -81,16 +49,11 @@ class InvoiceOCR(Action): Raises: Exception: If the file format is not zip, pdf, png, or jpg. """ - file_ext = None - for ext in FileExtensionType: - if filename.endswith(ext.value[0]): - file_ext = ext.value[1] - break - - if not file_ext: + ext = file_path.suffix + if ext not in [".zip", ".pdf", ".png", ".jpg"]: raise Exception("The invoice format is not zip, pdf, png, or jpg") - return file_ext + return ext @staticmethod async def _unzip(file_path: Path) -> Path: @@ -102,51 +65,52 @@ class InvoiceOCR(Action): Returns: The path to the unzipped directory. """ - file_directory = file_path.parent - with zipfile.ZipFile(file_path, 'r') as zip_ref: + file_directory = file_path.parent / "unzip_invoices" / datetime.now().strftime("%Y%m%d%H%M%S") + with zipfile.ZipFile(file_path, "r") as zip_ref: for zip_info in zip_ref.infolist(): # Use CP437 to encode the file name, and then use GBK decoding to prevent Chinese garbled code - relative_name = zip_info.filename.encode('cp437').decode('gbk') - unzip_dir, name = relative_name.split("/") - if name: + relative_name = Path(zip_info.filename.encode("cp437").decode("gbk")) + if relative_name.suffix: + name = relative_name.name full_filename = file_directory / relative_name await File.write(full_filename.parent, name, zip_ref.read(zip_info.filename)) - unzip_path = file_directory / unzip_dir - logger.info(f"unzip_path: {unzip_path}") - return unzip_path + logger.info(f"unzip_path: {file_directory}") + return file_directory - async def run(self, file_path: Path, filename: str, *args, **kwargs) -> list: + @staticmethod + async def _ocr(invoice_file_path: Path): + ocr = PaddleOCR(use_angle_cls=True, lang="ch", page_num=1) + ocr_result = ocr.ocr(str(invoice_file_path), cls=True) + return ocr_result + + async def run(self, file_path: Path, *args, **kwargs) -> list: """Execute the action to identify invoice files through OCR. Args: file_path: The path to the input file. - filename: The name of the input file. Returns: A list of OCR results. """ - file_ext = await self._check_file_type(filename) + file_ext = await self._check_file_type(file_path) - if file_ext == FileExtensionType.ZIP.value[1]: + if file_ext == ".zip": # OCR recognizes zip batch files unzip_path = await self._unzip(file_path) - file_list = os.listdir(unzip_path) ocr_list = [] - - for filename in file_list: - invoice_file_path = unzip_path / filename - # Identify files that match the type - if filename.split(".")[-1] in FileExtensionType.get_type_list(): - ocr = PaddleOCR(use_angle_cls=True, lang="ch", page_num=1) - ocr_result = ocr.ocr(str(invoice_file_path), cls=True) - ocr_list.append(ocr_result) + for root, _, files in os.walk(unzip_path): + for filename in files: + invoice_file_path = Path(root) / Path(filename) + # Identify files that match the type + if Path(filename).suffix in [".zip", ".pdf", ".png", ".jpg"]: + ocr_result = await self._ocr(str(invoice_file_path)) + ocr_list.append(ocr_result) return ocr_list else: # OCR identifies single file - ocr = PaddleOCR(use_angle_cls=True, lang="ch", page_num=1) - ocr_result = ocr.ocr(str(file_path), cls=True) + ocr_result = await self._ocr(file_path) return [ocr_result] @@ -163,7 +127,7 @@ class GenerateTable(Action): super().__init__(name, *args, **kwargs) self.language = language - async def run(self, ocr_results: list, filename: str, *args, **kwargs) -> Dict[str, str]: + async def run(self, ocr_results: list, filename: str, *args, **kwargs) -> dict[str, str]: """Processes OCR results, extracts invoice information, generates a table, and saves it as an Excel file. Args: diff --git a/metagpt/roles/invoice_ocr_assistant.py b/metagpt/roles/invoice_ocr_assistant.py index ede41dfc2..c307b20c0 100644 --- a/metagpt/roles/invoice_ocr_assistant.py +++ b/metagpt/roles/invoice_ocr_assistant.py @@ -7,10 +7,8 @@ @File : invoice_ocr_assistant.py """ -import os -from pathlib import Path - import pandas as pd + from metagpt.actions.invoice_ocr import InvoiceOCR, GenerateTable, ReplyQuestion from metagpt.prompts.invoice_ocr import INVOICE_OCR_SUCCESS from metagpt.roles import Role @@ -33,7 +31,7 @@ class InvoiceOCRAssistant(Role): def __init__( self, name: str = "Stitch", - profile: str = "Invoice Ocr Assistant", + profile: str = "Invoice OCR Assistant", goal: str = "OCR identifies invoice files and generates invoice main information table", constraints: str = "", language: str = "ch", @@ -67,11 +65,11 @@ class InvoiceOCRAssistant(Role): if isinstance(todo, InvoiceOCR): self.origin_query = msg.content file_path = msg.instruct_content.get("file_path") - self.filename = os.path.basename(file_path) + self.filename = file_path.name if not file_path: raise Exception("Invoice file not uploaded") - resp = await todo.run(Path(file_path), self.filename) + resp = await todo.run(file_path) if len(resp) == 1: # Single file support for questioning based on OCR recognition results self._init_actions([GenerateTable, ReplyQuestion]) diff --git a/requirements-ocr.txt b/requirements-ocr.txt new file mode 100644 index 000000000..cf6103afc --- /dev/null +++ b/requirements-ocr.txt @@ -0,0 +1,4 @@ +paddlepaddle==2.4.2 +paddleocr>=2.0.1 +tabulate==0.9.0 +-r requirements.txt diff --git a/requirements.txt b/requirements.txt index cf258cd06..de861ded9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,3 @@ -paddlepaddle==2.4.2 -paddleocr>=2.0.1 aiohttp==3.8.4 #azure_storage==0.37.0 channels==4.0.0 @@ -44,5 +42,4 @@ pytest-mock==3.11.1 open-interpreter==0.1.4; python_version>"3.9" ta==0.10.2 semantic-kernel==0.3.10.dev0 -tabulate==0.9.0 diff --git a/tests/metagpt/actions/test_invoice_ocr.py b/tests/metagpt/actions/test_invoice_ocr.py index 0433a2a89..a15166f7c 100644 --- a/tests/metagpt/actions/test_invoice_ocr.py +++ b/tests/metagpt/actions/test_invoice_ocr.py @@ -8,10 +8,10 @@ """ import os -from pathlib import Path from typing import List import pytest +from pathlib import Path from metagpt.actions.invoice_ocr import InvoiceOCR, GenerateTable, ReplyQuestion @@ -48,7 +48,7 @@ async def test_invoice_ocr(invoice_path: str): ), ] ) -async def test_generate_table(invoice_path: str, expected_result: List[dict]): +async def test_generate_table(invoice_path: str, expected_result: list[dict]): invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path)) filename = os.path.basename(invoice_path) ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename) diff --git a/tests/metagpt/roles/test_invoice_ocr_assistant.py b/tests/metagpt/roles/test_invoice_ocr_assistant.py index 32493b831..75097e73c 100644 --- a/tests/metagpt/roles/test_invoice_ocr_assistant.py +++ b/tests/metagpt/roles/test_invoice_ocr_assistant.py @@ -7,11 +7,10 @@ @File : test_invoice_ocr_assistant.py """ -import os -import pandas as pd -from typing import List +from pathlib import Path import pytest +import pandas as pd from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant from metagpt.schema import Message @@ -23,8 +22,8 @@ from metagpt.schema import Message [ ( "Invoicing date", - "../../data/invoices/invoice-1.pdf", - "../../../data/invoice_table/invoice-1.xlsx", + Path("../../data/invoices/invoice-1.pdf"), + Path("../../../data/invoice_table/invoice-1.xlsx"), [ { "收款人": "小明", @@ -36,8 +35,8 @@ from metagpt.schema import Message ), ( "Invoicing date", - "../../data/invoices/invoice-2.png", - "../../../data/invoice_table/invoice-2.xlsx", + Path("../../data/invoices/invoice-2.png"), + Path("../../../data/invoice_table/invoice-2.xlsx"), [ { "收款人": "铁头", @@ -49,8 +48,8 @@ from metagpt.schema import Message ), ( "Invoicing date", - "../../data/invoices/invoice-3.jpg", - "../../../data/invoice_table/invoice-3.xlsx", + Path("../../data/invoices/invoice-3.jpg"), + Path("../../../data/invoice_table/invoice-3.xlsx"), [ { "收款人": "夏天", @@ -62,8 +61,8 @@ from metagpt.schema import Message ), ( "Invoicing date", - "../../data/invoices/invoice-4.zip", - "../../../data/invoice_table/invoice-4.xlsx", + Path("../../data/invoices/invoice-4.zip"), + Path("../../../data/invoice_table/invoice-4.xlsx"), [ { "收款人": "小明", @@ -89,17 +88,17 @@ from metagpt.schema import Message ) async def test_invoice_ocr_assistant( query: str, - invoice_path: str, - invoice_table_path: str, - expected_result: List[dict] + invoice_path: Path, + invoice_table_path: Path, + expected_result: list[dict] ): - invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path)) + invoice_path = Path.cwd() / invoice_path role = InvoiceOCRAssistant() await role.run(Message( content=query, instruct_content={"file_path": invoice_path} )) - invoice_table_path = os.path.abspath(os.path.join(os.getcwd(), invoice_table_path)) + invoice_table_path = Path.cwd() / invoice_table_path df = pd.read_excel(invoice_table_path) dict_result = df.to_dict(orient='records') assert dict_result == expected_result