diff --git a/Dockerfile b/Dockerfile index 8ab180e28..c6e22989b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,7 +3,7 @@ FROM nikolaik/python-nodejs:python3.9-nodejs20-slim # Install Debian software needed by MetaGPT and clean up in one RUN command to reduce image size RUN apt update &&\ - apt install -y git chromium fonts-ipafont-gothic fonts-wqy-zenhei fonts-thai-tlwg fonts-kacst fonts-freefont-ttf libxss1 --no-install-recommends &&\ + apt install -y libgomp1 git chromium fonts-ipafont-gothic fonts-wqy-zenhei fonts-thai-tlwg fonts-kacst fonts-freefont-ttf libxss1 --no-install-recommends &&\ apt clean && rm -rf /var/lib/apt/lists/* # Install Mermaid CLI globally diff --git a/docs/README_JA.md b/docs/README_JA.md index 6ffc80ac7..6a7d43fc2 100644 --- a/docs/README_JA.md +++ b/docs/README_JA.md @@ -33,11 +33,13 @@ # MetaGPT: マルチエージェントフレームワーク
ソフトウェア会社のマルチロール図式(順次導入)
-## MetaGPTの能力 +## MetaGPT の能力 + https://github.com/geekan/MetaGPT/assets/34952977/34345016-5d13-489d-b9f9-b82ace413419 + ## 例(GPT-4 で完全生成) 例えば、`python startup.py "Toutiao のような RecSys をデザインする"`と入力すると、多くの出力が得られます @@ -46,6 +48,9 @@ ## 例(GPT-4 で完全生成) 解析と設計を含む 1 つの例を生成するのに約 **$0.2**(GPT-4 の API 使用料)、完全なプロジェクトでは約 **$2.0** かかります。 + + + ## インストール ### インストールビデオガイド @@ -55,7 +60,7 @@ ### インストールビデオガイド ### 伝統的なインストール ```bash -# ステップ 1: NPM がシステムにインストールされていることを確認してください。次に mermaid-js をインストールします。 +# ステップ 1: NPM がシステムにインストールされていることを確認してください。次に mermaid-js をインストールします。(お使いのコンピューターに npm がない場合は、Node.js 公式サイトで Node.js https://nodejs.org/ をインストールしてください。) npm --version sudo npm install -g @mermaid-js/mermaid-cli @@ -79,7 +84,7 @@ # ステップ 3: リポジトリをローカルマシンにクローンし、 npm install @mermaid-js/mermaid-cli ``` -- config.yml に mmdc のコンフィギュレーションを記述するのを忘れないこと +- config.yml に mmdc のコンフィグを記述するのを忘れないこと ```yml PUPPETEER_CONFIG: "./config/puppeteer-config.json" @@ -88,6 +93,71 @@ # ステップ 3: リポジトリをローカルマシンにクローンし、 - もし `pip install -e.` がエラー `[Errno 13] Permission denied: '/usr/local/lib/python3.11/dist-packages/test-easy-install-13129.write-test'` で失敗したら、代わりに `pip install -e. --user` を実行してみてください +- Mermaid charts を SVG、PNG、PDF 形式に変換します。Node.js 版の Mermaid-CLI に加えて、Python 版の Playwright、pyppeteer、または mermaid.ink をこのタスクに使用できるようになりました。 + + - Playwright + - **Playwright のインストール** + + ```bash + pip install playwright + ``` + + - **必要なブラウザのインストール** + + PDF変換をサポートするには、Chrominumをインストールしてください。 + + ```bash + playwright install --with-deps chromium + ``` + + - **modify `config.yaml`** + + config.yaml から MERMAID_ENGINE のコメントを外し、`playwright` に変更する + + ```yaml + MERMAID_ENGINE: playwright + ``` + + - pyppeteer + - **pyppeteer のインストール** + + ```bash + pip install pyppeteer + ``` + + - **自分のブラウザを使用** + + pyppeteer を使えばインストールされているブラウザを使うことができます、以下の環境を設定してください + + ```bash + export PUPPETEER_EXECUTABLE_PATH = /path/to/your/chromium or edge or chrome + ``` + + ブラウザのインストールにこのコマンドを使わないでください、これは古すぎます + + ```bash + pyppeteer-install + ``` + + - **`config.yaml` を修正** + + config.yaml から MERMAID_ENGINE のコメントを外し、`pyppeteer` に変更する + + ```yaml + MERMAID_ENGINE: pyppeteer + ``` + + - mermaid.ink + - **`config.yaml` を修正** + + config.yaml から MERMAID_ENGINE のコメントを外し、`ink` に変更する + + ```yaml + MERMAID_ENGINE: ink + ``` + + 注: この方法は pdf エクスポートに対応していません。 + ### Docker によるインストール ```bash diff --git a/examples/invoice_ocr.py b/examples/invoice_ocr.py new file mode 100644 index 000000000..11656ed52 --- /dev/null +++ b/examples/invoice_ocr.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +# _*_ coding: utf-8 _*_ + +""" +@Time : 2023/9/21 21:40:57 +@Author : Stitch-z +@File : invoice_ocr.py +""" + +import asyncio +from pathlib import Path + +from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant +from metagpt.schema import Message + + +async def main(): + relative_paths = [ + Path("../tests/data/invoices/invoice-1.pdf"), + Path("../tests/data/invoices/invoice-2.png"), + Path("../tests/data/invoices/invoice-3.jpg"), + Path("../tests/data/invoices/invoice-4.zip") + ] + # The absolute path of the file + absolute_file_paths = [Path.cwd() / path for path in relative_paths] + + for path in absolute_file_paths: + role = InvoiceOCRAssistant() + await role.run(Message( + content="Invoicing date", + instruct_content={"file_path": path} + )) + + +if __name__ == '__main__': + asyncio.run(main()) + diff --git a/metagpt/actions/invoice_ocr.py b/metagpt/actions/invoice_ocr.py new file mode 100644 index 000000000..b37aa6885 --- /dev/null +++ b/metagpt/actions/invoice_ocr.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python3 +# _*_ coding: utf-8 _*_ + +""" +@Time : 2023/9/21 18:10:20 +@Author : Stitch-z +@File : invoice_ocr.py +@Describe : Actions of the invoice ocr assistant. +""" + +import os +import zipfile +from pathlib import Path +from datetime import datetime + +import pandas as pd +from paddleocr import PaddleOCR + +from metagpt.actions import Action +from metagpt.const import INVOICE_OCR_TABLE_PATH +from metagpt.logs import logger +from metagpt.prompts.invoice_ocr import EXTRACT_OCR_MAIN_INFO_PROMPT, REPLY_OCR_QUESTION_PROMPT +from metagpt.utils.common import OutputParser +from metagpt.utils.file import File + + +class InvoiceOCR(Action): + """Action class for performing OCR on invoice files, including zip, PDF, png, and jpg files. + + Args: + name: The name of the action. Defaults to an empty string. + language: The language for OCR output. Defaults to "ch" (Chinese). + + """ + + def __init__(self, name: str = "", *args, **kwargs): + super().__init__(name, *args, **kwargs) + + @staticmethod + async def _check_file_type(file_path: Path) -> str: + """Check the file type of the given filename. + + Args: + file_path: The path of the file. + + Returns: + The file type based on FileExtensionType enum. + + Raises: + Exception: If the file format is not zip, pdf, png, or jpg. + """ + ext = file_path.suffix + if ext not in [".zip", ".pdf", ".png", ".jpg"]: + raise Exception("The invoice format is not zip, pdf, png, or jpg") + + return ext + + @staticmethod + async def _unzip(file_path: Path) -> Path: + """Unzip a file and return the path to the unzipped directory. + + Args: + file_path: The path to the zip file. + + Returns: + The path to the unzipped directory. + """ + file_directory = file_path.parent / "unzip_invoices" / datetime.now().strftime("%Y%m%d%H%M%S") + with zipfile.ZipFile(file_path, "r") as zip_ref: + for zip_info in zip_ref.infolist(): + # Use CP437 to encode the file name, and then use GBK decoding to prevent Chinese garbled code + relative_name = Path(zip_info.filename.encode("cp437").decode("gbk")) + if relative_name.suffix: + full_filename = file_directory / relative_name + await File.write(full_filename.parent, relative_name.name, zip_ref.read(zip_info.filename)) + + logger.info(f"unzip_path: {file_directory}") + return file_directory + + @staticmethod + async def _ocr(invoice_file_path: Path): + ocr = PaddleOCR(use_angle_cls=True, lang="ch", page_num=1) + ocr_result = ocr.ocr(str(invoice_file_path), cls=True) + return ocr_result + + async def run(self, file_path: Path, *args, **kwargs) -> list: + """Execute the action to identify invoice files through OCR. + + Args: + file_path: The path to the input file. + + Returns: + A list of OCR results. + """ + file_ext = await self._check_file_type(file_path) + + if file_ext == ".zip": + # OCR recognizes zip batch files + unzip_path = await self._unzip(file_path) + ocr_list = [] + for root, _, files in os.walk(unzip_path): + for filename in files: + invoice_file_path = Path(root) / Path(filename) + # Identify files that match the type + if Path(filename).suffix in [".zip", ".pdf", ".png", ".jpg"]: + ocr_result = await self._ocr(str(invoice_file_path)) + ocr_list.append(ocr_result) + return ocr_list + + else: + # OCR identifies single file + ocr_result = await self._ocr(file_path) + return [ocr_result] + + +class GenerateTable(Action): + """Action class for generating tables from OCR results. + + Args: + name: The name of the action. Defaults to an empty string. + language: The language used for the generated table. Defaults to "ch" (Chinese). + + """ + + def __init__(self, name: str = "", language: str = "ch", *args, **kwargs): + super().__init__(name, *args, **kwargs) + self.language = language + + async def run(self, ocr_results: list, filename: str, *args, **kwargs) -> dict[str, str]: + """Processes OCR results, extracts invoice information, generates a table, and saves it as an Excel file. + + Args: + ocr_results: A list of OCR results obtained from invoice processing. + filename: The name of the output Excel file. + + Returns: + A dictionary containing the invoice information. + + """ + table_data = [] + pathname = INVOICE_OCR_TABLE_PATH + pathname.mkdir(parents=True, exist_ok=True) + + for ocr_result in ocr_results: + # Extract invoice OCR main information + prompt = EXTRACT_OCR_MAIN_INFO_PROMPT.format(ocr_result=ocr_result, language=self.language) + ocr_info = await self._aask(prompt=prompt) + invoice_data = OutputParser.extract_struct(ocr_info, dict) + if invoice_data: + table_data.append(invoice_data) + + # Generate Excel file + filename = f"{filename.split('.')[0]}.xlsx" + full_filename = f"{pathname}/{filename}" + df = pd.DataFrame(table_data) + df.to_excel(full_filename, index=False) + return table_data + + +class ReplyQuestion(Action): + """Action class for generating replies to questions based on OCR results. + + Args: + name: The name of the action. Defaults to an empty string. + language: The language used for generating the reply. Defaults to "ch" (Chinese). + + """ + + def __init__(self, name: str = "", language: str = "ch", *args, **kwargs): + super().__init__(name, *args, **kwargs) + self.language = language + + async def run(self, query: str, ocr_result: list, *args, **kwargs) -> str: + """Reply to questions based on ocr results. + + Args: + query: The question for which a reply is generated. + ocr_result: A list of OCR results. + + Returns: + A reply result of string type. + """ + prompt = REPLY_OCR_QUESTION_PROMPT.format(query=query, ocr_result=ocr_result, language=self.language) + resp = await self._aask(prompt=prompt) + return resp + diff --git a/metagpt/const.py b/metagpt/const.py index b8b08628e..7f3f87dfa 100644 --- a/metagpt/const.py +++ b/metagpt/const.py @@ -36,6 +36,7 @@ YAPI_URL = "http://yapi.deepwisdomai.com/" TMP = PROJECT_ROOT / "tmp" RESEARCH_PATH = DATA_PATH / "research" TUTORIAL_PATH = DATA_PATH / "tutorial_docx" +INVOICE_OCR_TABLE_PATH = DATA_PATH / "invoice_table" SKILL_DIRECTORY = PROJECT_ROOT / "metagpt/skills" diff --git a/metagpt/prompts/invoice_ocr.py b/metagpt/prompts/invoice_ocr.py new file mode 100644 index 000000000..52f628a5b --- /dev/null +++ b/metagpt/prompts/invoice_ocr.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 +# _*_ coding: utf-8 _*_ + +""" +@Time : 2023/9/21 16:30:25 +@Author : Stitch-z +@File : invoice_ocr.py +@Describe : Prompts of the invoice ocr assistant. +""" + +COMMON_PROMPT = "Now I will provide you with the OCR text recognition results for the invoice." + +EXTRACT_OCR_MAIN_INFO_PROMPT = COMMON_PROMPT + """ +Please extract the payee, city, total cost, and invoicing date of the invoice. + +The OCR data of the invoice are as follows: +{ocr_result} + +Mandatory restrictions are returned according to the following requirements: +1. The total cost refers to the total price and tax. Do not include `¥`. +2. The city must be the recipient's city. +2. The returned JSON dictionary must be returned in {language} +3. Mandatory requirement to output in JSON format: {{"收款人":"x","城市":"x","总费用/元":"","开票日期":""}}. +""" + +REPLY_OCR_QUESTION_PROMPT = COMMON_PROMPT + """ +Please answer the question: {query} + +The OCR data of the invoice are as follows: +{ocr_result} + +Mandatory restrictions are returned according to the following requirements: +1. Answer in {language} language. +2. Enforce restrictions on not returning OCR data sent to you. +3. Return with markdown syntax layout. +""" + +INVOICE_OCR_SUCCESS = "Successfully completed OCR text recognition invoice." + diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 7e865f288..6ebed2c16 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -111,19 +111,19 @@ class CostManager(metaclass=Singleton): return self.total_completion_tokens -def get_total_cost(self): - """ - Get the total cost of API calls. + def get_total_cost(self): + """ + Get the total cost of API calls. - Returns: - float: The total cost of API calls. - """ - return self.total_cost + Returns: + float: The total cost of API calls. + """ + return self.total_cost -def get_costs(self) -> Costs: - """Get all costs""" - return Costs(self.total_prompt_tokens, self.total_completion_tokens, self.total_cost, self.total_budget) + def get_costs(self) -> Costs: + """Get all costs""" + return Costs(self.total_prompt_tokens, self.total_completion_tokens, self.total_cost, self.total_budget) def log_and_reraise(retry_state): diff --git a/metagpt/roles/invoice_ocr_assistant.py b/metagpt/roles/invoice_ocr_assistant.py new file mode 100644 index 000000000..c307b20c0 --- /dev/null +++ b/metagpt/roles/invoice_ocr_assistant.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 +# _*_ coding: utf-8 _*_ + +""" +@Time : 2023/9/21 14:10:05 +@Author : Stitch-z +@File : invoice_ocr_assistant.py +""" + +import pandas as pd + +from metagpt.actions.invoice_ocr import InvoiceOCR, GenerateTable, ReplyQuestion +from metagpt.prompts.invoice_ocr import INVOICE_OCR_SUCCESS +from metagpt.roles import Role +from metagpt.schema import Message + + +class InvoiceOCRAssistant(Role): + """Invoice OCR assistant, support OCR text recognition of invoice PDF, png, jpg, and zip files, + generate a table for the payee, city, total amount, and invoicing date of the invoice, + and ask questions for a single file based on the OCR recognition results of the invoice. + + Args: + name: The name of the role. + profile: The role profile description. + goal: The goal of the role. + constraints: Constraints or requirements for the role. + language: The language in which the invoice table will be generated. + """ + + def __init__( + self, + name: str = "Stitch", + profile: str = "Invoice OCR Assistant", + goal: str = "OCR identifies invoice files and generates invoice main information table", + constraints: str = "", + language: str = "ch", + ): + super().__init__(name, profile, goal, constraints) + self._init_actions([InvoiceOCR]) + self.language = language + self.filename = "" + self.origin_query = "" + self.orc_data = None + + async def _think(self) -> None: + """Determine the next action to be taken by the role.""" + if self._rc.todo is None: + self._set_state(0) + return + + if self._rc.state + 1 < len(self._states): + self._set_state(self._rc.state + 1) + else: + self._rc.todo = None + + async def _act(self) -> Message: + """Perform an action as determined by the role. + + Returns: + A message containing the result of the action. + """ + msg = self._rc.memory.get(k=1)[0] + todo = self._rc.todo + if isinstance(todo, InvoiceOCR): + self.origin_query = msg.content + file_path = msg.instruct_content.get("file_path") + self.filename = file_path.name + if not file_path: + raise Exception("Invoice file not uploaded") + + resp = await todo.run(file_path) + if len(resp) == 1: + # Single file support for questioning based on OCR recognition results + self._init_actions([GenerateTable, ReplyQuestion]) + self.orc_data = resp[0] + else: + self._init_actions([GenerateTable]) + + self._rc.todo = None + content = INVOICE_OCR_SUCCESS + elif isinstance(todo, GenerateTable): + ocr_results = msg.instruct_content + resp = await todo.run(ocr_results, self.filename) + + # Convert list to Markdown format string + df = pd.DataFrame(resp) + markdown_table = df.to_markdown(index=False) + content = f"{markdown_table}\n\n\n" + else: + resp = await todo.run(self.origin_query, self.orc_data) + content = resp + + msg = Message(content=content, instruct_content=resp) + self._rc.memory.add(msg) + return msg + + async def _react(self) -> Message: + """Execute the invoice ocr assistant's think and actions. + + Returns: + A message containing the final result of the assistant's actions. + """ + while True: + await self._think() + if self._rc.todo is None: + break + msg = await self._act() + return msg + diff --git a/metagpt/tools/code_interpreter.py b/metagpt/tools/code_interpreter.py index 97398ccfd..e41eaab72 100644 --- a/metagpt/tools/code_interpreter.py +++ b/metagpt/tools/code_interpreter.py @@ -1,11 +1,11 @@ import re -from typing import List, Callable +from typing import List, Callable, Dict from pathlib import Path import wrapt import textwrap import inspect -from interpreter.interpreter import Interpreter +from interpreter.core.core import Interpreter from metagpt.logs import logger from metagpt.config import CONFIG @@ -41,13 +41,13 @@ class OpenCodeInterpreter(object): interpreter.auto_run = auto_run interpreter.model = CONFIG.openai_api_model or "gpt-3.5-turbo" interpreter.api_key = CONFIG.openai_api_key - interpreter.api_base = CONFIG.openai_api_base + # interpreter.api_base = CONFIG.openai_api_base self.interpreter = interpreter def chat(self, query: str, reset: bool = True): if reset: self.interpreter.reset() - return self.interpreter.chat(query, return_messages=True) + return self.interpreter.chat(query) @staticmethod def extract_function(query_respond: List, function_name: str, *, language: str = 'python', @@ -61,11 +61,30 @@ class OpenCodeInterpreter(object): assert language == 'python', f"Expect python language for default function_format, but got {language}." function_format = """def {function_name}():\n{code}""" # Extract the code module in the open-interpreter respond message. - code = [item['function_call']['parsed_arguments']['code'] for item in query_respond - if "function_call" in item - and "parsed_arguments" in item["function_call"] - and 'language' in item["function_call"]['parsed_arguments'] - and item["function_call"]['parsed_arguments']['language'] == language] + # The query_respond of open-interpreter before v0.1.4 is: + # [{'role': 'user', 'content': your query string}, + # {'role': 'assistant', 'content': plan from llm, 'function_call': { + # "name": "run_code", "arguments": "{"language": "python", "code": code of first plan}, + # "parsed_arguments": {"language": "python", "code": code of first plan} + # ...] + if "function_call" in query_respond[1]: + code = [item['function_call']['parsed_arguments']['code'] for item in query_respond + if "function_call" in item + and "parsed_arguments" in item["function_call"] + and 'language' in item["function_call"]['parsed_arguments'] + and item["function_call"]['parsed_arguments']['language'] == language] + # The query_respond of open-interpreter v0.1.7 is: + # [{'role': 'user', 'message': your query string}, + # {'role': 'assistant', 'message': plan from llm, 'language': 'python', + # 'code': code of first plan, 'output': output of first plan code}, + # ...] + elif "code" in query_respond[1]: + code = [item['code'] for item in query_respond + if "code" in item + and 'language' in item + and item['language'] == language] + else: + raise ValueError(f"Unexpect message format in query_respond: {query_respond[1].keys()}") # add indent. indented_code_str = textwrap.indent("\n".join(code), ' ' * 4) # Return the code after deduplication. @@ -94,13 +113,49 @@ class OpenInterpreterDecorator(object): self.code_file_path = code_file_path self.clear_code = clear_code + def _have_code(self, rsp: List[Dict]): + # Is there any code generated? + return 'code' in rsp[1] and rsp[1]['code'] not in ("", None) + + def _is_faild_plan(self, rsp: List[Dict]): + # is faild plan? + func_code = OpenCodeInterpreter.extract_function(rsp, 'function') + # If there is no more than 1 '\n', the plan execution fails. + if isinstance(func_code, str) and func_code.count('\n') <= 1: + return True + return False + + def _check_respond(self, query: str, interpreter: OpenCodeInterpreter, respond: List[Dict], max_try: int = 3): + for _ in range(max_try): + # TODO: If no code or faild plan is generated, execute chat again, repeating no more than max_try times. + if self._have_code(respond) and not self._is_faild_plan(respond): + break + elif not self._have_code(respond): + logger.warning(f"llm did not return executable code, resend the query: \n{query}") + respond = interpreter.chat(query) + elif self._is_faild_plan(respond): + logger.warning(f"llm did not generate successful plan, resend the query: \n{query}") + respond = interpreter.chat(query) + + # Post-processing of respond + if not self._have_code(respond): + error_msg = f"OpenCodeInterpreter do not generate code for query: \n{query}" + logger.error(error_msg) + raise ValueError(error_msg) + + if self._is_faild_plan(respond): + error_msg = f"OpenCodeInterpreter do not generate code for query: \n{query}" + logger.error(error_msg) + raise ValueError(error_msg) + return respond + def __call__(self, wrapped): @wrapt.decorator async def wrapper(wrapped: Callable, instance, args, kwargs): # Get the decorated function name. func_name = wrapped.__name__ # If the script exists locally and clearcode is not required, execute the function from the script. - if Path(self.code_file_path).is_file() and not self.clear_code: + if self.code_file_path and Path(self.code_file_path).is_file() and not self.clear_code: return run_function_script(self.code_file_path, func_name, *args, **kwargs) # Auto run generate code by using open-interpreter. @@ -108,6 +163,8 @@ class OpenInterpreterDecorator(object): query = gen_query(wrapped, args, kwargs) logger.info(f"query for OpenCodeInterpreter: \n {query}") respond = interpreter.chat(query) + # Make sure the response is as expected. + respond = self._check_respond(query, interpreter, respond, 3) # Assemble the code blocks generated by open-interpreter into a function without parameters. func_code = interpreter.extract_function(respond, func_name) # Clone the `func_code` into wrapped, that is, @@ -121,9 +178,10 @@ class OpenInterpreterDecorator(object): # execute this function. try: res = run_function_code(code, func_name, *args, **kwargs) - if self.save_code: + if self.save_code and self.code_file_path: cf._save(self.code_file_path, code) except Exception as e: + logger.error(f"Could not evaluate Python code \n{logger_code}: \nError: {e}") raise Exception("Could not evaluate Python code", e) return res return wrapper(wrapped) diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 59d179808..f09666beb 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -195,7 +195,8 @@ class OutputParser: except (ValueError, SyntaxError) as e: raise Exception(f"Error while extracting and parsing the {data_type}: {e}") else: - raise Exception(f"No {data_type} found in the text.") + logger.error(f"No {data_type} found in the text.") + return [] if data_type is list else {} class CodeParser: diff --git a/requirements-ocr.txt b/requirements-ocr.txt new file mode 100644 index 000000000..cf6103afc --- /dev/null +++ b/requirements-ocr.txt @@ -0,0 +1,4 @@ +paddlepaddle==2.4.2 +paddleocr>=2.0.1 +tabulate==0.9.0 +-r requirements.txt diff --git a/requirements.txt b/requirements.txt index 562a653f3..093298775 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,7 @@ langchain==0.0.231 loguru==0.6.0 meilisearch==0.21.0 numpy==1.24.3 -openai==0.27.8 +openai openpyxl beautifulsoup4==4.12.2 pandas==2.0.3 @@ -23,7 +23,7 @@ pydantic==1.10.8 #pymilvus==2.2.8 pytest==7.2.2 python_docx==0.8.11 -PyYAML==6.0 +PyYAML==6.0.1 # sentence_transformers==2.2.2 setuptools==65.6.3 tenacity==8.2.2 @@ -39,13 +39,8 @@ typing_extensions==4.5.0 libcst==1.0.1 qdrant-client==1.4.0 pytest-mock==3.11.1 -open-interpreter==0.1.4; python_version>"3.9" +open-interpreter==0.1.7; python_version>"3.9" ta==0.10.2 -semantic-kernel==0.3.10.dev0 +semantic-kernel==0.3.13.dev0 +wrapt==1.15.0 websocket-client==0.58.0 - - -aiofiles~=23.2.1 -pygments~=2.16.1 -requests~=2.31.0 -yaml~=0.2.5 \ No newline at end of file diff --git a/tests/data/invoices/invoice-1.pdf b/tests/data/invoices/invoice-1.pdf new file mode 100644 index 000000000..7f53133ef Binary files /dev/null and b/tests/data/invoices/invoice-1.pdf differ diff --git a/tests/data/invoices/invoice-2.png b/tests/data/invoices/invoice-2.png new file mode 100644 index 000000000..412ec3a24 Binary files /dev/null and b/tests/data/invoices/invoice-2.png differ diff --git a/tests/data/invoices/invoice-3.jpg b/tests/data/invoices/invoice-3.jpg new file mode 100644 index 000000000..8fb74dd30 Binary files /dev/null and b/tests/data/invoices/invoice-3.jpg differ diff --git a/tests/data/invoices/invoice-4.zip b/tests/data/invoices/invoice-4.zip new file mode 100644 index 000000000..c6be51a2b Binary files /dev/null and b/tests/data/invoices/invoice-4.zip differ diff --git a/tests/metagpt/actions/test_invoice_ocr.py b/tests/metagpt/actions/test_invoice_ocr.py new file mode 100644 index 000000000..a15166f7c --- /dev/null +++ b/tests/metagpt/actions/test_invoice_ocr.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +# _*_ coding: utf-8 _*_ + +""" +@Time : 2023/10/09 18:40:34 +@Author : Stitch-z +@File : test_invoice_ocr.py +""" + +import os +from typing import List + +import pytest +from pathlib import Path + +from metagpt.actions.invoice_ocr import InvoiceOCR, GenerateTable, ReplyQuestion + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "invoice_path", + [ + "../../data/invoices/invoice-3.jpg", + "../../data/invoices/invoice-4.zip", + ] +) +async def test_invoice_ocr(invoice_path: str): + invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path)) + filename = os.path.basename(invoice_path) + resp = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename) + assert isinstance(resp, list) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("invoice_path", "expected_result"), + [ + ( + "../../data/invoices/invoice-1.pdf", + [ + { + "收款人": "小明", + "城市": "深圳市", + "总费用/元": "412.00", + "开票日期": "2023年02月03日" + } + ] + ), + ] +) +async def test_generate_table(invoice_path: str, expected_result: list[dict]): + invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path)) + filename = os.path.basename(invoice_path) + ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename) + table_data = await GenerateTable().run(ocr_results=ocr_result, filename=filename) + assert table_data == expected_result + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("invoice_path", "query", "expected_result"), + [ + ("../../data/invoices/invoice-1.pdf", "Invoicing date", "2023年02月03日") + ] +) +async def test_reply_question(invoice_path: str, query: dict, expected_result: str): + invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path)) + filename = os.path.basename(invoice_path) + ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename) + result = await ReplyQuestion().run(query=query, ocr_result=ocr_result) + assert expected_result in result + diff --git a/tests/metagpt/roles/test_invoice_ocr_assistant.py b/tests/metagpt/roles/test_invoice_ocr_assistant.py new file mode 100644 index 000000000..75097e73c --- /dev/null +++ b/tests/metagpt/roles/test_invoice_ocr_assistant.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python3 +# _*_ coding: utf-8 _*_ + +""" +@Time : 2023/9/21 23:11:27 +@Author : Stitch-z +@File : test_invoice_ocr_assistant.py +""" + +from pathlib import Path + +import pytest +import pandas as pd + +from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant +from metagpt.schema import Message + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("query", "invoice_path", "invoice_table_path", "expected_result"), + [ + ( + "Invoicing date", + Path("../../data/invoices/invoice-1.pdf"), + Path("../../../data/invoice_table/invoice-1.xlsx"), + [ + { + "收款人": "小明", + "城市": "深圳市", + "总费用/元": 412.00, + "开票日期": "2023年02月03日" + } + ] + ), + ( + "Invoicing date", + Path("../../data/invoices/invoice-2.png"), + Path("../../../data/invoice_table/invoice-2.xlsx"), + [ + { + "收款人": "铁头", + "城市": "广州市", + "总费用/元": 898.00, + "开票日期": "2023年03月17日" + } + ] + ), + ( + "Invoicing date", + Path("../../data/invoices/invoice-3.jpg"), + Path("../../../data/invoice_table/invoice-3.xlsx"), + [ + { + "收款人": "夏天", + "城市": "福州市", + "总费用/元": 2462.00, + "开票日期": "2023年08月26日" + } + ] + ), + ( + "Invoicing date", + Path("../../data/invoices/invoice-4.zip"), + Path("../../../data/invoice_table/invoice-4.xlsx"), + [ + { + "收款人": "小明", + "城市": "深圳市", + "总费用/元": 412.00, + "开票日期": "2023年02月03日" + }, + { + "收款人": "铁头", + "城市": "广州市", + "总费用/元": 898.00, + "开票日期": "2023年03月17日" + }, + { + "收款人": "夏天", + "城市": "福州市", + "总费用/元": 2462.00, + "开票日期": "2023年08月26日" + } + ] + ), + ] +) +async def test_invoice_ocr_assistant( + query: str, + invoice_path: Path, + invoice_table_path: Path, + expected_result: list[dict] +): + invoice_path = Path.cwd() / invoice_path + role = InvoiceOCRAssistant() + await role.run(Message( + content=query, + instruct_content={"file_path": invoice_path} + )) + invoice_table_path = Path.cwd() / invoice_table_path + df = pd.read_excel(invoice_table_path) + dict_result = df.to_dict(orient='records') + assert dict_result == expected_result + diff --git a/tests/metagpt/utils/test_output_parser.py b/tests/metagpt/utils/test_output_parser.py index 2b706efc4..4e362f9f7 100644 --- a/tests/metagpt/utils/test_output_parser.py +++ b/tests/metagpt/utils/test_output_parser.py @@ -95,7 +95,7 @@ def test_parse_data(): """xxx xx""", list, None, - Exception, + [], ), ( """xxx [1, 2, []xx""",