Merge pull request #11 from Stitch-z/feature-invoice-ocr-assistant

Feature/invoice ocr assistant
This commit is contained in:
Stitch-z 2023-10-10 15:00:35 +08:00 committed by GitHub
commit d104777ed9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 599 additions and 3 deletions

View file

@ -3,7 +3,7 @@ FROM nikolaik/python-nodejs:python3.9-nodejs20-slim
# Install Debian software needed by MetaGPT and clean up in one RUN command to reduce image size
RUN apt update &&\
apt install -y git chromium fonts-ipafont-gothic fonts-wqy-zenhei fonts-thai-tlwg fonts-kacst fonts-freefont-ttf libxss1 --no-install-recommends &&\
apt install -y libgomp1 git chromium fonts-ipafont-gothic fonts-wqy-zenhei fonts-thai-tlwg fonts-kacst fonts-freefont-ttf libxss1 --no-install-recommends &&\
apt clean && rm -rf /var/lib/apt/lists/*
# Install Mermaid CLI globally

39
examples/invoice_ocr.py Normal file
View file

@ -0,0 +1,39 @@
#!/usr/bin/env python3
# _*_ coding: utf-8 _*_
"""
@Time : 2023/9/21 21:40:57
@Author : Stitch-z
@File : invoice_ocr.py
"""
import asyncio
import os
from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant
from metagpt.schema import Message
async def main():
relative_paths = [
"../tests/data/invoices/invoice-1.pdf",
"../tests/data/invoices/invoice-2.png",
"../tests/data/invoices/invoice-3.jpg",
"../tests/data/invoices/invoice-4.zip"
]
# Get the current working directory
current_directory = os.getcwd()
# The absolute path of the file
absolute_file_paths = [os.path.abspath(os.path.join(current_directory, path)) for path in relative_paths]
for path in absolute_file_paths:
role = InvoiceOCRAssistant()
await role.run(Message(
content="Invoicing date",
instruct_content={"file_path": path}
))
if __name__ == '__main__':
asyncio.run(main())

View file

@ -0,0 +1,223 @@
#!/usr/bin/env python3
# _*_ coding: utf-8 _*_
"""
@Time : 2023/9/21 18:10:20
@Author : Stitch-z
@File : invoice_ocr.py
@Describe : Actions of the invoice ocr assistant.
"""
import os
import zipfile
from enum import Enum
from pathlib import Path
import pandas as pd
from typing import Dict, List
from paddleocr import PaddleOCR
from metagpt.actions import Action
from metagpt.const import INVOICE_OCR_TABLE_PATH
from metagpt.logs import logger
from metagpt.prompts.invoice_ocr import EXTRACT_OCR_MAIN_INFO_PROMPT, REPLY_OCR_QUESTION_PROMPT
from metagpt.utils.common import OutputParser
from metagpt.utils.file import File
class FileExtensionType(Enum):
"""Enum representing file extensions and their associated types.
Each enum member consists of a tuple containing the file extension and its associated type.
"""
ZIP = (".zip", "zip")
PDF = (".pdf", "pdf")
PNG = (".png", "png")
JPG = (".jpg", "jpg")
@classmethod
def get_extension_list(cls) -> List[str]:
"""Get a list of file extensions.
Returns:
A list of file extensions as strings.
"""
return [ext.value[0] for ext in cls]
@classmethod
def get_type_list(cls) -> List[str]:
"""Get a list of file types.
Returns:
A list of file types as strings.
"""
return [ext.value[1] for ext in cls]
class InvoiceOCR(Action):
"""Action class for performing OCR on invoice files, including zip, PDF, png, and jpg files.
Args:
name: The name of the action. Defaults to an empty string.
language: The language for OCR output. Defaults to "ch" (Chinese).
"""
def __init__(self, name: str = "", *args, **kwargs):
super().__init__(name, *args, **kwargs)
@staticmethod
async def _check_file_type(filename: str) -> str:
"""Check the file type of the given filename.
Args:
filename: The name of the file.
Returns:
The file type based on FileExtensionType enum.
Raises:
Exception: If the file format is not zip, pdf, png, or jpg.
"""
file_ext = None
for ext in FileExtensionType:
if filename.endswith(ext.value[0]):
file_ext = ext.value[1]
break
if not file_ext:
raise Exception("The invoice format is not zip, pdf, png, or jpg")
return file_ext
@staticmethod
async def _unzip(file_path: Path) -> Path:
"""Unzip a file and return the path to the unzipped directory.
Args:
file_path: The path to the zip file.
Returns:
The path to the unzipped directory.
"""
file_directory = file_path.parent
with zipfile.ZipFile(file_path, 'r') as zip_ref:
for zip_info in zip_ref.infolist():
# Use CP437 to encode the file name, and then use GBK decoding to prevent Chinese garbled code
relative_name = zip_info.filename.encode('cp437').decode('gbk')
unzip_dir, name = relative_name.split("/")
if name:
full_filename = file_directory / relative_name
await File.write(full_filename.parent, name, zip_ref.read(zip_info.filename))
unzip_path = file_directory / unzip_dir
logger.info(f"unzip_path: {unzip_path}")
return unzip_path
async def run(self, file_path: Path, filename: str, *args, **kwargs) -> list:
"""Execute the action to identify invoice files through OCR.
Args:
file_path: The path to the input file.
filename: The name of the input file.
Returns:
A list of OCR results.
"""
file_ext = await self._check_file_type(filename)
if file_ext == FileExtensionType.ZIP.value[1]:
# OCR recognizes zip batch files
unzip_path = await self._unzip(file_path)
file_list = os.listdir(unzip_path)
ocr_list = []
for filename in file_list:
invoice_file_path = unzip_path / filename
# Identify files that match the type
if filename.split(".")[-1] in FileExtensionType.get_type_list():
ocr = PaddleOCR(use_angle_cls=True, lang="ch", page_num=1)
ocr_result = ocr.ocr(str(invoice_file_path), cls=True)
ocr_list.append(ocr_result)
return ocr_list
else:
# OCR identifies single file
ocr = PaddleOCR(use_angle_cls=True, lang="ch", page_num=1)
ocr_result = ocr.ocr(str(file_path), cls=True)
return [ocr_result]
class GenerateTable(Action):
"""Action class for generating tables from OCR results.
Args:
name: The name of the action. Defaults to an empty string.
language: The language used for the generated table. Defaults to "ch" (Chinese).
"""
def __init__(self, name: str = "", language: str = "ch", *args, **kwargs):
super().__init__(name, *args, **kwargs)
self.language = language
async def run(self, ocr_results: list, filename: str, *args, **kwargs) -> Dict[str, str]:
"""Processes OCR results, extracts invoice information, generates a table, and saves it as an Excel file.
Args:
ocr_results: A list of OCR results obtained from invoice processing.
filename: The name of the output Excel file.
Returns:
A dictionary containing the invoice information.
"""
table_data = []
pathname = INVOICE_OCR_TABLE_PATH
pathname.mkdir(parents=True, exist_ok=True)
for ocr_result in ocr_results:
# Extract invoice OCR main information
prompt = EXTRACT_OCR_MAIN_INFO_PROMPT.format(ocr_result=ocr_result, language=self.language)
ocr_info = await self._aask(prompt=prompt)
invoice_data = OutputParser.extract_struct(ocr_info, dict)
if invoice_data:
table_data.append(invoice_data)
# Generate Excel file
filename = f"{filename.split('.')[0]}.xlsx"
full_filename = f"{pathname}/{filename}"
df = pd.DataFrame(table_data)
df.to_excel(full_filename, index=False)
return table_data
class ReplyQuestion(Action):
"""Action class for generating replies to questions based on OCR results.
Args:
name: The name of the action. Defaults to an empty string.
language: The language used for generating the reply. Defaults to "ch" (Chinese).
"""
def __init__(self, name: str = "", language: str = "ch", *args, **kwargs):
super().__init__(name, *args, **kwargs)
self.language = language
async def run(self, query: str, ocr_result: list, *args, **kwargs) -> str:
"""Reply to questions based on ocr results.
Args:
query: The question for which a reply is generated.
ocr_result: A list of OCR results.
Returns:
A reply result of string type.
"""
prompt = REPLY_OCR_QUESTION_PROMPT.format(query=query, ocr_result=ocr_result, language=self.language)
resp = await self._aask(prompt=prompt)
return resp

View file

@ -36,6 +36,7 @@ YAPI_URL = "http://yapi.deepwisdomai.com/"
TMP = PROJECT_ROOT / "tmp"
RESEARCH_PATH = DATA_PATH / "research"
TUTORIAL_PATH = DATA_PATH / "tutorial_docx"
INVOICE_OCR_TABLE_PATH = DATA_PATH / "invoice_table"
SKILL_DIRECTORY = PROJECT_ROOT / "metagpt/skills"

View file

@ -0,0 +1,39 @@
#!/usr/bin/env python3
# _*_ coding: utf-8 _*_
"""
@Time : 2023/9/21 16:30:25
@Author : Stitch-z
@File : invoice_ocr.py
@Describe : Prompts of the invoice ocr assistant.
"""
COMMON_PROMPT = "Now I will provide you with the OCR text recognition results for the invoice."
EXTRACT_OCR_MAIN_INFO_PROMPT = COMMON_PROMPT + """
Please extract the payee, city, total cost, and invoicing date of the invoice.
The OCR data of the invoice are as follows:
{ocr_result}
Mandatory restrictions are returned according to the following requirements:
1. The total cost refers to the total price and tax. Do not include `¥`.
2. The city must be the recipient's city.
2. The returned JSON dictionary must be returned in {language}
3. Mandatory requirement to output in JSON format: {{"收款人":"x","城市":"x","总费用/元":"","开票日期":""}}.
"""
REPLY_OCR_QUESTION_PROMPT = COMMON_PROMPT + """
Please answer the question: {query}
The OCR data of the invoice are as follows:
{ocr_result}
Mandatory restrictions are returned according to the following requirements:
1. Answer in {language} language.
2. Enforce restrictions on not returning OCR data sent to you.
3. Return with markdown syntax layout.
"""
INVOICE_OCR_SUCCESS = "Successfully completed OCR text recognition invoice."

View file

@ -0,0 +1,112 @@
#!/usr/bin/env python3
# _*_ coding: utf-8 _*_
"""
@Time : 2023/9/21 14:10:05
@Author : Stitch-z
@File : invoice_ocr_assistant.py
"""
import os
from pathlib import Path
import pandas as pd
from metagpt.actions.invoice_ocr import InvoiceOCR, GenerateTable, ReplyQuestion
from metagpt.prompts.invoice_ocr import INVOICE_OCR_SUCCESS
from metagpt.roles import Role
from metagpt.schema import Message
class InvoiceOCRAssistant(Role):
"""Invoice OCR assistant, support OCR text recognition of invoice PDF, png, jpg, and zip files,
generate a table for the payee, city, total amount, and invoicing date of the invoice,
and ask questions for a single file based on the OCR recognition results of the invoice.
Args:
name: The name of the role.
profile: The role profile description.
goal: The goal of the role.
constraints: Constraints or requirements for the role.
language: The language in which the invoice table will be generated.
"""
def __init__(
self,
name: str = "Stitch",
profile: str = "Invoice Ocr Assistant",
goal: str = "OCR identifies invoice files and generates invoice main information table",
constraints: str = "",
language: str = "ch",
):
super().__init__(name, profile, goal, constraints)
self._init_actions([InvoiceOCR])
self.language = language
self.filename = ""
self.origin_query = ""
self.orc_data = None
async def _think(self) -> None:
"""Determine the next action to be taken by the role."""
if self._rc.todo is None:
self._set_state(0)
return
if self._rc.state + 1 < len(self._states):
self._set_state(self._rc.state + 1)
else:
self._rc.todo = None
async def _act(self) -> Message:
"""Perform an action as determined by the role.
Returns:
A message containing the result of the action.
"""
msg = self._rc.memory.get(k=1)[0]
todo = self._rc.todo
if isinstance(todo, InvoiceOCR):
self.origin_query = msg.content
file_path = msg.instruct_content.get("file_path")
self.filename = os.path.basename(file_path)
if not file_path:
raise Exception("Invoice file not uploaded")
resp = await todo.run(Path(file_path), self.filename)
if len(resp) == 1:
# Single file support for questioning based on OCR recognition results
self._init_actions([GenerateTable, ReplyQuestion])
self.orc_data = resp[0]
else:
self._init_actions([GenerateTable])
self._rc.todo = None
content = INVOICE_OCR_SUCCESS
elif isinstance(todo, GenerateTable):
ocr_results = msg.instruct_content
resp = await todo.run(ocr_results, self.filename)
# Convert list to Markdown format string
df = pd.DataFrame(resp)
markdown_table = df.to_markdown(index=False)
content = f"{markdown_table}\n\n\n"
else:
resp = await todo.run(self.origin_query, self.orc_data)
content = resp
msg = Message(content=content, instruct_content=resp)
self._rc.memory.add(msg)
return msg
async def _react(self) -> Message:
"""Execute the invoice ocr assistant's think and actions.
Returns:
A message containing the final result of the assistant's actions.
"""
while True:
await self._think()
if self._rc.todo is None:
break
msg = await self._act()
return msg

View file

@ -195,7 +195,8 @@ class OutputParser:
except (ValueError, SyntaxError) as e:
raise Exception(f"Error while extracting and parsing the {data_type}: {e}")
else:
raise Exception(f"No {data_type} found in the text.")
logger.error(f"No {data_type} found in the text.")
return [] if data_type is list else {}
class CodeParser:

View file

@ -1,3 +1,5 @@
paddlepaddle==2.4.2
paddleocr>=2.0.1
aiohttp==3.8.4
#azure_storage==0.37.0
channels==4.0.0
@ -42,4 +44,5 @@ pytest-mock==3.11.1
open-interpreter==0.1.4; python_version>"3.9"
ta==0.10.2
semantic-kernel==0.3.10.dev0
tabulate==0.9.0

Binary file not shown.

Binary file not shown.

After

Width:  |  Height:  |  Size: 464 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 466 KiB

Binary file not shown.

View file

@ -0,0 +1,72 @@
#!/usr/bin/env python3
# _*_ coding: utf-8 _*_
"""
@Time : 2023/10/09 18:40:34
@Author : Stitch-z
@File : test_invoice_ocr.py
"""
import os
from pathlib import Path
from typing import List
import pytest
from metagpt.actions.invoice_ocr import InvoiceOCR, GenerateTable, ReplyQuestion
@pytest.mark.asyncio
@pytest.mark.parametrize(
"invoice_path",
[
"../../data/invoices/invoice-3.jpg",
"../../data/invoices/invoice-4.zip",
]
)
async def test_invoice_ocr(invoice_path: str):
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
filename = os.path.basename(invoice_path)
resp = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)
assert isinstance(resp, list)
@pytest.mark.asyncio
@pytest.mark.parametrize(
("invoice_path", "expected_result"),
[
(
"../../data/invoices/invoice-1.pdf",
[
{
"收款人": "小明",
"城市": "深圳市",
"总费用/元": "412.00",
"开票日期": "2023年02月03日"
}
]
),
]
)
async def test_generate_table(invoice_path: str, expected_result: List[dict]):
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
filename = os.path.basename(invoice_path)
ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)
table_data = await GenerateTable().run(ocr_results=ocr_result, filename=filename)
assert table_data == expected_result
@pytest.mark.asyncio
@pytest.mark.parametrize(
("invoice_path", "query", "expected_result"),
[
("../../data/invoices/invoice-1.pdf", "Invoicing date", "2023年02月03日")
]
)
async def test_reply_question(invoice_path: str, query: dict, expected_result: str):
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
filename = os.path.basename(invoice_path)
ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)
result = await ReplyQuestion().run(query=query, ocr_result=ocr_result)
assert expected_result in result

View file

@ -0,0 +1,106 @@
#!/usr/bin/env python3
# _*_ coding: utf-8 _*_
"""
@Time : 2023/9/21 23:11:27
@Author : Stitch-z
@File : test_invoice_ocr_assistant.py
"""
import os
import pandas as pd
from typing import List
import pytest
from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant
from metagpt.schema import Message
@pytest.mark.asyncio
@pytest.mark.parametrize(
("query", "invoice_path", "invoice_table_path", "expected_result"),
[
(
"Invoicing date",
"../../data/invoices/invoice-1.pdf",
"../../../data/invoice_table/invoice-1.xlsx",
[
{
"收款人": "小明",
"城市": "深圳市",
"总费用/元": 412.00,
"开票日期": "2023年02月03日"
}
]
),
(
"Invoicing date",
"../../data/invoices/invoice-2.png",
"../../../data/invoice_table/invoice-2.xlsx",
[
{
"收款人": "铁头",
"城市": "广州市",
"总费用/元": 898.00,
"开票日期": "2023年03月17日"
}
]
),
(
"Invoicing date",
"../../data/invoices/invoice-3.jpg",
"../../../data/invoice_table/invoice-3.xlsx",
[
{
"收款人": "夏天",
"城市": "福州市",
"总费用/元": 2462.00,
"开票日期": "2023年08月26日"
}
]
),
(
"Invoicing date",
"../../data/invoices/invoice-4.zip",
"../../../data/invoice_table/invoice-4.xlsx",
[
{
"收款人": "小明",
"城市": "深圳市",
"总费用/元": 412.00,
"开票日期": "2023年02月03日"
},
{
"收款人": "铁头",
"城市": "广州市",
"总费用/元": 898.00,
"开票日期": "2023年03月17日"
},
{
"收款人": "夏天",
"城市": "福州市",
"总费用/元": 2462.00,
"开票日期": "2023年08月26日"
}
]
),
]
)
async def test_invoice_ocr_assistant(
query: str,
invoice_path: str,
invoice_table_path: str,
expected_result: List[dict]
):
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
role = InvoiceOCRAssistant()
await role.run(Message(
content=query,
instruct_content={"file_path": invoice_path}
))
invoice_table_path = os.path.abspath(os.path.join(os.getcwd(), invoice_table_path))
df = pd.read_excel(invoice_table_path)
dict_result = df.to_dict(orient='records')
assert dict_result == expected_result

View file

@ -95,7 +95,7 @@ def test_parse_data():
"""xxx xx""",
list,
None,
Exception,
[],
),
(
"""xxx [1, 2, []xx""",