diff --git a/Dockerfile b/Dockerfile index 8ab180e28..c6e22989b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,7 +3,7 @@ FROM nikolaik/python-nodejs:python3.9-nodejs20-slim # Install Debian software needed by MetaGPT and clean up in one RUN command to reduce image size RUN apt update &&\ - apt install -y git chromium fonts-ipafont-gothic fonts-wqy-zenhei fonts-thai-tlwg fonts-kacst fonts-freefont-ttf libxss1 --no-install-recommends &&\ + apt install -y libgomp1 git chromium fonts-ipafont-gothic fonts-wqy-zenhei fonts-thai-tlwg fonts-kacst fonts-freefont-ttf libxss1 --no-install-recommends &&\ apt clean && rm -rf /var/lib/apt/lists/* # Install Mermaid CLI globally diff --git a/examples/invoice_ocr.py b/examples/invoice_ocr.py new file mode 100644 index 000000000..11656ed52 --- /dev/null +++ b/examples/invoice_ocr.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +# _*_ coding: utf-8 _*_ + +""" +@Time : 2023/9/21 21:40:57 +@Author : Stitch-z +@File : invoice_ocr.py +""" + +import asyncio +from pathlib import Path + +from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant +from metagpt.schema import Message + + +async def main(): + relative_paths = [ + Path("../tests/data/invoices/invoice-1.pdf"), + Path("../tests/data/invoices/invoice-2.png"), + Path("../tests/data/invoices/invoice-3.jpg"), + Path("../tests/data/invoices/invoice-4.zip") + ] + # The absolute path of the file + absolute_file_paths = [Path.cwd() / path for path in relative_paths] + + for path in absolute_file_paths: + role = InvoiceOCRAssistant() + await role.run(Message( + content="Invoicing date", + instruct_content={"file_path": path} + )) + + +if __name__ == '__main__': + asyncio.run(main()) + diff --git a/metagpt/actions/invoice_ocr.py b/metagpt/actions/invoice_ocr.py new file mode 100644 index 000000000..b37aa6885 --- /dev/null +++ b/metagpt/actions/invoice_ocr.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python3 +# _*_ coding: utf-8 _*_ + +""" +@Time : 2023/9/21 18:10:20 +@Author : Stitch-z +@File : invoice_ocr.py +@Describe : Actions of the invoice ocr assistant. +""" + +import os +import zipfile +from pathlib import Path +from datetime import datetime + +import pandas as pd +from paddleocr import PaddleOCR + +from metagpt.actions import Action +from metagpt.const import INVOICE_OCR_TABLE_PATH +from metagpt.logs import logger +from metagpt.prompts.invoice_ocr import EXTRACT_OCR_MAIN_INFO_PROMPT, REPLY_OCR_QUESTION_PROMPT +from metagpt.utils.common import OutputParser +from metagpt.utils.file import File + + +class InvoiceOCR(Action): + """Action class for performing OCR on invoice files, including zip, PDF, png, and jpg files. + + Args: + name: The name of the action. Defaults to an empty string. + language: The language for OCR output. Defaults to "ch" (Chinese). + + """ + + def __init__(self, name: str = "", *args, **kwargs): + super().__init__(name, *args, **kwargs) + + @staticmethod + async def _check_file_type(file_path: Path) -> str: + """Check the file type of the given filename. + + Args: + file_path: The path of the file. + + Returns: + The file type based on FileExtensionType enum. + + Raises: + Exception: If the file format is not zip, pdf, png, or jpg. + """ + ext = file_path.suffix + if ext not in [".zip", ".pdf", ".png", ".jpg"]: + raise Exception("The invoice format is not zip, pdf, png, or jpg") + + return ext + + @staticmethod + async def _unzip(file_path: Path) -> Path: + """Unzip a file and return the path to the unzipped directory. + + Args: + file_path: The path to the zip file. + + Returns: + The path to the unzipped directory. + """ + file_directory = file_path.parent / "unzip_invoices" / datetime.now().strftime("%Y%m%d%H%M%S") + with zipfile.ZipFile(file_path, "r") as zip_ref: + for zip_info in zip_ref.infolist(): + # Use CP437 to encode the file name, and then use GBK decoding to prevent Chinese garbled code + relative_name = Path(zip_info.filename.encode("cp437").decode("gbk")) + if relative_name.suffix: + full_filename = file_directory / relative_name + await File.write(full_filename.parent, relative_name.name, zip_ref.read(zip_info.filename)) + + logger.info(f"unzip_path: {file_directory}") + return file_directory + + @staticmethod + async def _ocr(invoice_file_path: Path): + ocr = PaddleOCR(use_angle_cls=True, lang="ch", page_num=1) + ocr_result = ocr.ocr(str(invoice_file_path), cls=True) + return ocr_result + + async def run(self, file_path: Path, *args, **kwargs) -> list: + """Execute the action to identify invoice files through OCR. + + Args: + file_path: The path to the input file. + + Returns: + A list of OCR results. + """ + file_ext = await self._check_file_type(file_path) + + if file_ext == ".zip": + # OCR recognizes zip batch files + unzip_path = await self._unzip(file_path) + ocr_list = [] + for root, _, files in os.walk(unzip_path): + for filename in files: + invoice_file_path = Path(root) / Path(filename) + # Identify files that match the type + if Path(filename).suffix in [".zip", ".pdf", ".png", ".jpg"]: + ocr_result = await self._ocr(str(invoice_file_path)) + ocr_list.append(ocr_result) + return ocr_list + + else: + # OCR identifies single file + ocr_result = await self._ocr(file_path) + return [ocr_result] + + +class GenerateTable(Action): + """Action class for generating tables from OCR results. + + Args: + name: The name of the action. Defaults to an empty string. + language: The language used for the generated table. Defaults to "ch" (Chinese). + + """ + + def __init__(self, name: str = "", language: str = "ch", *args, **kwargs): + super().__init__(name, *args, **kwargs) + self.language = language + + async def run(self, ocr_results: list, filename: str, *args, **kwargs) -> dict[str, str]: + """Processes OCR results, extracts invoice information, generates a table, and saves it as an Excel file. + + Args: + ocr_results: A list of OCR results obtained from invoice processing. + filename: The name of the output Excel file. + + Returns: + A dictionary containing the invoice information. + + """ + table_data = [] + pathname = INVOICE_OCR_TABLE_PATH + pathname.mkdir(parents=True, exist_ok=True) + + for ocr_result in ocr_results: + # Extract invoice OCR main information + prompt = EXTRACT_OCR_MAIN_INFO_PROMPT.format(ocr_result=ocr_result, language=self.language) + ocr_info = await self._aask(prompt=prompt) + invoice_data = OutputParser.extract_struct(ocr_info, dict) + if invoice_data: + table_data.append(invoice_data) + + # Generate Excel file + filename = f"{filename.split('.')[0]}.xlsx" + full_filename = f"{pathname}/{filename}" + df = pd.DataFrame(table_data) + df.to_excel(full_filename, index=False) + return table_data + + +class ReplyQuestion(Action): + """Action class for generating replies to questions based on OCR results. + + Args: + name: The name of the action. Defaults to an empty string. + language: The language used for generating the reply. Defaults to "ch" (Chinese). + + """ + + def __init__(self, name: str = "", language: str = "ch", *args, **kwargs): + super().__init__(name, *args, **kwargs) + self.language = language + + async def run(self, query: str, ocr_result: list, *args, **kwargs) -> str: + """Reply to questions based on ocr results. + + Args: + query: The question for which a reply is generated. + ocr_result: A list of OCR results. + + Returns: + A reply result of string type. + """ + prompt = REPLY_OCR_QUESTION_PROMPT.format(query=query, ocr_result=ocr_result, language=self.language) + resp = await self._aask(prompt=prompt) + return resp + diff --git a/metagpt/const.py b/metagpt/const.py index b8b08628e..7f3f87dfa 100644 --- a/metagpt/const.py +++ b/metagpt/const.py @@ -36,6 +36,7 @@ YAPI_URL = "http://yapi.deepwisdomai.com/" TMP = PROJECT_ROOT / "tmp" RESEARCH_PATH = DATA_PATH / "research" TUTORIAL_PATH = DATA_PATH / "tutorial_docx" +INVOICE_OCR_TABLE_PATH = DATA_PATH / "invoice_table" SKILL_DIRECTORY = PROJECT_ROOT / "metagpt/skills" diff --git a/metagpt/prompts/invoice_ocr.py b/metagpt/prompts/invoice_ocr.py new file mode 100644 index 000000000..52f628a5b --- /dev/null +++ b/metagpt/prompts/invoice_ocr.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 +# _*_ coding: utf-8 _*_ + +""" +@Time : 2023/9/21 16:30:25 +@Author : Stitch-z +@File : invoice_ocr.py +@Describe : Prompts of the invoice ocr assistant. +""" + +COMMON_PROMPT = "Now I will provide you with the OCR text recognition results for the invoice." + +EXTRACT_OCR_MAIN_INFO_PROMPT = COMMON_PROMPT + """ +Please extract the payee, city, total cost, and invoicing date of the invoice. + +The OCR data of the invoice are as follows: +{ocr_result} + +Mandatory restrictions are returned according to the following requirements: +1. The total cost refers to the total price and tax. Do not include `¥`. +2. The city must be the recipient's city. +2. The returned JSON dictionary must be returned in {language} +3. Mandatory requirement to output in JSON format: {{"收款人":"x","城市":"x","总费用/元":"","开票日期":""}}. +""" + +REPLY_OCR_QUESTION_PROMPT = COMMON_PROMPT + """ +Please answer the question: {query} + +The OCR data of the invoice are as follows: +{ocr_result} + +Mandatory restrictions are returned according to the following requirements: +1. Answer in {language} language. +2. Enforce restrictions on not returning OCR data sent to you. +3. Return with markdown syntax layout. +""" + +INVOICE_OCR_SUCCESS = "Successfully completed OCR text recognition invoice." + diff --git a/metagpt/roles/invoice_ocr_assistant.py b/metagpt/roles/invoice_ocr_assistant.py new file mode 100644 index 000000000..c307b20c0 --- /dev/null +++ b/metagpt/roles/invoice_ocr_assistant.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 +# _*_ coding: utf-8 _*_ + +""" +@Time : 2023/9/21 14:10:05 +@Author : Stitch-z +@File : invoice_ocr_assistant.py +""" + +import pandas as pd + +from metagpt.actions.invoice_ocr import InvoiceOCR, GenerateTable, ReplyQuestion +from metagpt.prompts.invoice_ocr import INVOICE_OCR_SUCCESS +from metagpt.roles import Role +from metagpt.schema import Message + + +class InvoiceOCRAssistant(Role): + """Invoice OCR assistant, support OCR text recognition of invoice PDF, png, jpg, and zip files, + generate a table for the payee, city, total amount, and invoicing date of the invoice, + and ask questions for a single file based on the OCR recognition results of the invoice. + + Args: + name: The name of the role. + profile: The role profile description. + goal: The goal of the role. + constraints: Constraints or requirements for the role. + language: The language in which the invoice table will be generated. + """ + + def __init__( + self, + name: str = "Stitch", + profile: str = "Invoice OCR Assistant", + goal: str = "OCR identifies invoice files and generates invoice main information table", + constraints: str = "", + language: str = "ch", + ): + super().__init__(name, profile, goal, constraints) + self._init_actions([InvoiceOCR]) + self.language = language + self.filename = "" + self.origin_query = "" + self.orc_data = None + + async def _think(self) -> None: + """Determine the next action to be taken by the role.""" + if self._rc.todo is None: + self._set_state(0) + return + + if self._rc.state + 1 < len(self._states): + self._set_state(self._rc.state + 1) + else: + self._rc.todo = None + + async def _act(self) -> Message: + """Perform an action as determined by the role. + + Returns: + A message containing the result of the action. + """ + msg = self._rc.memory.get(k=1)[0] + todo = self._rc.todo + if isinstance(todo, InvoiceOCR): + self.origin_query = msg.content + file_path = msg.instruct_content.get("file_path") + self.filename = file_path.name + if not file_path: + raise Exception("Invoice file not uploaded") + + resp = await todo.run(file_path) + if len(resp) == 1: + # Single file support for questioning based on OCR recognition results + self._init_actions([GenerateTable, ReplyQuestion]) + self.orc_data = resp[0] + else: + self._init_actions([GenerateTable]) + + self._rc.todo = None + content = INVOICE_OCR_SUCCESS + elif isinstance(todo, GenerateTable): + ocr_results = msg.instruct_content + resp = await todo.run(ocr_results, self.filename) + + # Convert list to Markdown format string + df = pd.DataFrame(resp) + markdown_table = df.to_markdown(index=False) + content = f"{markdown_table}\n\n\n" + else: + resp = await todo.run(self.origin_query, self.orc_data) + content = resp + + msg = Message(content=content, instruct_content=resp) + self._rc.memory.add(msg) + return msg + + async def _react(self) -> Message: + """Execute the invoice ocr assistant's think and actions. + + Returns: + A message containing the final result of the assistant's actions. + """ + while True: + await self._think() + if self._rc.todo is None: + break + msg = await self._act() + return msg + diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 59d179808..f09666beb 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -195,7 +195,8 @@ class OutputParser: except (ValueError, SyntaxError) as e: raise Exception(f"Error while extracting and parsing the {data_type}: {e}") else: - raise Exception(f"No {data_type} found in the text.") + logger.error(f"No {data_type} found in the text.") + return [] if data_type is list else {} class CodeParser: diff --git a/requirements-ocr.txt b/requirements-ocr.txt new file mode 100644 index 000000000..cf6103afc --- /dev/null +++ b/requirements-ocr.txt @@ -0,0 +1,4 @@ +paddlepaddle==2.4.2 +paddleocr>=2.0.1 +tabulate==0.9.0 +-r requirements.txt diff --git a/tests/data/invoices/invoice-1.pdf b/tests/data/invoices/invoice-1.pdf new file mode 100644 index 000000000..7f53133ef Binary files /dev/null and b/tests/data/invoices/invoice-1.pdf differ diff --git a/tests/data/invoices/invoice-2.png b/tests/data/invoices/invoice-2.png new file mode 100644 index 000000000..412ec3a24 Binary files /dev/null and b/tests/data/invoices/invoice-2.png differ diff --git a/tests/data/invoices/invoice-3.jpg b/tests/data/invoices/invoice-3.jpg new file mode 100644 index 000000000..8fb74dd30 Binary files /dev/null and b/tests/data/invoices/invoice-3.jpg differ diff --git a/tests/data/invoices/invoice-4.zip b/tests/data/invoices/invoice-4.zip new file mode 100644 index 000000000..c6be51a2b Binary files /dev/null and b/tests/data/invoices/invoice-4.zip differ diff --git a/tests/metagpt/actions/test_invoice_ocr.py b/tests/metagpt/actions/test_invoice_ocr.py new file mode 100644 index 000000000..a15166f7c --- /dev/null +++ b/tests/metagpt/actions/test_invoice_ocr.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +# _*_ coding: utf-8 _*_ + +""" +@Time : 2023/10/09 18:40:34 +@Author : Stitch-z +@File : test_invoice_ocr.py +""" + +import os +from typing import List + +import pytest +from pathlib import Path + +from metagpt.actions.invoice_ocr import InvoiceOCR, GenerateTable, ReplyQuestion + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "invoice_path", + [ + "../../data/invoices/invoice-3.jpg", + "../../data/invoices/invoice-4.zip", + ] +) +async def test_invoice_ocr(invoice_path: str): + invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path)) + filename = os.path.basename(invoice_path) + resp = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename) + assert isinstance(resp, list) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("invoice_path", "expected_result"), + [ + ( + "../../data/invoices/invoice-1.pdf", + [ + { + "收款人": "小明", + "城市": "深圳市", + "总费用/元": "412.00", + "开票日期": "2023年02月03日" + } + ] + ), + ] +) +async def test_generate_table(invoice_path: str, expected_result: list[dict]): + invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path)) + filename = os.path.basename(invoice_path) + ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename) + table_data = await GenerateTable().run(ocr_results=ocr_result, filename=filename) + assert table_data == expected_result + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("invoice_path", "query", "expected_result"), + [ + ("../../data/invoices/invoice-1.pdf", "Invoicing date", "2023年02月03日") + ] +) +async def test_reply_question(invoice_path: str, query: dict, expected_result: str): + invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path)) + filename = os.path.basename(invoice_path) + ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename) + result = await ReplyQuestion().run(query=query, ocr_result=ocr_result) + assert expected_result in result + diff --git a/tests/metagpt/roles/test_invoice_ocr_assistant.py b/tests/metagpt/roles/test_invoice_ocr_assistant.py new file mode 100644 index 000000000..75097e73c --- /dev/null +++ b/tests/metagpt/roles/test_invoice_ocr_assistant.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python3 +# _*_ coding: utf-8 _*_ + +""" +@Time : 2023/9/21 23:11:27 +@Author : Stitch-z +@File : test_invoice_ocr_assistant.py +""" + +from pathlib import Path + +import pytest +import pandas as pd + +from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant +from metagpt.schema import Message + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("query", "invoice_path", "invoice_table_path", "expected_result"), + [ + ( + "Invoicing date", + Path("../../data/invoices/invoice-1.pdf"), + Path("../../../data/invoice_table/invoice-1.xlsx"), + [ + { + "收款人": "小明", + "城市": "深圳市", + "总费用/元": 412.00, + "开票日期": "2023年02月03日" + } + ] + ), + ( + "Invoicing date", + Path("../../data/invoices/invoice-2.png"), + Path("../../../data/invoice_table/invoice-2.xlsx"), + [ + { + "收款人": "铁头", + "城市": "广州市", + "总费用/元": 898.00, + "开票日期": "2023年03月17日" + } + ] + ), + ( + "Invoicing date", + Path("../../data/invoices/invoice-3.jpg"), + Path("../../../data/invoice_table/invoice-3.xlsx"), + [ + { + "收款人": "夏天", + "城市": "福州市", + "总费用/元": 2462.00, + "开票日期": "2023年08月26日" + } + ] + ), + ( + "Invoicing date", + Path("../../data/invoices/invoice-4.zip"), + Path("../../../data/invoice_table/invoice-4.xlsx"), + [ + { + "收款人": "小明", + "城市": "深圳市", + "总费用/元": 412.00, + "开票日期": "2023年02月03日" + }, + { + "收款人": "铁头", + "城市": "广州市", + "总费用/元": 898.00, + "开票日期": "2023年03月17日" + }, + { + "收款人": "夏天", + "城市": "福州市", + "总费用/元": 2462.00, + "开票日期": "2023年08月26日" + } + ] + ), + ] +) +async def test_invoice_ocr_assistant( + query: str, + invoice_path: Path, + invoice_table_path: Path, + expected_result: list[dict] +): + invoice_path = Path.cwd() / invoice_path + role = InvoiceOCRAssistant() + await role.run(Message( + content=query, + instruct_content={"file_path": invoice_path} + )) + invoice_table_path = Path.cwd() / invoice_table_path + df = pd.read_excel(invoice_table_path) + dict_result = df.to_dict(orient='records') + assert dict_result == expected_result + diff --git a/tests/metagpt/utils/test_output_parser.py b/tests/metagpt/utils/test_output_parser.py index 2b706efc4..4e362f9f7 100644 --- a/tests/metagpt/utils/test_output_parser.py +++ b/tests/metagpt/utils/test_output_parser.py @@ -95,7 +95,7 @@ def test_parse_data(): """xxx xx""", list, None, - Exception, + [], ), ( """xxx [1, 2, []xx""",