Merge branch 'main' into minor-edits

This commit is contained in:
Shashank Harinath 2023-10-31 21:00:34 -07:00 committed by GitHub
commit 60bb53574f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 722 additions and 44 deletions

View file

@ -0,0 +1,186 @@
#!/usr/bin/env python3
# _*_ coding: utf-8 _*_
"""
@Time : 2023/9/21 18:10:20
@Author : Stitch-z
@File : invoice_ocr.py
@Describe : Actions of the invoice ocr assistant.
"""
import os
import zipfile
from pathlib import Path
from datetime import datetime
import pandas as pd
from paddleocr import PaddleOCR
from metagpt.actions import Action
from metagpt.const import INVOICE_OCR_TABLE_PATH
from metagpt.logs import logger
from metagpt.prompts.invoice_ocr import EXTRACT_OCR_MAIN_INFO_PROMPT, REPLY_OCR_QUESTION_PROMPT
from metagpt.utils.common import OutputParser
from metagpt.utils.file import File
class InvoiceOCR(Action):
"""Action class for performing OCR on invoice files, including zip, PDF, png, and jpg files.
Args:
name: The name of the action. Defaults to an empty string.
language: The language for OCR output. Defaults to "ch" (Chinese).
"""
def __init__(self, name: str = "", *args, **kwargs):
super().__init__(name, *args, **kwargs)
@staticmethod
async def _check_file_type(file_path: Path) -> str:
"""Check the file type of the given filename.
Args:
file_path: The path of the file.
Returns:
The file type based on FileExtensionType enum.
Raises:
Exception: If the file format is not zip, pdf, png, or jpg.
"""
ext = file_path.suffix
if ext not in [".zip", ".pdf", ".png", ".jpg"]:
raise Exception("The invoice format is not zip, pdf, png, or jpg")
return ext
@staticmethod
async def _unzip(file_path: Path) -> Path:
"""Unzip a file and return the path to the unzipped directory.
Args:
file_path: The path to the zip file.
Returns:
The path to the unzipped directory.
"""
file_directory = file_path.parent / "unzip_invoices" / datetime.now().strftime("%Y%m%d%H%M%S")
with zipfile.ZipFile(file_path, "r") as zip_ref:
for zip_info in zip_ref.infolist():
# Use CP437 to encode the file name, and then use GBK decoding to prevent Chinese garbled code
relative_name = Path(zip_info.filename.encode("cp437").decode("gbk"))
if relative_name.suffix:
full_filename = file_directory / relative_name
await File.write(full_filename.parent, relative_name.name, zip_ref.read(zip_info.filename))
logger.info(f"unzip_path: {file_directory}")
return file_directory
@staticmethod
async def _ocr(invoice_file_path: Path):
ocr = PaddleOCR(use_angle_cls=True, lang="ch", page_num=1)
ocr_result = ocr.ocr(str(invoice_file_path), cls=True)
return ocr_result
async def run(self, file_path: Path, *args, **kwargs) -> list:
"""Execute the action to identify invoice files through OCR.
Args:
file_path: The path to the input file.
Returns:
A list of OCR results.
"""
file_ext = await self._check_file_type(file_path)
if file_ext == ".zip":
# OCR recognizes zip batch files
unzip_path = await self._unzip(file_path)
ocr_list = []
for root, _, files in os.walk(unzip_path):
for filename in files:
invoice_file_path = Path(root) / Path(filename)
# Identify files that match the type
if Path(filename).suffix in [".zip", ".pdf", ".png", ".jpg"]:
ocr_result = await self._ocr(str(invoice_file_path))
ocr_list.append(ocr_result)
return ocr_list
else:
# OCR identifies single file
ocr_result = await self._ocr(file_path)
return [ocr_result]
class GenerateTable(Action):
"""Action class for generating tables from OCR results.
Args:
name: The name of the action. Defaults to an empty string.
language: The language used for the generated table. Defaults to "ch" (Chinese).
"""
def __init__(self, name: str = "", language: str = "ch", *args, **kwargs):
super().__init__(name, *args, **kwargs)
self.language = language
async def run(self, ocr_results: list, filename: str, *args, **kwargs) -> dict[str, str]:
"""Processes OCR results, extracts invoice information, generates a table, and saves it as an Excel file.
Args:
ocr_results: A list of OCR results obtained from invoice processing.
filename: The name of the output Excel file.
Returns:
A dictionary containing the invoice information.
"""
table_data = []
pathname = INVOICE_OCR_TABLE_PATH
pathname.mkdir(parents=True, exist_ok=True)
for ocr_result in ocr_results:
# Extract invoice OCR main information
prompt = EXTRACT_OCR_MAIN_INFO_PROMPT.format(ocr_result=ocr_result, language=self.language)
ocr_info = await self._aask(prompt=prompt)
invoice_data = OutputParser.extract_struct(ocr_info, dict)
if invoice_data:
table_data.append(invoice_data)
# Generate Excel file
filename = f"{filename.split('.')[0]}.xlsx"
full_filename = f"{pathname}/{filename}"
df = pd.DataFrame(table_data)
df.to_excel(full_filename, index=False)
return table_data
class ReplyQuestion(Action):
"""Action class for generating replies to questions based on OCR results.
Args:
name: The name of the action. Defaults to an empty string.
language: The language used for generating the reply. Defaults to "ch" (Chinese).
"""
def __init__(self, name: str = "", language: str = "ch", *args, **kwargs):
super().__init__(name, *args, **kwargs)
self.language = language
async def run(self, query: str, ocr_result: list, *args, **kwargs) -> str:
"""Reply to questions based on ocr results.
Args:
query: The question for which a reply is generated.
ocr_result: A list of OCR results.
Returns:
A reply result of string type.
"""
prompt = REPLY_OCR_QUESTION_PROMPT.format(query=query, ocr_result=ocr_result, language=self.language)
resp = await self._aask(prompt=prompt)
return resp

View file

@ -36,6 +36,7 @@ YAPI_URL = "http://yapi.deepwisdomai.com/"
TMP = PROJECT_ROOT / "tmp"
RESEARCH_PATH = DATA_PATH / "research"
TUTORIAL_PATH = DATA_PATH / "tutorial_docx"
INVOICE_OCR_TABLE_PATH = DATA_PATH / "invoice_table"
SKILL_DIRECTORY = PROJECT_ROOT / "metagpt/skills"

View file

@ -0,0 +1,39 @@
#!/usr/bin/env python3
# _*_ coding: utf-8 _*_
"""
@Time : 2023/9/21 16:30:25
@Author : Stitch-z
@File : invoice_ocr.py
@Describe : Prompts of the invoice ocr assistant.
"""
COMMON_PROMPT = "Now I will provide you with the OCR text recognition results for the invoice."
EXTRACT_OCR_MAIN_INFO_PROMPT = COMMON_PROMPT + """
Please extract the payee, city, total cost, and invoicing date of the invoice.
The OCR data of the invoice are as follows:
{ocr_result}
Mandatory restrictions are returned according to the following requirements:
1. The total cost refers to the total price and tax. Do not include `¥`.
2. The city must be the recipient's city.
2. The returned JSON dictionary must be returned in {language}
3. Mandatory requirement to output in JSON format: {{"收款人":"x","城市":"x","总费用/元":"","开票日期":""}}.
"""
REPLY_OCR_QUESTION_PROMPT = COMMON_PROMPT + """
Please answer the question: {query}
The OCR data of the invoice are as follows:
{ocr_result}
Mandatory restrictions are returned according to the following requirements:
1. Answer in {language} language.
2. Enforce restrictions on not returning OCR data sent to you.
3. Return with markdown syntax layout.
"""
INVOICE_OCR_SUCCESS = "Successfully completed OCR text recognition invoice."

View file

@ -0,0 +1,110 @@
#!/usr/bin/env python3
# _*_ coding: utf-8 _*_
"""
@Time : 2023/9/21 14:10:05
@Author : Stitch-z
@File : invoice_ocr_assistant.py
"""
import pandas as pd
from metagpt.actions.invoice_ocr import InvoiceOCR, GenerateTable, ReplyQuestion
from metagpt.prompts.invoice_ocr import INVOICE_OCR_SUCCESS
from metagpt.roles import Role
from metagpt.schema import Message
class InvoiceOCRAssistant(Role):
"""Invoice OCR assistant, support OCR text recognition of invoice PDF, png, jpg, and zip files,
generate a table for the payee, city, total amount, and invoicing date of the invoice,
and ask questions for a single file based on the OCR recognition results of the invoice.
Args:
name: The name of the role.
profile: The role profile description.
goal: The goal of the role.
constraints: Constraints or requirements for the role.
language: The language in which the invoice table will be generated.
"""
def __init__(
self,
name: str = "Stitch",
profile: str = "Invoice OCR Assistant",
goal: str = "OCR identifies invoice files and generates invoice main information table",
constraints: str = "",
language: str = "ch",
):
super().__init__(name, profile, goal, constraints)
self._init_actions([InvoiceOCR])
self.language = language
self.filename = ""
self.origin_query = ""
self.orc_data = None
async def _think(self) -> None:
"""Determine the next action to be taken by the role."""
if self._rc.todo is None:
self._set_state(0)
return
if self._rc.state + 1 < len(self._states):
self._set_state(self._rc.state + 1)
else:
self._rc.todo = None
async def _act(self) -> Message:
"""Perform an action as determined by the role.
Returns:
A message containing the result of the action.
"""
msg = self._rc.memory.get(k=1)[0]
todo = self._rc.todo
if isinstance(todo, InvoiceOCR):
self.origin_query = msg.content
file_path = msg.instruct_content.get("file_path")
self.filename = file_path.name
if not file_path:
raise Exception("Invoice file not uploaded")
resp = await todo.run(file_path)
if len(resp) == 1:
# Single file support for questioning based on OCR recognition results
self._init_actions([GenerateTable, ReplyQuestion])
self.orc_data = resp[0]
else:
self._init_actions([GenerateTable])
self._rc.todo = None
content = INVOICE_OCR_SUCCESS
elif isinstance(todo, GenerateTable):
ocr_results = msg.instruct_content
resp = await todo.run(ocr_results, self.filename)
# Convert list to Markdown format string
df = pd.DataFrame(resp)
markdown_table = df.to_markdown(index=False)
content = f"{markdown_table}\n\n\n"
else:
resp = await todo.run(self.origin_query, self.orc_data)
content = resp
msg = Message(content=content, instruct_content=resp)
self._rc.memory.add(msg)
return msg
async def _react(self) -> Message:
"""Execute the invoice ocr assistant's think and actions.
Returns:
A message containing the final result of the assistant's actions.
"""
while True:
await self._think()
if self._rc.todo is None:
break
msg = await self._act()
return msg

View file

@ -1,11 +1,11 @@
import re
from typing import List, Callable
from typing import List, Callable, Dict
from pathlib import Path
import wrapt
import textwrap
import inspect
from interpreter.interpreter import Interpreter
from interpreter.core.core import Interpreter
from metagpt.logs import logger
from metagpt.config import CONFIG
@ -41,13 +41,13 @@ class OpenCodeInterpreter(object):
interpreter.auto_run = auto_run
interpreter.model = CONFIG.openai_api_model or "gpt-3.5-turbo"
interpreter.api_key = CONFIG.openai_api_key
interpreter.api_base = CONFIG.openai_api_base
# interpreter.api_base = CONFIG.openai_api_base
self.interpreter = interpreter
def chat(self, query: str, reset: bool = True):
if reset:
self.interpreter.reset()
return self.interpreter.chat(query, return_messages=True)
return self.interpreter.chat(query)
@staticmethod
def extract_function(query_respond: List, function_name: str, *, language: str = 'python',
@ -61,11 +61,30 @@ class OpenCodeInterpreter(object):
assert language == 'python', f"Expect python language for default function_format, but got {language}."
function_format = """def {function_name}():\n{code}"""
# Extract the code module in the open-interpreter respond message.
code = [item['function_call']['parsed_arguments']['code'] for item in query_respond
if "function_call" in item
and "parsed_arguments" in item["function_call"]
and 'language' in item["function_call"]['parsed_arguments']
and item["function_call"]['parsed_arguments']['language'] == language]
# The query_respond of open-interpreter before v0.1.4 is:
# [{'role': 'user', 'content': your query string},
# {'role': 'assistant', 'content': plan from llm, 'function_call': {
# "name": "run_code", "arguments": "{"language": "python", "code": code of first plan},
# "parsed_arguments": {"language": "python", "code": code of first plan}
# ...]
if "function_call" in query_respond[1]:
code = [item['function_call']['parsed_arguments']['code'] for item in query_respond
if "function_call" in item
and "parsed_arguments" in item["function_call"]
and 'language' in item["function_call"]['parsed_arguments']
and item["function_call"]['parsed_arguments']['language'] == language]
# The query_respond of open-interpreter v0.1.7 is:
# [{'role': 'user', 'message': your query string},
# {'role': 'assistant', 'message': plan from llm, 'language': 'python',
# 'code': code of first plan, 'output': output of first plan code},
# ...]
elif "code" in query_respond[1]:
code = [item['code'] for item in query_respond
if "code" in item
and 'language' in item
and item['language'] == language]
else:
raise ValueError(f"Unexpect message format in query_respond: {query_respond[1].keys()}")
# add indent.
indented_code_str = textwrap.indent("\n".join(code), ' ' * 4)
# Return the code after deduplication.
@ -94,13 +113,49 @@ class OpenInterpreterDecorator(object):
self.code_file_path = code_file_path
self.clear_code = clear_code
def _have_code(self, rsp: List[Dict]):
# Is there any code generated?
return 'code' in rsp[1] and rsp[1]['code'] not in ("", None)
def _is_faild_plan(self, rsp: List[Dict]):
# is faild plan?
func_code = OpenCodeInterpreter.extract_function(rsp, 'function')
# If there is no more than 1 '\n', the plan execution fails.
if isinstance(func_code, str) and func_code.count('\n') <= 1:
return True
return False
def _check_respond(self, query: str, interpreter: OpenCodeInterpreter, respond: List[Dict], max_try: int = 3):
for _ in range(max_try):
# TODO: If no code or faild plan is generated, execute chat again, repeating no more than max_try times.
if self._have_code(respond) and not self._is_faild_plan(respond):
break
elif not self._have_code(respond):
logger.warning(f"llm did not return executable code, resend the query: \n{query}")
respond = interpreter.chat(query)
elif self._is_faild_plan(respond):
logger.warning(f"llm did not generate successful plan, resend the query: \n{query}")
respond = interpreter.chat(query)
# Post-processing of respond
if not self._have_code(respond):
error_msg = f"OpenCodeInterpreter do not generate code for query: \n{query}"
logger.error(error_msg)
raise ValueError(error_msg)
if self._is_faild_plan(respond):
error_msg = f"OpenCodeInterpreter do not generate code for query: \n{query}"
logger.error(error_msg)
raise ValueError(error_msg)
return respond
def __call__(self, wrapped):
@wrapt.decorator
async def wrapper(wrapped: Callable, instance, args, kwargs):
# Get the decorated function name.
func_name = wrapped.__name__
# If the script exists locally and clearcode is not required, execute the function from the script.
if Path(self.code_file_path).is_file() and not self.clear_code:
if self.code_file_path and Path(self.code_file_path).is_file() and not self.clear_code:
return run_function_script(self.code_file_path, func_name, *args, **kwargs)
# Auto run generate code by using open-interpreter.
@ -108,6 +163,8 @@ class OpenInterpreterDecorator(object):
query = gen_query(wrapped, args, kwargs)
logger.info(f"query for OpenCodeInterpreter: \n {query}")
respond = interpreter.chat(query)
# Make sure the response is as expected.
respond = self._check_respond(query, interpreter, respond, 3)
# Assemble the code blocks generated by open-interpreter into a function without parameters.
func_code = interpreter.extract_function(respond, func_name)
# Clone the `func_code` into wrapped, that is,
@ -121,9 +178,10 @@ class OpenInterpreterDecorator(object):
# execute this function.
try:
res = run_function_code(code, func_name, *args, **kwargs)
if self.save_code:
if self.save_code and self.code_file_path:
cf._save(self.code_file_path, code)
except Exception as e:
logger.error(f"Could not evaluate Python code \n{logger_code}: \nError: {e}")
raise Exception("Could not evaluate Python code", e)
return res
return wrapper(wrapped)

View file

@ -195,7 +195,8 @@ class OutputParser:
except (ValueError, SyntaxError) as e:
raise Exception(f"Error while extracting and parsing the {data_type}: {e}")
else:
raise Exception(f"No {data_type} found in the text.")
logger.error(f"No {data_type} found in the text.")
return [] if data_type is list else {}
class CodeParser: