diff --git a/Dockerfile b/Dockerfile index 8ab180e28..c6e22989b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,7 +3,7 @@ FROM nikolaik/python-nodejs:python3.9-nodejs20-slim # Install Debian software needed by MetaGPT and clean up in one RUN command to reduce image size RUN apt update &&\ - apt install -y git chromium fonts-ipafont-gothic fonts-wqy-zenhei fonts-thai-tlwg fonts-kacst fonts-freefont-ttf libxss1 --no-install-recommends &&\ + apt install -y libgomp1 git chromium fonts-ipafont-gothic fonts-wqy-zenhei fonts-thai-tlwg fonts-kacst fonts-freefont-ttf libxss1 --no-install-recommends &&\ apt clean && rm -rf /var/lib/apt/lists/* # Install Mermaid CLI globally diff --git a/examples/invoice_ocr.py b/examples/invoice_ocr.py new file mode 100644 index 000000000..e1f7a63fc --- /dev/null +++ b/examples/invoice_ocr.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 +# _*_ coding: utf-8 _*_ + +""" +@Time : 2023/9/21 21:40:57 +@Author : Stitch-z +@File : tutorial_assistant.py +""" + +import asyncio +import os + +from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant +from metagpt.schema import Message + + +async def main(): + relative_paths = [ + "../tests/data/invoices/invoice-1.pdf", + "../tests/data/invoices/invoice-2.png", + "../tests/data/invoices/invoice-3.jpg", + "../tests/data/invoices/invoice-4.zip" + ] + # Get the current working directory + current_directory = os.getcwd() + # The absolute path of the file + absolute_file_paths = [os.path.abspath(os.path.join(current_directory, path)) for path in relative_paths] + + for path in absolute_file_paths: + role = InvoiceOCRAssistant() + await role.run(Message( + content="Invoicing date", + instruct_content={"file_path": path} + )) + + +if __name__ == '__main__': + asyncio.run(main()) + diff --git a/metagpt/actions/invoice_ocr.py b/metagpt/actions/invoice_ocr.py new file mode 100644 index 000000000..172a2b981 --- /dev/null +++ b/metagpt/actions/invoice_ocr.py @@ -0,0 +1,223 @@ +#!/usr/bin/env python3 +# _*_ coding: utf-8 _*_ + +""" +@Time : 2023/9/21 18:10:20 +@Author : Stitch-z +@File : invoice_ocr.py +@Describe : Actions of the invoice ocr assistant. +""" + +import os +import zipfile +from enum import Enum +from pathlib import Path + +import pandas as pd + +from typing import Dict, List +from paddleocr import PaddleOCR + +from metagpt.actions import Action +from metagpt.const import INVOICE_OCR_TABLE_PATH +from metagpt.logs import logger +from metagpt.prompts.invoice_ocr import EXTRACT_OCR_MAIN_INFO_PROMPT, REPLY_OCR_QUESTION_PROMPT +from metagpt.utils.common import OutputParser +from metagpt.utils.file import File + + +class FileExtensionType(Enum): + """Enum representing file extensions and their associated types. + Each enum member consists of a tuple containing the file extension and its associated type. + + """ + + ZIP = (".zip", "zip") + PDF = (".pdf", "pdf") + PNG = (".png", "png") + JPG = (".jpg", "jpg") + + @classmethod + def get_extension_list(cls) -> List[str]: + """Get a list of file extensions. + + Returns: + A list of file extensions as strings. + """ + return [ext.value[0] for ext in cls] + + @classmethod + def get_type_list(cls) -> List[str]: + """Get a list of file types. + + Returns: + A list of file types as strings. + """ + return [ext.value[1] for ext in cls] + + +class InvoiceOCR(Action): + """Action class for performing OCR on invoice files, including zip, PDF, png, and jpg files. + + Args: + name: The name of the action. Defaults to an empty string. + language: The language for OCR output. Defaults to "ch" (Chinese). + + """ + + def __init__(self, name: str = "", *args, **kwargs): + super().__init__(name, *args, **kwargs) + + @staticmethod + async def _check_file_type(filename: str) -> str: + """Check the file type of the given filename. + + Args: + filename: The name of the file. + + Returns: + The file type based on FileExtensionType enum. + + Raises: + Exception: If the file format is not zip, pdf, png, or jpg. + """ + file_ext = None + for ext in FileExtensionType: + if filename.endswith(ext.value[0]): + file_ext = ext.value[1] + break + + if not file_ext: + raise Exception("The invoice format is not zip, pdf, png, or jpg") + + return file_ext + + @staticmethod + async def _unzip(file_path: Path) -> Path: + """Unzip a file and return the path to the unzipped directory. + + Args: + file_path: The path to the zip file. + + Returns: + The path to the unzipped directory. + """ + file_directory = file_path.parent + with zipfile.ZipFile(file_path, 'r') as zip_ref: + for zip_info in zip_ref.infolist(): + # Use CP437 to encode the file name, and then use GBK decoding to prevent Chinese garbled code + relative_name = zip_info.filename.encode('cp437').decode('gbk') + unzip_dir, name = relative_name.split("/") + if name: + full_filename = file_directory / relative_name + await File.write(full_filename.parent, name, zip_ref.read(zip_info.filename)) + + unzip_path = file_directory / unzip_dir + logger.info(f"unzip_path: {unzip_path}") + return unzip_path + + async def run(self, file_path: Path, filename: str, *args, **kwargs) -> list: + """Execute the action to identify invoice files through OCR. + + Args: + file_path: The path to the input file. + filename: The name of the input file. + + Returns: + A list of OCR results. + """ + file_ext = await self._check_file_type(filename) + + if file_ext == FileExtensionType.ZIP.value[1]: + # OCR recognizes zip batch files + unzip_path = await self._unzip(file_path) + file_list = os.listdir(unzip_path) + ocr_list = [] + + for filename in file_list: + invoice_file_path = unzip_path / filename + # Identify files that match the type + if filename.split(".")[-1] in FileExtensionType.get_type_list(): + ocr = PaddleOCR(use_angle_cls=True, lang="ch", page_num=1) + ocr_result = ocr.ocr(str(invoice_file_path), cls=True) + ocr_list.append(ocr_result) + return ocr_list + + else: + # OCR identifies single file + ocr = PaddleOCR(use_angle_cls=True, lang="ch", page_num=1) + ocr_result = ocr.ocr(str(file_path), cls=True) + return [ocr_result] + + +class GenerateTable(Action): + """Action class for generating tables from OCR results. + + Args: + name: The name of the action. Defaults to an empty string. + language: The language used for the generated table. Defaults to "ch" (Chinese). + + """ + + def __init__(self, name: str = "", language: str = "ch", *args, **kwargs): + super().__init__(name, *args, **kwargs) + self.language = language + + async def run(self, ocr_results: list, filename: str, *args, **kwargs) -> Dict[str, str]: + """Processes OCR results, extracts invoice information, generates a table, and saves it as an Excel file. + + Args: + ocr_results: A list of OCR results obtained from invoice processing. + filename: The name of the output Excel file. + + Returns: + A dictionary containing the invoice information. + + """ + table_data = [] + pathname = INVOICE_OCR_TABLE_PATH + pathname.mkdir(parents=True, exist_ok=True) + + for ocr_result in ocr_results: + # Extract invoice OCR main information + prompt = EXTRACT_OCR_MAIN_INFO_PROMPT.format(ocr_result=ocr_result, language=self.language) + ocr_info = await self._aask(prompt=prompt) + invoice_data = OutputParser.extract_struct(ocr_info, dict) + if invoice_data: + table_data.append(invoice_data) + + # Generate Excel file + filename = f"{filename.split('.')[0]}.xlsx" + full_filename = f"{pathname}/{filename}" + df = pd.DataFrame(table_data) + df.to_excel(full_filename, index=False) + return table_data + + +class ReplyQuestion(Action): + """Action class for generating replies to questions based on OCR results. + + Args: + name: The name of the action. Defaults to an empty string. + language: The language used for generating the reply. Defaults to "ch" (Chinese). + + """ + + def __init__(self, name: str = "", language: str = "ch", *args, **kwargs): + super().__init__(name, *args, **kwargs) + self.language = language + + async def run(self, query: str, ocr_result: list, *args, **kwargs) -> str: + """Reply to questions based on ocr results. + + Args: + query: The question for which a reply is generated. + ocr_result: A list of OCR results. + + Returns: + A reply result of string type. + """ + prompt = REPLY_OCR_QUESTION_PROMPT.format(query=query, ocr_result=ocr_result, language=self.language) + resp = await self._aask(prompt=prompt) + return resp + diff --git a/metagpt/const.py b/metagpt/const.py index b8b08628e..7f3f87dfa 100644 --- a/metagpt/const.py +++ b/metagpt/const.py @@ -36,6 +36,7 @@ YAPI_URL = "http://yapi.deepwisdomai.com/" TMP = PROJECT_ROOT / "tmp" RESEARCH_PATH = DATA_PATH / "research" TUTORIAL_PATH = DATA_PATH / "tutorial_docx" +INVOICE_OCR_TABLE_PATH = DATA_PATH / "invoice_table" SKILL_DIRECTORY = PROJECT_ROOT / "metagpt/skills" diff --git a/metagpt/prompts/invoice_ocr.py b/metagpt/prompts/invoice_ocr.py new file mode 100644 index 000000000..52f628a5b --- /dev/null +++ b/metagpt/prompts/invoice_ocr.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 +# _*_ coding: utf-8 _*_ + +""" +@Time : 2023/9/21 16:30:25 +@Author : Stitch-z +@File : invoice_ocr.py +@Describe : Prompts of the invoice ocr assistant. +""" + +COMMON_PROMPT = "Now I will provide you with the OCR text recognition results for the invoice." + +EXTRACT_OCR_MAIN_INFO_PROMPT = COMMON_PROMPT + """ +Please extract the payee, city, total cost, and invoicing date of the invoice. + +The OCR data of the invoice are as follows: +{ocr_result} + +Mandatory restrictions are returned according to the following requirements: +1. The total cost refers to the total price and tax. Do not include `¥`. +2. The city must be the recipient's city. +2. The returned JSON dictionary must be returned in {language} +3. Mandatory requirement to output in JSON format: {{"收款人":"x","城市":"x","总费用/元":"","开票日期":""}}. +""" + +REPLY_OCR_QUESTION_PROMPT = COMMON_PROMPT + """ +Please answer the question: {query} + +The OCR data of the invoice are as follows: +{ocr_result} + +Mandatory restrictions are returned according to the following requirements: +1. Answer in {language} language. +2. Enforce restrictions on not returning OCR data sent to you. +3. Return with markdown syntax layout. +""" + +INVOICE_OCR_SUCCESS = "Successfully completed OCR text recognition invoice." + diff --git a/metagpt/roles/invoice_ocr_assistant.py b/metagpt/roles/invoice_ocr_assistant.py new file mode 100644 index 000000000..ede41dfc2 --- /dev/null +++ b/metagpt/roles/invoice_ocr_assistant.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +# _*_ coding: utf-8 _*_ + +""" +@Time : 2023/9/21 14:10:05 +@Author : Stitch-z +@File : invoice_ocr_assistant.py +""" + +import os +from pathlib import Path + +import pandas as pd +from metagpt.actions.invoice_ocr import InvoiceOCR, GenerateTable, ReplyQuestion +from metagpt.prompts.invoice_ocr import INVOICE_OCR_SUCCESS +from metagpt.roles import Role +from metagpt.schema import Message + + +class InvoiceOCRAssistant(Role): + """Invoice OCR assistant, support OCR text recognition of invoice PDF, png, jpg, and zip files, + generate a table for the payee, city, total amount, and invoicing date of the invoice, + and ask questions for a single file based on the OCR recognition results of the invoice. + + Args: + name: The name of the role. + profile: The role profile description. + goal: The goal of the role. + constraints: Constraints or requirements for the role. + language: The language in which the invoice table will be generated. + """ + + def __init__( + self, + name: str = "Stitch", + profile: str = "Invoice Ocr Assistant", + goal: str = "OCR identifies invoice files and generates invoice main information table", + constraints: str = "", + language: str = "ch", + ): + super().__init__(name, profile, goal, constraints) + self._init_actions([InvoiceOCR]) + self.language = language + self.filename = "" + self.origin_query = "" + self.orc_data = None + + async def _think(self) -> None: + """Determine the next action to be taken by the role.""" + if self._rc.todo is None: + self._set_state(0) + return + + if self._rc.state + 1 < len(self._states): + self._set_state(self._rc.state + 1) + else: + self._rc.todo = None + + async def _act(self) -> Message: + """Perform an action as determined by the role. + + Returns: + A message containing the result of the action. + """ + msg = self._rc.memory.get(k=1)[0] + todo = self._rc.todo + if isinstance(todo, InvoiceOCR): + self.origin_query = msg.content + file_path = msg.instruct_content.get("file_path") + self.filename = os.path.basename(file_path) + if not file_path: + raise Exception("Invoice file not uploaded") + + resp = await todo.run(Path(file_path), self.filename) + if len(resp) == 1: + # Single file support for questioning based on OCR recognition results + self._init_actions([GenerateTable, ReplyQuestion]) + self.orc_data = resp[0] + else: + self._init_actions([GenerateTable]) + + self._rc.todo = None + content = INVOICE_OCR_SUCCESS + elif isinstance(todo, GenerateTable): + ocr_results = msg.instruct_content + resp = await todo.run(ocr_results, self.filename) + + # Convert list to Markdown format string + df = pd.DataFrame(resp) + markdown_table = df.to_markdown(index=False) + content = f"{markdown_table}\n\n\n" + else: + resp = await todo.run(self.origin_query, self.orc_data) + content = resp + + msg = Message(content=content, instruct_content=resp) + self._rc.memory.add(msg) + return msg + + async def _react(self) -> Message: + """Execute the invoice ocr assistant's think and actions. + + Returns: + A message containing the final result of the assistant's actions. + """ + while True: + await self._think() + if self._rc.todo is None: + break + msg = await self._act() + return msg + diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 65cc15e82..aef5546fa 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -195,7 +195,8 @@ class OutputParser: except (ValueError, SyntaxError) as e: raise Exception(f"Error while extracting and parsing the {data_type}: {e}") else: - raise Exception(f"No {data_type} found in the text.") + logger.error(f"No {data_type} found in the text.") + return [] if data_type is list else {} class CodeParser: diff --git a/requirements.txt b/requirements.txt index de861ded9..51dc5c8aa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,5 @@ +paddlepaddle==2.4.2 +paddleocr>=2.0.1 aiohttp==3.8.4 #azure_storage==0.37.0 channels==4.0.0 @@ -42,4 +44,4 @@ pytest-mock==3.11.1 open-interpreter==0.1.4; python_version>"3.9" ta==0.10.2 semantic-kernel==0.3.10.dev0 - +tabulate==0.9.0 diff --git a/tests/data/invoices/invoice-1.pdf b/tests/data/invoices/invoice-1.pdf new file mode 100644 index 000000000..7f53133ef Binary files /dev/null and b/tests/data/invoices/invoice-1.pdf differ diff --git a/tests/data/invoices/invoice-2.png b/tests/data/invoices/invoice-2.png new file mode 100644 index 000000000..412ec3a24 Binary files /dev/null and b/tests/data/invoices/invoice-2.png differ diff --git a/tests/data/invoices/invoice-3.jpg b/tests/data/invoices/invoice-3.jpg new file mode 100644 index 000000000..8fb74dd30 Binary files /dev/null and b/tests/data/invoices/invoice-3.jpg differ diff --git a/tests/data/invoices/invoice-4.zip b/tests/data/invoices/invoice-4.zip new file mode 100644 index 000000000..c6be51a2b Binary files /dev/null and b/tests/data/invoices/invoice-4.zip differ diff --git a/tests/metagpt/actions/test_invoice_ocr.py b/tests/metagpt/actions/test_invoice_ocr.py new file mode 100644 index 000000000..0433a2a89 --- /dev/null +++ b/tests/metagpt/actions/test_invoice_ocr.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +# _*_ coding: utf-8 _*_ + +""" +@Time : 2023/10/09 18:40:34 +@Author : Stitch-z +@File : test_invoice_ocr.py +""" + +import os +from pathlib import Path +from typing import List + +import pytest + +from metagpt.actions.invoice_ocr import InvoiceOCR, GenerateTable, ReplyQuestion + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "invoice_path", + [ + "../../data/invoices/invoice-3.jpg", + "../../data/invoices/invoice-4.zip", + ] +) +async def test_invoice_ocr(invoice_path: str): + invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path)) + filename = os.path.basename(invoice_path) + resp = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename) + assert isinstance(resp, list) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("invoice_path", "expected_result"), + [ + ( + "../../data/invoices/invoice-1.pdf", + [ + { + "收款人": "小明", + "城市": "深圳市", + "总费用/元": "412.00", + "开票日期": "2023年02月03日" + } + ] + ), + ] +) +async def test_generate_table(invoice_path: str, expected_result: List[dict]): + invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path)) + filename = os.path.basename(invoice_path) + ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename) + table_data = await GenerateTable().run(ocr_results=ocr_result, filename=filename) + assert table_data == expected_result + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("invoice_path", "query", "expected_result"), + [ + ("../../data/invoices/invoice-1.pdf", "Invoicing date", "2023年02月03日") + ] +) +async def test_reply_question(invoice_path: str, query: dict, expected_result: str): + invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path)) + filename = os.path.basename(invoice_path) + ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename) + result = await ReplyQuestion().run(query=query, ocr_result=ocr_result) + assert expected_result in result + diff --git a/tests/metagpt/roles/test_invoice_ocr_assistant.py b/tests/metagpt/roles/test_invoice_ocr_assistant.py new file mode 100644 index 000000000..d73de3492 --- /dev/null +++ b/tests/metagpt/roles/test_invoice_ocr_assistant.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python3 +# _*_ coding: utf-8 _*_ + +""" +@Time : 2023/9/21 23:11:27 +@Author : Stitch-z +@File : test_invoice_ocr_assistant.py +""" + +import os +import pandas as pd +from typing import List + +import pytest + +from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant +from metagpt.schema import Message + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("query", "invoice_path", "invoice_table_path", "expected_result"), + [ + ( + "Invoicing date", + "../../data/invoices/invoice-1.pdf", + "../../../data/invoice_table/invoice-1.xlsx", + [ + { + "收款人": "小明", + "城市": "深圳市", + "总费用/元": 412.00, + "开票日期": "2023年02月03日" + } + ] + ), + ( + "Invoicing date", + "../../data/invoices/invoice-2.png", + "../../../data/invoice_table/invoice-2.xlsx", + [ + { + "收款人": "铁头", + "城市": "广州市", + "总费用/元": 898.00, + "开票日期": "2023年03月17日" + } + ] + ), + ( + "Invoicing date", + "../../data/invoices/invoice-3.jpg", + "../../../data/invoice_table/invoice-3.xlsx", + [ + { + "收款人": "夏天", + "城市": "福州市", + "总费用/元": 2462.00, + "开票日期": "2023年08月26日" + } + ] + ), + ( + "Invoicing date", + "../../data/invoices/invoice-4.zip", + "../../../data/invoice_table/invoice-4.xlsx", + [ + { + "收款人": "小明", + "城市": "深圳市", + "总费用/元": 412.00, + "开票日期": "2023年02月03日" + }, + { + "收款人": "铁头", + "城市": "广州市", + "总费用/元": 898.00, + "开票日期": "2023年03月17日" + }, + { + "收款人": "夏天", + "城市": "福州市", + "总费用/元": 2462.00, + "开票日期": "2023年08月26日" + } + ] + ), + ] +) +async def test_invoice_ocr_assistant( + query: str, + invoice_path: str, + invoice_table_path: str, + expected_result: List[dict] +): + invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path)) + role = InvoiceOCRAssistant() + await role.run(Message( + content=query, + instruct_content={"file_path": invoice_path} + )) + invoice_table_path = os.path.abspath(os.path.join(os.getcwd(), invoice_table_path)) + df = pd.read_excel(invoice_table_path) + dict_result = df.to_dict(orient='records') + assert dict_result == expected_result diff --git a/tests/metagpt/utils/test_output_parser.py b/tests/metagpt/utils/test_output_parser.py index 2b706efc4..4e362f9f7 100644 --- a/tests/metagpt/utils/test_output_parser.py +++ b/tests/metagpt/utils/test_output_parser.py @@ -95,7 +95,7 @@ def test_parse_data(): """xxx xx""", list, None, - Exception, + [], ), ( """xxx [1, 2, []xx""",