Feature: invoice ocr assistant

This commit is contained in:
Stitch-z 2023-10-10 14:26:03 +08:00
parent 443c044990
commit e3ecb88293
15 changed files with 598 additions and 4 deletions

View file

@ -0,0 +1,72 @@
#!/usr/bin/env python3
# _*_ coding: utf-8 _*_
"""
@Time : 2023/10/09 18:40:34
@Author : Stitch-z
@File : test_invoice_ocr.py
"""
import os
from pathlib import Path
from typing import List
import pytest
from metagpt.actions.invoice_ocr import InvoiceOCR, GenerateTable, ReplyQuestion
@pytest.mark.asyncio
@pytest.mark.parametrize(
"invoice_path",
[
"../../data/invoices/invoice-3.jpg",
"../../data/invoices/invoice-4.zip",
]
)
async def test_invoice_ocr(invoice_path: str):
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
filename = os.path.basename(invoice_path)
resp = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)
assert isinstance(resp, list)
@pytest.mark.asyncio
@pytest.mark.parametrize(
("invoice_path", "expected_result"),
[
(
"../../data/invoices/invoice-1.pdf",
[
{
"收款人": "小明",
"城市": "深圳市",
"总费用/元": "412.00",
"开票日期": "2023年02月03日"
}
]
),
]
)
async def test_generate_table(invoice_path: str, expected_result: List[dict]):
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
filename = os.path.basename(invoice_path)
ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)
table_data = await GenerateTable().run(ocr_results=ocr_result, filename=filename)
assert table_data == expected_result
@pytest.mark.asyncio
@pytest.mark.parametrize(
("invoice_path", "query", "expected_result"),
[
("../../data/invoices/invoice-1.pdf", "Invoicing date", "2023年02月03日")
]
)
async def test_reply_question(invoice_path: str, query: dict, expected_result: str):
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
filename = os.path.basename(invoice_path)
ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)
result = await ReplyQuestion().run(query=query, ocr_result=ocr_result)
assert expected_result in result

View file

@ -0,0 +1,105 @@
#!/usr/bin/env python3
# _*_ coding: utf-8 _*_
"""
@Time : 2023/9/21 23:11:27
@Author : Stitch-z
@File : test_invoice_ocr_assistant.py
"""
import os
import pandas as pd
from typing import List
import pytest
from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant
from metagpt.schema import Message
@pytest.mark.asyncio
@pytest.mark.parametrize(
("query", "invoice_path", "invoice_table_path", "expected_result"),
[
(
"Invoicing date",
"../../data/invoices/invoice-1.pdf",
"../../../data/invoice_table/invoice-1.xlsx",
[
{
"收款人": "小明",
"城市": "深圳市",
"总费用/元": 412.00,
"开票日期": "2023年02月03日"
}
]
),
(
"Invoicing date",
"../../data/invoices/invoice-2.png",
"../../../data/invoice_table/invoice-2.xlsx",
[
{
"收款人": "铁头",
"城市": "广州市",
"总费用/元": 898.00,
"开票日期": "2023年03月17日"
}
]
),
(
"Invoicing date",
"../../data/invoices/invoice-3.jpg",
"../../../data/invoice_table/invoice-3.xlsx",
[
{
"收款人": "夏天",
"城市": "福州市",
"总费用/元": 2462.00,
"开票日期": "2023年08月26日"
}
]
),
(
"Invoicing date",
"../../data/invoices/invoice-4.zip",
"../../../data/invoice_table/invoice-4.xlsx",
[
{
"收款人": "小明",
"城市": "深圳市",
"总费用/元": 412.00,
"开票日期": "2023年02月03日"
},
{
"收款人": "铁头",
"城市": "广州市",
"总费用/元": 898.00,
"开票日期": "2023年03月17日"
},
{
"收款人": "夏天",
"城市": "福州市",
"总费用/元": 2462.00,
"开票日期": "2023年08月26日"
}
]
),
]
)
async def test_invoice_ocr_assistant(
query: str,
invoice_path: str,
invoice_table_path: str,
expected_result: List[dict]
):
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
role = InvoiceOCRAssistant()
await role.run(Message(
content=query,
instruct_content={"file_path": invoice_path}
))
invoice_table_path = os.path.abspath(os.path.join(os.getcwd(), invoice_table_path))
df = pd.read_excel(invoice_table_path)
dict_result = df.to_dict(orient='records')
assert dict_result == expected_result

View file

@ -95,7 +95,7 @@ def test_parse_data():
"""xxx xx""",
list,
None,
Exception,
[],
),
(
"""xxx [1, 2, []xx""",