Merge pull request #16 from geekan/main

update
This commit is contained in:
Guess 2023-10-26 10:32:39 +08:00 committed by GitHub
commit bf49ca4b20
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 715 additions and 37 deletions

View file

@ -3,7 +3,7 @@ FROM nikolaik/python-nodejs:python3.9-nodejs20-slim
# Install Debian software needed by MetaGPT and clean up in one RUN command to reduce image size
RUN apt update &&\
apt install -y git chromium fonts-ipafont-gothic fonts-wqy-zenhei fonts-thai-tlwg fonts-kacst fonts-freefont-ttf libxss1 --no-install-recommends &&\
apt install -y libgomp1 git chromium fonts-ipafont-gothic fonts-wqy-zenhei fonts-thai-tlwg fonts-kacst fonts-freefont-ttf libxss1 --no-install-recommends &&\
apt clean && rm -rf /var/lib/apt/lists/*
# Install Mermaid CLI globally

View file

@ -33,11 +33,13 @@ # MetaGPT: マルチエージェントフレームワーク
<p align="center">ソフトウェア会社のマルチロール図式(順次導入)</p>
## MetaGPTの能力
## MetaGPT の能力
https://github.com/geekan/MetaGPT/assets/34952977/34345016-5d13-489d-b9f9-b82ace413419
## 例GPT-4 で完全生成)
例えば、`python startup.py "Toutiao のような RecSys をデザインする"`と入力すると、多くの出力が得られます
@ -46,6 +48,9 @@ ## 例GPT-4 で完全生成)
解析と設計を含む 1 つの例を生成するのに約 **$0.2**GPT-4 の API 使用料)、完全なプロジェクトでは約 **$2.0** かかります。
## インストール
### インストールビデオガイド
@ -55,7 +60,7 @@ ### インストールビデオガイド
### 伝統的なインストール
```bash
# ステップ 1: NPM がシステムにインストールされていることを確認してください。次に mermaid-js をインストールします。
# ステップ 1: NPM がシステムにインストールされていることを確認してください。次に mermaid-js をインストールします。(お使いのコンピューターに npm がない場合は、Node.js 公式サイトで Node.js https://nodejs.org/ をインストールしてください。)
npm --version
sudo npm install -g @mermaid-js/mermaid-cli
@ -79,7 +84,7 @@ # ステップ 3: リポジトリをローカルマシンにクローンし、
npm install @mermaid-js/mermaid-cli
```
- config.yml に mmdc のコンフィギュレーションを記述するのを忘れないこと
- config.yml に mmdc のコンフィを記述するのを忘れないこと
```yml
PUPPETEER_CONFIG: "./config/puppeteer-config.json"
@ -88,6 +93,71 @@ # ステップ 3: リポジトリをローカルマシンにクローンし、
- もし `pip install -e.` がエラー `[Errno 13] Permission denied: '/usr/local/lib/python3.11/dist-packages/test-easy-install-13129.write-test'` で失敗したら、代わりに `pip install -e. --user` を実行してみてください
- Mermaid charts を SVG、PNG、PDF 形式に変換します。Node.js 版の Mermaid-CLI に加えて、Python 版の Playwright、pyppeteer、または mermaid.ink をこのタスクに使用できるようになりました。
- Playwright
- **Playwright のインストール**
```bash
pip install playwright
```
- **必要なブラウザのインストール**
PDF変換をサポートするには、Chrominumをインストールしてください。
```bash
playwright install --with-deps chromium
```
- **modify `config.yaml`**
config.yaml から MERMAID_ENGINE のコメントを外し、`playwright` に変更する
```yaml
MERMAID_ENGINE: playwright
```
- pyppeteer
- **pyppeteer のインストール**
```bash
pip install pyppeteer
```
- **自分のブラウザを使用**
pyppeteer を使えばインストールされているブラウザを使うことができます、以下の環境を設定してください
```bash
export PUPPETEER_EXECUTABLE_PATH = /path/to/your/chromium or edge or chrome
```
ブラウザのインストールにこのコマンドを使わないでください、これは古すぎます
```bash
pyppeteer-install
```
- **`config.yaml` を修正**
config.yaml から MERMAID_ENGINE のコメントを外し、`pyppeteer` に変更する
```yaml
MERMAID_ENGINE: pyppeteer
```
- mermaid.ink
- **`config.yaml` を修正**
config.yaml から MERMAID_ENGINE のコメントを外し、`ink` に変更する
```yaml
MERMAID_ENGINE: ink
```
注: この方法は pdf エクスポートに対応していません。
### Docker によるインストール
```bash

37
examples/invoice_ocr.py Normal file
View file

@ -0,0 +1,37 @@
#!/usr/bin/env python3
# _*_ coding: utf-8 _*_
"""
@Time : 2023/9/21 21:40:57
@Author : Stitch-z
@File : invoice_ocr.py
"""
import asyncio
from pathlib import Path
from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant
from metagpt.schema import Message
async def main():
relative_paths = [
Path("../tests/data/invoices/invoice-1.pdf"),
Path("../tests/data/invoices/invoice-2.png"),
Path("../tests/data/invoices/invoice-3.jpg"),
Path("../tests/data/invoices/invoice-4.zip")
]
# The absolute path of the file
absolute_file_paths = [Path.cwd() / path for path in relative_paths]
for path in absolute_file_paths:
role = InvoiceOCRAssistant()
await role.run(Message(
content="Invoicing date",
instruct_content={"file_path": path}
))
if __name__ == '__main__':
asyncio.run(main())

View file

@ -0,0 +1,186 @@
#!/usr/bin/env python3
# _*_ coding: utf-8 _*_
"""
@Time : 2023/9/21 18:10:20
@Author : Stitch-z
@File : invoice_ocr.py
@Describe : Actions of the invoice ocr assistant.
"""
import os
import zipfile
from pathlib import Path
from datetime import datetime
import pandas as pd
from paddleocr import PaddleOCR
from metagpt.actions import Action
from metagpt.const import INVOICE_OCR_TABLE_PATH
from metagpt.logs import logger
from metagpt.prompts.invoice_ocr import EXTRACT_OCR_MAIN_INFO_PROMPT, REPLY_OCR_QUESTION_PROMPT
from metagpt.utils.common import OutputParser
from metagpt.utils.file import File
class InvoiceOCR(Action):
"""Action class for performing OCR on invoice files, including zip, PDF, png, and jpg files.
Args:
name: The name of the action. Defaults to an empty string.
language: The language for OCR output. Defaults to "ch" (Chinese).
"""
def __init__(self, name: str = "", *args, **kwargs):
super().__init__(name, *args, **kwargs)
@staticmethod
async def _check_file_type(file_path: Path) -> str:
"""Check the file type of the given filename.
Args:
file_path: The path of the file.
Returns:
The file type based on FileExtensionType enum.
Raises:
Exception: If the file format is not zip, pdf, png, or jpg.
"""
ext = file_path.suffix
if ext not in [".zip", ".pdf", ".png", ".jpg"]:
raise Exception("The invoice format is not zip, pdf, png, or jpg")
return ext
@staticmethod
async def _unzip(file_path: Path) -> Path:
"""Unzip a file and return the path to the unzipped directory.
Args:
file_path: The path to the zip file.
Returns:
The path to the unzipped directory.
"""
file_directory = file_path.parent / "unzip_invoices" / datetime.now().strftime("%Y%m%d%H%M%S")
with zipfile.ZipFile(file_path, "r") as zip_ref:
for zip_info in zip_ref.infolist():
# Use CP437 to encode the file name, and then use GBK decoding to prevent Chinese garbled code
relative_name = Path(zip_info.filename.encode("cp437").decode("gbk"))
if relative_name.suffix:
full_filename = file_directory / relative_name
await File.write(full_filename.parent, relative_name.name, zip_ref.read(zip_info.filename))
logger.info(f"unzip_path: {file_directory}")
return file_directory
@staticmethod
async def _ocr(invoice_file_path: Path):
ocr = PaddleOCR(use_angle_cls=True, lang="ch", page_num=1)
ocr_result = ocr.ocr(str(invoice_file_path), cls=True)
return ocr_result
async def run(self, file_path: Path, *args, **kwargs) -> list:
"""Execute the action to identify invoice files through OCR.
Args:
file_path: The path to the input file.
Returns:
A list of OCR results.
"""
file_ext = await self._check_file_type(file_path)
if file_ext == ".zip":
# OCR recognizes zip batch files
unzip_path = await self._unzip(file_path)
ocr_list = []
for root, _, files in os.walk(unzip_path):
for filename in files:
invoice_file_path = Path(root) / Path(filename)
# Identify files that match the type
if Path(filename).suffix in [".zip", ".pdf", ".png", ".jpg"]:
ocr_result = await self._ocr(str(invoice_file_path))
ocr_list.append(ocr_result)
return ocr_list
else:
# OCR identifies single file
ocr_result = await self._ocr(file_path)
return [ocr_result]
class GenerateTable(Action):
"""Action class for generating tables from OCR results.
Args:
name: The name of the action. Defaults to an empty string.
language: The language used for the generated table. Defaults to "ch" (Chinese).
"""
def __init__(self, name: str = "", language: str = "ch", *args, **kwargs):
super().__init__(name, *args, **kwargs)
self.language = language
async def run(self, ocr_results: list, filename: str, *args, **kwargs) -> dict[str, str]:
"""Processes OCR results, extracts invoice information, generates a table, and saves it as an Excel file.
Args:
ocr_results: A list of OCR results obtained from invoice processing.
filename: The name of the output Excel file.
Returns:
A dictionary containing the invoice information.
"""
table_data = []
pathname = INVOICE_OCR_TABLE_PATH
pathname.mkdir(parents=True, exist_ok=True)
for ocr_result in ocr_results:
# Extract invoice OCR main information
prompt = EXTRACT_OCR_MAIN_INFO_PROMPT.format(ocr_result=ocr_result, language=self.language)
ocr_info = await self._aask(prompt=prompt)
invoice_data = OutputParser.extract_struct(ocr_info, dict)
if invoice_data:
table_data.append(invoice_data)
# Generate Excel file
filename = f"{filename.split('.')[0]}.xlsx"
full_filename = f"{pathname}/{filename}"
df = pd.DataFrame(table_data)
df.to_excel(full_filename, index=False)
return table_data
class ReplyQuestion(Action):
"""Action class for generating replies to questions based on OCR results.
Args:
name: The name of the action. Defaults to an empty string.
language: The language used for generating the reply. Defaults to "ch" (Chinese).
"""
def __init__(self, name: str = "", language: str = "ch", *args, **kwargs):
super().__init__(name, *args, **kwargs)
self.language = language
async def run(self, query: str, ocr_result: list, *args, **kwargs) -> str:
"""Reply to questions based on ocr results.
Args:
query: The question for which a reply is generated.
ocr_result: A list of OCR results.
Returns:
A reply result of string type.
"""
prompt = REPLY_OCR_QUESTION_PROMPT.format(query=query, ocr_result=ocr_result, language=self.language)
resp = await self._aask(prompt=prompt)
return resp

View file

@ -36,6 +36,7 @@ YAPI_URL = "http://yapi.deepwisdomai.com/"
TMP = PROJECT_ROOT / "tmp"
RESEARCH_PATH = DATA_PATH / "research"
TUTORIAL_PATH = DATA_PATH / "tutorial_docx"
INVOICE_OCR_TABLE_PATH = DATA_PATH / "invoice_table"
SKILL_DIRECTORY = PROJECT_ROOT / "metagpt/skills"

View file

@ -0,0 +1,39 @@
#!/usr/bin/env python3
# _*_ coding: utf-8 _*_
"""
@Time : 2023/9/21 16:30:25
@Author : Stitch-z
@File : invoice_ocr.py
@Describe : Prompts of the invoice ocr assistant.
"""
COMMON_PROMPT = "Now I will provide you with the OCR text recognition results for the invoice."
EXTRACT_OCR_MAIN_INFO_PROMPT = COMMON_PROMPT + """
Please extract the payee, city, total cost, and invoicing date of the invoice.
The OCR data of the invoice are as follows:
{ocr_result}
Mandatory restrictions are returned according to the following requirements:
1. The total cost refers to the total price and tax. Do not include `¥`.
2. The city must be the recipient's city.
2. The returned JSON dictionary must be returned in {language}
3. Mandatory requirement to output in JSON format: {{"收款人":"x","城市":"x","总费用/元":"","开票日期":""}}.
"""
REPLY_OCR_QUESTION_PROMPT = COMMON_PROMPT + """
Please answer the question: {query}
The OCR data of the invoice are as follows:
{ocr_result}
Mandatory restrictions are returned according to the following requirements:
1. Answer in {language} language.
2. Enforce restrictions on not returning OCR data sent to you.
3. Return with markdown syntax layout.
"""
INVOICE_OCR_SUCCESS = "Successfully completed OCR text recognition invoice."

View file

@ -111,19 +111,19 @@ class CostManager(metaclass=Singleton):
return self.total_completion_tokens
def get_total_cost(self):
"""
Get the total cost of API calls.
def get_total_cost(self):
"""
Get the total cost of API calls.
Returns:
float: The total cost of API calls.
"""
return self.total_cost
Returns:
float: The total cost of API calls.
"""
return self.total_cost
def get_costs(self) -> Costs:
"""Get all costs"""
return Costs(self.total_prompt_tokens, self.total_completion_tokens, self.total_cost, self.total_budget)
def get_costs(self) -> Costs:
"""Get all costs"""
return Costs(self.total_prompt_tokens, self.total_completion_tokens, self.total_cost, self.total_budget)
def log_and_reraise(retry_state):

View file

@ -0,0 +1,110 @@
#!/usr/bin/env python3
# _*_ coding: utf-8 _*_
"""
@Time : 2023/9/21 14:10:05
@Author : Stitch-z
@File : invoice_ocr_assistant.py
"""
import pandas as pd
from metagpt.actions.invoice_ocr import InvoiceOCR, GenerateTable, ReplyQuestion
from metagpt.prompts.invoice_ocr import INVOICE_OCR_SUCCESS
from metagpt.roles import Role
from metagpt.schema import Message
class InvoiceOCRAssistant(Role):
"""Invoice OCR assistant, support OCR text recognition of invoice PDF, png, jpg, and zip files,
generate a table for the payee, city, total amount, and invoicing date of the invoice,
and ask questions for a single file based on the OCR recognition results of the invoice.
Args:
name: The name of the role.
profile: The role profile description.
goal: The goal of the role.
constraints: Constraints or requirements for the role.
language: The language in which the invoice table will be generated.
"""
def __init__(
self,
name: str = "Stitch",
profile: str = "Invoice OCR Assistant",
goal: str = "OCR identifies invoice files and generates invoice main information table",
constraints: str = "",
language: str = "ch",
):
super().__init__(name, profile, goal, constraints)
self._init_actions([InvoiceOCR])
self.language = language
self.filename = ""
self.origin_query = ""
self.orc_data = None
async def _think(self) -> None:
"""Determine the next action to be taken by the role."""
if self._rc.todo is None:
self._set_state(0)
return
if self._rc.state + 1 < len(self._states):
self._set_state(self._rc.state + 1)
else:
self._rc.todo = None
async def _act(self) -> Message:
"""Perform an action as determined by the role.
Returns:
A message containing the result of the action.
"""
msg = self._rc.memory.get(k=1)[0]
todo = self._rc.todo
if isinstance(todo, InvoiceOCR):
self.origin_query = msg.content
file_path = msg.instruct_content.get("file_path")
self.filename = file_path.name
if not file_path:
raise Exception("Invoice file not uploaded")
resp = await todo.run(file_path)
if len(resp) == 1:
# Single file support for questioning based on OCR recognition results
self._init_actions([GenerateTable, ReplyQuestion])
self.orc_data = resp[0]
else:
self._init_actions([GenerateTable])
self._rc.todo = None
content = INVOICE_OCR_SUCCESS
elif isinstance(todo, GenerateTable):
ocr_results = msg.instruct_content
resp = await todo.run(ocr_results, self.filename)
# Convert list to Markdown format string
df = pd.DataFrame(resp)
markdown_table = df.to_markdown(index=False)
content = f"{markdown_table}\n\n\n"
else:
resp = await todo.run(self.origin_query, self.orc_data)
content = resp
msg = Message(content=content, instruct_content=resp)
self._rc.memory.add(msg)
return msg
async def _react(self) -> Message:
"""Execute the invoice ocr assistant's think and actions.
Returns:
A message containing the final result of the assistant's actions.
"""
while True:
await self._think()
if self._rc.todo is None:
break
msg = await self._act()
return msg

View file

@ -1,11 +1,11 @@
import re
from typing import List, Callable
from typing import List, Callable, Dict
from pathlib import Path
import wrapt
import textwrap
import inspect
from interpreter.interpreter import Interpreter
from interpreter.core.core import Interpreter
from metagpt.logs import logger
from metagpt.config import CONFIG
@ -41,13 +41,13 @@ class OpenCodeInterpreter(object):
interpreter.auto_run = auto_run
interpreter.model = CONFIG.openai_api_model or "gpt-3.5-turbo"
interpreter.api_key = CONFIG.openai_api_key
interpreter.api_base = CONFIG.openai_api_base
# interpreter.api_base = CONFIG.openai_api_base
self.interpreter = interpreter
def chat(self, query: str, reset: bool = True):
if reset:
self.interpreter.reset()
return self.interpreter.chat(query, return_messages=True)
return self.interpreter.chat(query)
@staticmethod
def extract_function(query_respond: List, function_name: str, *, language: str = 'python',
@ -61,11 +61,30 @@ class OpenCodeInterpreter(object):
assert language == 'python', f"Expect python language for default function_format, but got {language}."
function_format = """def {function_name}():\n{code}"""
# Extract the code module in the open-interpreter respond message.
code = [item['function_call']['parsed_arguments']['code'] for item in query_respond
if "function_call" in item
and "parsed_arguments" in item["function_call"]
and 'language' in item["function_call"]['parsed_arguments']
and item["function_call"]['parsed_arguments']['language'] == language]
# The query_respond of open-interpreter before v0.1.4 is:
# [{'role': 'user', 'content': your query string},
# {'role': 'assistant', 'content': plan from llm, 'function_call': {
# "name": "run_code", "arguments": "{"language": "python", "code": code of first plan},
# "parsed_arguments": {"language": "python", "code": code of first plan}
# ...]
if "function_call" in query_respond[1]:
code = [item['function_call']['parsed_arguments']['code'] for item in query_respond
if "function_call" in item
and "parsed_arguments" in item["function_call"]
and 'language' in item["function_call"]['parsed_arguments']
and item["function_call"]['parsed_arguments']['language'] == language]
# The query_respond of open-interpreter v0.1.7 is:
# [{'role': 'user', 'message': your query string},
# {'role': 'assistant', 'message': plan from llm, 'language': 'python',
# 'code': code of first plan, 'output': output of first plan code},
# ...]
elif "code" in query_respond[1]:
code = [item['code'] for item in query_respond
if "code" in item
and 'language' in item
and item['language'] == language]
else:
raise ValueError(f"Unexpect message format in query_respond: {query_respond[1].keys()}")
# add indent.
indented_code_str = textwrap.indent("\n".join(code), ' ' * 4)
# Return the code after deduplication.
@ -94,13 +113,49 @@ class OpenInterpreterDecorator(object):
self.code_file_path = code_file_path
self.clear_code = clear_code
def _have_code(self, rsp: List[Dict]):
# Is there any code generated?
return 'code' in rsp[1] and rsp[1]['code'] not in ("", None)
def _is_faild_plan(self, rsp: List[Dict]):
# is faild plan?
func_code = OpenCodeInterpreter.extract_function(rsp, 'function')
# If there is no more than 1 '\n', the plan execution fails.
if isinstance(func_code, str) and func_code.count('\n') <= 1:
return True
return False
def _check_respond(self, query: str, interpreter: OpenCodeInterpreter, respond: List[Dict], max_try: int = 3):
for _ in range(max_try):
# TODO: If no code or faild plan is generated, execute chat again, repeating no more than max_try times.
if self._have_code(respond) and not self._is_faild_plan(respond):
break
elif not self._have_code(respond):
logger.warning(f"llm did not return executable code, resend the query: \n{query}")
respond = interpreter.chat(query)
elif self._is_faild_plan(respond):
logger.warning(f"llm did not generate successful plan, resend the query: \n{query}")
respond = interpreter.chat(query)
# Post-processing of respond
if not self._have_code(respond):
error_msg = f"OpenCodeInterpreter do not generate code for query: \n{query}"
logger.error(error_msg)
raise ValueError(error_msg)
if self._is_faild_plan(respond):
error_msg = f"OpenCodeInterpreter do not generate code for query: \n{query}"
logger.error(error_msg)
raise ValueError(error_msg)
return respond
def __call__(self, wrapped):
@wrapt.decorator
async def wrapper(wrapped: Callable, instance, args, kwargs):
# Get the decorated function name.
func_name = wrapped.__name__
# If the script exists locally and clearcode is not required, execute the function from the script.
if Path(self.code_file_path).is_file() and not self.clear_code:
if self.code_file_path and Path(self.code_file_path).is_file() and not self.clear_code:
return run_function_script(self.code_file_path, func_name, *args, **kwargs)
# Auto run generate code by using open-interpreter.
@ -108,6 +163,8 @@ class OpenInterpreterDecorator(object):
query = gen_query(wrapped, args, kwargs)
logger.info(f"query for OpenCodeInterpreter: \n {query}")
respond = interpreter.chat(query)
# Make sure the response is as expected.
respond = self._check_respond(query, interpreter, respond, 3)
# Assemble the code blocks generated by open-interpreter into a function without parameters.
func_code = interpreter.extract_function(respond, func_name)
# Clone the `func_code` into wrapped, that is,
@ -121,9 +178,10 @@ class OpenInterpreterDecorator(object):
# execute this function.
try:
res = run_function_code(code, func_name, *args, **kwargs)
if self.save_code:
if self.save_code and self.code_file_path:
cf._save(self.code_file_path, code)
except Exception as e:
logger.error(f"Could not evaluate Python code \n{logger_code}: \nError: {e}")
raise Exception("Could not evaluate Python code", e)
return res
return wrapper(wrapped)

View file

@ -195,7 +195,8 @@ class OutputParser:
except (ValueError, SyntaxError) as e:
raise Exception(f"Error while extracting and parsing the {data_type}: {e}")
else:
raise Exception(f"No {data_type} found in the text.")
logger.error(f"No {data_type} found in the text.")
return [] if data_type is list else {}
class CodeParser:

4
requirements-ocr.txt Normal file
View file

@ -0,0 +1,4 @@
paddlepaddle==2.4.2
paddleocr>=2.0.1
tabulate==0.9.0
-r requirements.txt

View file

@ -14,7 +14,7 @@ langchain==0.0.231
loguru==0.6.0
meilisearch==0.21.0
numpy==1.24.3
openai==0.27.8
openai
openpyxl
beautifulsoup4==4.12.2
pandas==2.0.3
@ -23,7 +23,7 @@ pydantic==1.10.8
#pymilvus==2.2.8
pytest==7.2.2
python_docx==0.8.11
PyYAML==6.0
PyYAML==6.0.1
# sentence_transformers==2.2.2
setuptools==65.6.3
tenacity==8.2.2
@ -39,13 +39,8 @@ typing_extensions==4.5.0
libcst==1.0.1
qdrant-client==1.4.0
pytest-mock==3.11.1
open-interpreter==0.1.4; python_version>"3.9"
open-interpreter==0.1.7; python_version>"3.9"
ta==0.10.2
semantic-kernel==0.3.10.dev0
semantic-kernel==0.3.13.dev0
wrapt==1.15.0
websocket-client==0.58.0
aiofiles~=23.2.1
pygments~=2.16.1
requests~=2.31.0
yaml~=0.2.5

Binary file not shown.

Binary file not shown.

After

Width:  |  Height:  |  Size: 464 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 466 KiB

Binary file not shown.

View file

@ -0,0 +1,72 @@
#!/usr/bin/env python3
# _*_ coding: utf-8 _*_
"""
@Time : 2023/10/09 18:40:34
@Author : Stitch-z
@File : test_invoice_ocr.py
"""
import os
from typing import List
import pytest
from pathlib import Path
from metagpt.actions.invoice_ocr import InvoiceOCR, GenerateTable, ReplyQuestion
@pytest.mark.asyncio
@pytest.mark.parametrize(
"invoice_path",
[
"../../data/invoices/invoice-3.jpg",
"../../data/invoices/invoice-4.zip",
]
)
async def test_invoice_ocr(invoice_path: str):
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
filename = os.path.basename(invoice_path)
resp = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)
assert isinstance(resp, list)
@pytest.mark.asyncio
@pytest.mark.parametrize(
("invoice_path", "expected_result"),
[
(
"../../data/invoices/invoice-1.pdf",
[
{
"收款人": "小明",
"城市": "深圳市",
"总费用/元": "412.00",
"开票日期": "2023年02月03日"
}
]
),
]
)
async def test_generate_table(invoice_path: str, expected_result: list[dict]):
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
filename = os.path.basename(invoice_path)
ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)
table_data = await GenerateTable().run(ocr_results=ocr_result, filename=filename)
assert table_data == expected_result
@pytest.mark.asyncio
@pytest.mark.parametrize(
("invoice_path", "query", "expected_result"),
[
("../../data/invoices/invoice-1.pdf", "Invoicing date", "2023年02月03日")
]
)
async def test_reply_question(invoice_path: str, query: dict, expected_result: str):
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
filename = os.path.basename(invoice_path)
ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)
result = await ReplyQuestion().run(query=query, ocr_result=ocr_result)
assert expected_result in result

View file

@ -0,0 +1,105 @@
#!/usr/bin/env python3
# _*_ coding: utf-8 _*_
"""
@Time : 2023/9/21 23:11:27
@Author : Stitch-z
@File : test_invoice_ocr_assistant.py
"""
from pathlib import Path
import pytest
import pandas as pd
from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant
from metagpt.schema import Message
@pytest.mark.asyncio
@pytest.mark.parametrize(
("query", "invoice_path", "invoice_table_path", "expected_result"),
[
(
"Invoicing date",
Path("../../data/invoices/invoice-1.pdf"),
Path("../../../data/invoice_table/invoice-1.xlsx"),
[
{
"收款人": "小明",
"城市": "深圳市",
"总费用/元": 412.00,
"开票日期": "2023年02月03日"
}
]
),
(
"Invoicing date",
Path("../../data/invoices/invoice-2.png"),
Path("../../../data/invoice_table/invoice-2.xlsx"),
[
{
"收款人": "铁头",
"城市": "广州市",
"总费用/元": 898.00,
"开票日期": "2023年03月17日"
}
]
),
(
"Invoicing date",
Path("../../data/invoices/invoice-3.jpg"),
Path("../../../data/invoice_table/invoice-3.xlsx"),
[
{
"收款人": "夏天",
"城市": "福州市",
"总费用/元": 2462.00,
"开票日期": "2023年08月26日"
}
]
),
(
"Invoicing date",
Path("../../data/invoices/invoice-4.zip"),
Path("../../../data/invoice_table/invoice-4.xlsx"),
[
{
"收款人": "小明",
"城市": "深圳市",
"总费用/元": 412.00,
"开票日期": "2023年02月03日"
},
{
"收款人": "铁头",
"城市": "广州市",
"总费用/元": 898.00,
"开票日期": "2023年03月17日"
},
{
"收款人": "夏天",
"城市": "福州市",
"总费用/元": 2462.00,
"开票日期": "2023年08月26日"
}
]
),
]
)
async def test_invoice_ocr_assistant(
query: str,
invoice_path: Path,
invoice_table_path: Path,
expected_result: list[dict]
):
invoice_path = Path.cwd() / invoice_path
role = InvoiceOCRAssistant()
await role.run(Message(
content=query,
instruct_content={"file_path": invoice_path}
))
invoice_table_path = Path.cwd() / invoice_table_path
df = pd.read_excel(invoice_table_path)
dict_result = df.to_dict(orient='records')
assert dict_result == expected_result

View file

@ -95,7 +95,7 @@ def test_parse_data():
"""xxx xx""",
list,
None,
Exception,
[],
),
(
"""xxx [1, 2, []xx""",