diff --git a/examples/invoice_ocr.py b/examples/invoice_ocr.py index a6e565772..d9a2e8a6d 100644 --- a/examples/invoice_ocr.py +++ b/examples/invoice_ocr.py @@ -10,7 +10,7 @@ import asyncio from pathlib import Path -from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant +from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant, InvoicePath from metagpt.schema import Message @@ -26,7 +26,7 @@ async def main(): for path in absolute_file_paths: role = InvoiceOCRAssistant() - await role.run(Message(content="Invoicing date", instruct_content={"file_path": path})) + await role.run(Message(content="Invoicing date", instruct_content=InvoicePath(file_path=path))) if __name__ == "__main__": diff --git a/metagpt/roles/invoice_ocr_assistant.py b/metagpt/roles/invoice_ocr_assistant.py index 1e28bc078..56d729fa9 100644 --- a/metagpt/roles/invoice_ocr_assistant.py +++ b/metagpt/roles/invoice_ocr_assistant.py @@ -7,9 +7,11 @@ @File : invoice_ocr_assistant.py """ +from pathlib import Path from typing import Optional import pandas as pd +from pydantic import BaseModel from metagpt.actions.invoice_ocr import GenerateTable, InvoiceOCR, ReplyQuestion from metagpt.prompts.invoice_ocr import INVOICE_OCR_SUCCESS @@ -17,6 +19,22 @@ from metagpt.roles.role import Role, RoleReactMode from metagpt.schema import Message +class InvoicePath(BaseModel): + file_path: Path = "" + + +class OCRResults(BaseModel): + ocr_results: list[dict] = [] + + +class InvoiceData(BaseModel): + invoice_data: list[dict] = [] + + +class ReplyData(BaseModel): + content: str = "" + + class InvoiceOCRAssistant(Role): """Invoice OCR assistant, support OCR text recognition of invoice PDF, png, jpg, and zip files, generate a table for the payee, city, total amount, and invoicing date of the invoice, @@ -54,7 +72,8 @@ class InvoiceOCRAssistant(Role): todo = self._rc.todo if isinstance(todo, InvoiceOCR): self.origin_query = msg.content - file_path = msg.instruct_content.get("file_path") + invoice_path: InvoicePath = msg.instruct_content + file_path = invoice_path.file_path self.filename = file_path.name if not file_path: raise Exception("Invoice file not uploaded") @@ -69,17 +88,20 @@ class InvoiceOCRAssistant(Role): self._rc.todo = None content = INVOICE_OCR_SUCCESS + resp = OCRResults(ocr_results=resp) elif isinstance(todo, GenerateTable): - ocr_results = msg.instruct_content - resp = await todo.run(ocr_results, self.filename) + ocr_results: OCRResults = msg.instruct_content + resp = await todo.run(ocr_results.ocr_results, self.filename) # Convert list to Markdown format string df = pd.DataFrame(resp) markdown_table = df.to_markdown(index=False) content = f"{markdown_table}\n\n\n" + resp = InvoiceData(invoice_data=resp) else: resp = await todo.run(self.origin_query, self.orc_data) content = resp + resp = ReplyData(content=resp) msg = Message(content=content, instruct_content=resp) self._rc.memory.add(msg) diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 528e7d72d..8b048a523 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -489,6 +489,7 @@ class Role(BaseModel): async def _act_by_order(self) -> Message: """switch action each time by order defined in _init_actions, i.e. _act (Action1) -> _act (Action2) -> ...""" start_idx = self._rc.state if self._rc.state >= 0 else 0 # action to run from recovered state + rsp = Message(content="No actions taken yet") # return default message if _actions=[] for i in range(start_idx, len(self._states)): self._set_state(i) rsp = await self._act() diff --git a/tests/metagpt/roles/test_invoice_ocr_assistant.py b/tests/metagpt/roles/test_invoice_ocr_assistant.py index c9aad93a7..e5a570f53 100644 --- a/tests/metagpt/roles/test_invoice_ocr_assistant.py +++ b/tests/metagpt/roles/test_invoice_ocr_assistant.py @@ -12,7 +12,7 @@ from pathlib import Path import pandas as pd import pytest -from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant +from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant, InvoicePath from metagpt.schema import Message @@ -55,7 +55,7 @@ async def test_invoice_ocr_assistant( ): invoice_path = Path.cwd() / invoice_path role = InvoiceOCRAssistant() - await role.run(Message(content=query, instruct_content={"file_path": invoice_path})) + await role.run(Message(content=query, instruct_content=InvoicePath(file_path=invoice_path))) invoice_table_path = Path.cwd() / invoice_table_path df = pd.read_excel(invoice_table_path) dict_result = df.to_dict(orient="records")