mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
feat: +unit test
This commit is contained in:
commit
f3f19811a0
13 changed files with 416 additions and 434 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -164,4 +164,5 @@ metagpt/roles/idea_agent.py
|
|||
# output folder
|
||||
output
|
||||
tmp.png
|
||||
|
||||
.dependencies.json
|
||||
tests/metagpt/utils/file_repo_git
|
||||
|
|
|
|||
|
|
@ -33,6 +33,9 @@ class BrainMemory(BaseModel):
|
|||
cacheable: bool = True
|
||||
llm: Optional[BaseLLM] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def add_talk(self, msg: Message):
|
||||
"""
|
||||
Add message from user.
|
||||
|
|
|
|||
|
|
@ -1,197 +1,207 @@
|
|||
import inspect
|
||||
import re
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time :
|
||||
@Author :
|
||||
@File : code_interpreter.py
|
||||
@Warning : open-interpreter 0.1.17 requires openai<0.29.0,>=0.28.0, but you have openai 1.6.0 which is incompatible.
|
||||
open-interpreter 0.1.17 requires tiktoken<0.5.0,>=0.4.0, but you have tiktoken 0.5.2 which is incompatible.
|
||||
"""
|
||||
|
||||
import wrapt
|
||||
from interpreter.core.core import Interpreter
|
||||
|
||||
from metagpt.actions.clone_function import (
|
||||
CloneFunction,
|
||||
run_function_code,
|
||||
run_function_script,
|
||||
)
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.highlight import highlight
|
||||
|
||||
|
||||
def extract_python_code(code: str):
|
||||
"""Extract code blocks: If the code comments are the same, only the last code block is kept."""
|
||||
# Use regular expressions to match comment blocks and related code.
|
||||
pattern = r"(#\s[^\n]*)\n(.*?)(?=\n\s*#|$)"
|
||||
matches = re.findall(pattern, code, re.DOTALL)
|
||||
|
||||
# Extract the last code block when encountering the same comment.
|
||||
unique_comments = {}
|
||||
for comment, code_block in matches:
|
||||
unique_comments[comment] = code_block
|
||||
|
||||
# concatenate into functional form
|
||||
result_code = "\n".join([f"{comment}\n{code_block}" for comment, code_block in unique_comments.items()])
|
||||
header_code = code[: code.find("#")]
|
||||
code = header_code + result_code
|
||||
|
||||
logger.info(f"Extract python code: \n {highlight(code)}")
|
||||
|
||||
return code
|
||||
|
||||
|
||||
class OpenCodeInterpreter(object):
|
||||
"""https://github.com/KillianLucas/open-interpreter"""
|
||||
|
||||
def __init__(self, auto_run: bool = True) -> None:
|
||||
interpreter = Interpreter()
|
||||
interpreter.auto_run = auto_run
|
||||
interpreter.model = CONFIG.openai_api_model or "gpt-3.5-turbo"
|
||||
interpreter.api_key = CONFIG.openai_api_key
|
||||
self.interpreter = interpreter
|
||||
|
||||
def chat(self, query: str, reset: bool = True):
|
||||
if reset:
|
||||
self.interpreter.reset()
|
||||
return self.interpreter.chat(query)
|
||||
|
||||
@staticmethod
|
||||
def extract_function(
|
||||
query_respond: List, function_name: str, *, language: str = "python", function_format: str = None
|
||||
) -> str:
|
||||
"""create a function from query_respond."""
|
||||
if language not in ("python"):
|
||||
raise NotImplementedError(f"Not support to parse language {language}!")
|
||||
|
||||
# set function form
|
||||
if function_format is None:
|
||||
assert language == "python", f"Expect python language for default function_format, but got {language}."
|
||||
function_format = """def {function_name}():\n{code}"""
|
||||
# Extract the code module in the open-interpreter respond message.
|
||||
# The query_respond of open-interpreter before v0.1.4 is:
|
||||
# [{'role': 'user', 'content': your query string},
|
||||
# {'role': 'assistant', 'content': plan from llm, 'function_call': {
|
||||
# "name": "run_code", "arguments": "{"language": "python", "code": code of first plan},
|
||||
# "parsed_arguments": {"language": "python", "code": code of first plan}
|
||||
# ...]
|
||||
if "function_call" in query_respond[1]:
|
||||
code = [
|
||||
item["function_call"]["parsed_arguments"]["code"]
|
||||
for item in query_respond
|
||||
if "function_call" in item
|
||||
and "parsed_arguments" in item["function_call"]
|
||||
and "language" in item["function_call"]["parsed_arguments"]
|
||||
and item["function_call"]["parsed_arguments"]["language"] == language
|
||||
]
|
||||
# The query_respond of open-interpreter v0.1.7 is:
|
||||
# [{'role': 'user', 'message': your query string},
|
||||
# {'role': 'assistant', 'message': plan from llm, 'language': 'python',
|
||||
# 'code': code of first plan, 'output': output of first plan code},
|
||||
# ...]
|
||||
elif "code" in query_respond[1]:
|
||||
code = [
|
||||
item["code"]
|
||||
for item in query_respond
|
||||
if "code" in item and "language" in item and item["language"] == language
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"Unexpect message format in query_respond: {query_respond[1].keys()}")
|
||||
# add indent.
|
||||
indented_code_str = textwrap.indent("\n".join(code), " " * 4)
|
||||
# Return the code after deduplication.
|
||||
if language == "python":
|
||||
return extract_python_code(function_format.format(function_name=function_name, code=indented_code_str))
|
||||
|
||||
|
||||
def gen_query(func: Callable, args, kwargs) -> str:
|
||||
# Get the annotation of the function as part of the query.
|
||||
desc = func.__doc__
|
||||
signature = inspect.signature(func)
|
||||
# Get the signature of the wrapped function and the assignment of the input parameters as part of the query.
|
||||
bound_args = signature.bind(*args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
query = f"{desc}, {bound_args.arguments}, If you must use a third-party package, use the most popular ones, for example: pandas, numpy, ta, ..."
|
||||
return query
|
||||
|
||||
|
||||
def gen_template_fun(func: Callable) -> str:
|
||||
return f"def {func.__name__}{str(inspect.signature(func))}\n # here is your code ..."
|
||||
|
||||
|
||||
class OpenInterpreterDecorator(object):
|
||||
def __init__(self, save_code: bool = False, code_file_path: str = None, clear_code: bool = False) -> None:
|
||||
self.save_code = save_code
|
||||
self.code_file_path = code_file_path
|
||||
self.clear_code = clear_code
|
||||
|
||||
def _have_code(self, rsp: List[Dict]):
|
||||
# Is there any code generated?
|
||||
return "code" in rsp[1] and rsp[1]["code"] not in ("", None)
|
||||
|
||||
def _is_faild_plan(self, rsp: List[Dict]):
|
||||
# is faild plan?
|
||||
func_code = OpenCodeInterpreter.extract_function(rsp, "function")
|
||||
# If there is no more than 1 '\n', the plan execution fails.
|
||||
if isinstance(func_code, str) and func_code.count("\n") <= 1:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _check_respond(self, query: str, interpreter: OpenCodeInterpreter, respond: List[Dict], max_try: int = 3):
|
||||
for _ in range(max_try):
|
||||
# TODO: If no code or faild plan is generated, execute chat again, repeating no more than max_try times.
|
||||
if self._have_code(respond) and not self._is_faild_plan(respond):
|
||||
break
|
||||
elif not self._have_code(respond):
|
||||
logger.warning(f"llm did not return executable code, resend the query: \n{query}")
|
||||
respond = interpreter.chat(query)
|
||||
elif self._is_faild_plan(respond):
|
||||
logger.warning(f"llm did not generate successful plan, resend the query: \n{query}")
|
||||
respond = interpreter.chat(query)
|
||||
|
||||
# Post-processing of respond
|
||||
if not self._have_code(respond):
|
||||
error_msg = f"OpenCodeInterpreter do not generate code for query: \n{query}"
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
if self._is_faild_plan(respond):
|
||||
error_msg = f"OpenCodeInterpreter do not generate code for query: \n{query}"
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
return respond
|
||||
|
||||
def __call__(self, wrapped):
|
||||
@wrapt.decorator
|
||||
async def wrapper(wrapped: Callable, instance, args, kwargs):
|
||||
# Get the decorated function name.
|
||||
func_name = wrapped.__name__
|
||||
# If the script exists locally and clearcode is not required, execute the function from the script.
|
||||
if self.code_file_path and Path(self.code_file_path).is_file() and not self.clear_code:
|
||||
return run_function_script(self.code_file_path, func_name, *args, **kwargs)
|
||||
|
||||
# Auto run generate code by using open-interpreter.
|
||||
interpreter = OpenCodeInterpreter()
|
||||
query = gen_query(wrapped, args, kwargs)
|
||||
logger.info(f"query for OpenCodeInterpreter: \n {query}")
|
||||
respond = interpreter.chat(query)
|
||||
# Make sure the response is as expected.
|
||||
respond = self._check_respond(query, interpreter, respond, 3)
|
||||
# Assemble the code blocks generated by open-interpreter into a function without parameters.
|
||||
func_code = interpreter.extract_function(respond, func_name)
|
||||
# Clone the `func_code` into wrapped, that is,
|
||||
# keep the `func_code` and wrapped functions with the same input parameter and return value types.
|
||||
template_func = gen_template_fun(wrapped)
|
||||
cf = CloneFunction()
|
||||
code = await cf.run(template_func=template_func, source_code=func_code)
|
||||
# Display the generated function in the terminal.
|
||||
logger_code = highlight(code, "python")
|
||||
logger.info(f"Creating following Python function:\n{logger_code}")
|
||||
# execute this function.
|
||||
try:
|
||||
res = run_function_code(code, func_name, *args, **kwargs)
|
||||
if self.save_code and self.code_file_path:
|
||||
cf._save(self.code_file_path, code)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not evaluate Python code \n{logger_code}: \nError: {e}")
|
||||
raise Exception("Could not evaluate Python code", e)
|
||||
return res
|
||||
|
||||
return wrapper(wrapped)
|
||||
# import inspect
|
||||
# import re
|
||||
# import textwrap
|
||||
# from pathlib import Path
|
||||
# from typing import Callable, Dict, List
|
||||
#
|
||||
# import wrapt
|
||||
# from interpreter.core.core import Interpreter
|
||||
#
|
||||
# from metagpt.actions.clone_function import (
|
||||
# CloneFunction,
|
||||
# run_function_code,
|
||||
# run_function_script,
|
||||
# )
|
||||
# from metagpt.config import CONFIG
|
||||
# from metagpt.logs import logger
|
||||
# from metagpt.utils.highlight import highlight
|
||||
#
|
||||
#
|
||||
# def extract_python_code(code: str):
|
||||
# """Extract code blocks: If the code comments are the same, only the last code block is kept."""
|
||||
# # Use regular expressions to match comment blocks and related code.
|
||||
# pattern = r"(#\s[^\n]*)\n(.*?)(?=\n\s*#|$)"
|
||||
# matches = re.findall(pattern, code, re.DOTALL)
|
||||
#
|
||||
# # Extract the last code block when encountering the same comment.
|
||||
# unique_comments = {}
|
||||
# for comment, code_block in matches:
|
||||
# unique_comments[comment] = code_block
|
||||
#
|
||||
# # concatenate into functional form
|
||||
# result_code = "\n".join([f"{comment}\n{code_block}" for comment, code_block in unique_comments.items()])
|
||||
# header_code = code[: code.find("#")]
|
||||
# code = header_code + result_code
|
||||
#
|
||||
# logger.info(f"Extract python code: \n {highlight(code)}")
|
||||
#
|
||||
# return code
|
||||
#
|
||||
#
|
||||
# class OpenCodeInterpreter(object):
|
||||
# """https://github.com/KillianLucas/open-interpreter"""
|
||||
#
|
||||
# def __init__(self, auto_run: bool = True) -> None:
|
||||
# interpreter = Interpreter()
|
||||
# interpreter.auto_run = auto_run
|
||||
# interpreter.model = CONFIG.openai_api_model or "gpt-3.5-turbo"
|
||||
# interpreter.api_key = CONFIG.openai_api_key
|
||||
# self.interpreter = interpreter
|
||||
#
|
||||
# def chat(self, query: str, reset: bool = True):
|
||||
# if reset:
|
||||
# self.interpreter.reset()
|
||||
# return self.interpreter.chat(query)
|
||||
#
|
||||
# @staticmethod
|
||||
# def extract_function(
|
||||
# query_respond: List, function_name: str, *, language: str = "python", function_format: str = None
|
||||
# ) -> str:
|
||||
# """create a function from query_respond."""
|
||||
# if language not in ("python"):
|
||||
# raise NotImplementedError(f"Not support to parse language {language}!")
|
||||
#
|
||||
# # set function form
|
||||
# if function_format is None:
|
||||
# assert language == "python", f"Expect python language for default function_format, but got {language}."
|
||||
# function_format = """def {function_name}():\n{code}"""
|
||||
# # Extract the code module in the open-interpreter respond message.
|
||||
# # The query_respond of open-interpreter before v0.1.4 is:
|
||||
# # [{'role': 'user', 'content': your query string},
|
||||
# # {'role': 'assistant', 'content': plan from llm, 'function_call': {
|
||||
# # "name": "run_code", "arguments": "{"language": "python", "code": code of first plan},
|
||||
# # "parsed_arguments": {"language": "python", "code": code of first plan}
|
||||
# # ...]
|
||||
# if "function_call" in query_respond[1]:
|
||||
# code = [
|
||||
# item["function_call"]["parsed_arguments"]["code"]
|
||||
# for item in query_respond
|
||||
# if "function_call" in item
|
||||
# and "parsed_arguments" in item["function_call"]
|
||||
# and "language" in item["function_call"]["parsed_arguments"]
|
||||
# and item["function_call"]["parsed_arguments"]["language"] == language
|
||||
# ]
|
||||
# # The query_respond of open-interpreter v0.1.7 is:
|
||||
# # [{'role': 'user', 'message': your query string},
|
||||
# # {'role': 'assistant', 'message': plan from llm, 'language': 'python',
|
||||
# # 'code': code of first plan, 'output': output of first plan code},
|
||||
# # ...]
|
||||
# elif "code" in query_respond[1]:
|
||||
# code = [
|
||||
# item["code"]
|
||||
# for item in query_respond
|
||||
# if "code" in item and "language" in item and item["language"] == language
|
||||
# ]
|
||||
# else:
|
||||
# raise ValueError(f"Unexpect message format in query_respond: {query_respond[1].keys()}")
|
||||
# # add indent.
|
||||
# indented_code_str = textwrap.indent("\n".join(code), " " * 4)
|
||||
# # Return the code after deduplication.
|
||||
# if language == "python":
|
||||
# return extract_python_code(function_format.format(function_name=function_name, code=indented_code_str))
|
||||
#
|
||||
#
|
||||
# def gen_query(func: Callable, args, kwargs) -> str:
|
||||
# # Get the annotation of the function as part of the query.
|
||||
# desc = func.__doc__
|
||||
# signature = inspect.signature(func)
|
||||
# # Get the signature of the wrapped function and the assignment of the input parameters as part of the query.
|
||||
# bound_args = signature.bind(*args, **kwargs)
|
||||
# bound_args.apply_defaults()
|
||||
# query = f"{desc}, {bound_args.arguments}, If you must use a third-party package, use the most popular ones, for example: pandas, numpy, ta, ..."
|
||||
# return query
|
||||
#
|
||||
#
|
||||
# def gen_template_fun(func: Callable) -> str:
|
||||
# return f"def {func.__name__}{str(inspect.signature(func))}\n # here is your code ..."
|
||||
#
|
||||
#
|
||||
# class OpenInterpreterDecorator(object):
|
||||
# def __init__(self, save_code: bool = False, code_file_path: str = None, clear_code: bool = False) -> None:
|
||||
# self.save_code = save_code
|
||||
# self.code_file_path = code_file_path
|
||||
# self.clear_code = clear_code
|
||||
#
|
||||
# def _have_code(self, rsp: List[Dict]):
|
||||
# # Is there any code generated?
|
||||
# return "code" in rsp[1] and rsp[1]["code"] not in ("", None)
|
||||
#
|
||||
# def _is_faild_plan(self, rsp: List[Dict]):
|
||||
# # is faild plan?
|
||||
# func_code = OpenCodeInterpreter.extract_function(rsp, "function")
|
||||
# # If there is no more than 1 '\n', the plan execution fails.
|
||||
# if isinstance(func_code, str) and func_code.count("\n") <= 1:
|
||||
# return True
|
||||
# return False
|
||||
#
|
||||
# def _check_respond(self, query: str, interpreter: OpenCodeInterpreter, respond: List[Dict], max_try: int = 3):
|
||||
# for _ in range(max_try):
|
||||
# # TODO: If no code or faild plan is generated, execute chat again, repeating no more than max_try times.
|
||||
# if self._have_code(respond) and not self._is_faild_plan(respond):
|
||||
# break
|
||||
# elif not self._have_code(respond):
|
||||
# logger.warning(f"llm did not return executable code, resend the query: \n{query}")
|
||||
# respond = interpreter.chat(query)
|
||||
# elif self._is_faild_plan(respond):
|
||||
# logger.warning(f"llm did not generate successful plan, resend the query: \n{query}")
|
||||
# respond = interpreter.chat(query)
|
||||
#
|
||||
# # Post-processing of respond
|
||||
# if not self._have_code(respond):
|
||||
# error_msg = f"OpenCodeInterpreter do not generate code for query: \n{query}"
|
||||
# logger.error(error_msg)
|
||||
# raise ValueError(error_msg)
|
||||
#
|
||||
# if self._is_faild_plan(respond):
|
||||
# error_msg = f"OpenCodeInterpreter do not generate code for query: \n{query}"
|
||||
# logger.error(error_msg)
|
||||
# raise ValueError(error_msg)
|
||||
# return respond
|
||||
#
|
||||
# def __call__(self, wrapped):
|
||||
# @wrapt.decorator
|
||||
# async def wrapper(wrapped: Callable, instance, args, kwargs):
|
||||
# # Get the decorated function name.
|
||||
# func_name = wrapped.__name__
|
||||
# # If the script exists locally and clearcode is not required, execute the function from the script.
|
||||
# if self.code_file_path and Path(self.code_file_path).is_file() and not self.clear_code:
|
||||
# return run_function_script(self.code_file_path, func_name, *args, **kwargs)
|
||||
#
|
||||
# # Auto run generate code by using open-interpreter.
|
||||
# interpreter = OpenCodeInterpreter()
|
||||
# query = gen_query(wrapped, args, kwargs)
|
||||
# logger.info(f"query for OpenCodeInterpreter: \n {query}")
|
||||
# respond = interpreter.chat(query)
|
||||
# # Make sure the response is as expected.
|
||||
# respond = self._check_respond(query, interpreter, respond, 3)
|
||||
# # Assemble the code blocks generated by open-interpreter into a function without parameters.
|
||||
# func_code = interpreter.extract_function(respond, func_name)
|
||||
# # Clone the `func_code` into wrapped, that is,
|
||||
# # keep the `func_code` and wrapped functions with the same input parameter and return value types.
|
||||
# template_func = gen_template_fun(wrapped)
|
||||
# cf = CloneFunction()
|
||||
# code = await cf.run(template_func=template_func, source_code=func_code)
|
||||
# # Display the generated function in the terminal.
|
||||
# logger_code = highlight(code, "python")
|
||||
# logger.info(f"Creating following Python function:\n{logger_code}")
|
||||
# # execute this function.
|
||||
# try:
|
||||
# res = run_function_code(code, func_name, *args, **kwargs)
|
||||
# if self.save_code and self.code_file_path:
|
||||
# cf._save(self.code_file_path, code)
|
||||
# except Exception as e:
|
||||
# logger.error(f"Could not evaluate Python code \n{logger_code}: \nError: {e}")
|
||||
# raise Exception("Could not evaluate Python code", e)
|
||||
# return res
|
||||
#
|
||||
# return wrapper(wrapped)
|
||||
|
|
|
|||
|
|
@ -22,13 +22,6 @@ class Moderation:
|
|||
resp.append({"flagged": item.flagged, "true_categories": true_categories})
|
||||
return resp
|
||||
|
||||
def moderation_with_categories(self, content: Union[str, list[str]]):
|
||||
resp = []
|
||||
if content:
|
||||
moderation_results = self.llm.moderation(content=content)
|
||||
resp = self.handle_moderation_results(moderation_results.results)
|
||||
return resp
|
||||
|
||||
async def amoderation_with_categories(self, content: Union[str, list[str]]):
|
||||
resp = []
|
||||
if content:
|
||||
|
|
@ -36,16 +29,6 @@ class Moderation:
|
|||
resp = self.handle_moderation_results(moderation_results.results)
|
||||
return resp
|
||||
|
||||
def moderation(self, content: Union[str, list[str]]):
|
||||
resp = []
|
||||
if content:
|
||||
moderation_results = self.llm.moderation(content=content)
|
||||
results = moderation_results.results
|
||||
for item in results:
|
||||
resp.append(item.flagged)
|
||||
|
||||
return resp
|
||||
|
||||
async def amoderation(self, content: Union[str, list[str]]):
|
||||
resp = []
|
||||
if content:
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ class OpenAIText2Image:
|
|||
:return: The image data is returned in Base64 encoding.
|
||||
"""
|
||||
try:
|
||||
result = await self._llm.async_client.images.generate(prompt=text, n=1, size=size_type)
|
||||
result = await self._llm.aclient.images.generate(prompt=text, n=1, size=size_type)
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred:{e}")
|
||||
return ""
|
||||
|
|
|
|||
|
|
@ -1,58 +1,58 @@
|
|||
#!/usr/bin/env python3
|
||||
# _*_ coding: utf-8 _*_
|
||||
|
||||
"""
|
||||
@Time : 2023/10/09 18:40:34
|
||||
@Author : Stitch-z
|
||||
@File : test_invoice_ocr.py
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.invoice_ocr import GenerateTable, InvoiceOCR, ReplyQuestion
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"invoice_path",
|
||||
[
|
||||
"../../data/invoices/invoice-3.jpg",
|
||||
"../../data/invoices/invoice-4.zip",
|
||||
],
|
||||
)
|
||||
async def test_invoice_ocr(invoice_path: str):
|
||||
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
|
||||
filename = os.path.basename(invoice_path)
|
||||
resp = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)
|
||||
assert isinstance(resp, list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("invoice_path", "expected_result"),
|
||||
[
|
||||
("../../data/invoices/invoice-1.pdf", [{"收款人": "小明", "城市": "深圳市", "总费用/元": "412.00", "开票日期": "2023年02月03日"}]),
|
||||
],
|
||||
)
|
||||
async def test_generate_table(invoice_path: str, expected_result: list[dict]):
|
||||
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
|
||||
filename = os.path.basename(invoice_path)
|
||||
ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)
|
||||
table_data = await GenerateTable().run(ocr_results=ocr_result, filename=filename)
|
||||
assert table_data == expected_result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("invoice_path", "query", "expected_result"),
|
||||
[("../../data/invoices/invoice-1.pdf", "Invoicing date", "2023年02月03日")],
|
||||
)
|
||||
async def test_reply_question(invoice_path: str, query: dict, expected_result: str):
|
||||
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
|
||||
filename = os.path.basename(invoice_path)
|
||||
ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)
|
||||
result = await ReplyQuestion().run(query=query, ocr_result=ocr_result)
|
||||
assert expected_result in result
|
||||
# #!/usr/bin/env python3
|
||||
# # _*_ coding: utf-8 _*_
|
||||
#
|
||||
# """
|
||||
# @Time : 2023/10/09 18:40:34
|
||||
# @Author : Stitch-z
|
||||
# @File : test_invoice_ocr.py
|
||||
# """
|
||||
#
|
||||
# import os
|
||||
# from pathlib import Path
|
||||
#
|
||||
# import pytest
|
||||
#
|
||||
# from metagpt.actions.invoice_ocr import GenerateTable, InvoiceOCR, ReplyQuestion
|
||||
#
|
||||
#
|
||||
# @pytest.mark.asyncio
|
||||
# @pytest.mark.parametrize(
|
||||
# "invoice_path",
|
||||
# [
|
||||
# "../../data/invoices/invoice-3.jpg",
|
||||
# "../../data/invoices/invoice-4.zip",
|
||||
# ],
|
||||
# )
|
||||
# async def test_invoice_ocr(invoice_path: str):
|
||||
# invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
|
||||
# filename = os.path.basename(invoice_path)
|
||||
# resp = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)
|
||||
# assert isinstance(resp, list)
|
||||
#
|
||||
#
|
||||
# @pytest.mark.asyncio
|
||||
# @pytest.mark.parametrize(
|
||||
# ("invoice_path", "expected_result"),
|
||||
# [
|
||||
# ("../../data/invoices/invoice-1.pdf", [{"收款人": "小明", "城市": "深圳市", "总费用/元": "412.00", "开票日期": "2023年02月03日"}]),
|
||||
# ],
|
||||
# )
|
||||
# async def test_generate_table(invoice_path: str, expected_result: list[dict]):
|
||||
# invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
|
||||
# filename = os.path.basename(invoice_path)
|
||||
# ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)
|
||||
# table_data = await GenerateTable().run(ocr_results=ocr_result, filename=filename)
|
||||
# assert table_data == expected_result
|
||||
#
|
||||
#
|
||||
# @pytest.mark.asyncio
|
||||
# @pytest.mark.parametrize(
|
||||
# ("invoice_path", "query", "expected_result"),
|
||||
# [("../../data/invoices/invoice-1.pdf", "Invoicing date", "2023年02月03日")],
|
||||
# )
|
||||
# async def test_reply_question(invoice_path: str, query: dict, expected_result: str):
|
||||
# invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
|
||||
# filename = os.path.basename(invoice_path)
|
||||
# ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)
|
||||
# result = await ReplyQuestion().run(query=query, ocr_result=ocr_result)
|
||||
# assert expected_result in result
|
||||
|
|
|
|||
|
|
@ -7,12 +7,9 @@
|
|||
"""
|
||||
import random
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.document_store.lancedb_store import LanceStore
|
||||
|
||||
|
||||
@pytest
|
||||
def test_lance_store():
|
||||
# This simply establishes the connection to the database, so we can drop the table if it exists
|
||||
store = LanceStore("test")
|
||||
|
|
|
|||
|
|
@ -1,63 +1,63 @@
|
|||
#!/usr/bin/env python3
|
||||
# _*_ coding: utf-8 _*_
|
||||
|
||||
"""
|
||||
@Time : 2023/9/21 23:11:27
|
||||
@Author : Stitch-z
|
||||
@File : test_invoice_ocr_assistant.py
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant, InvoicePath
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("query", "invoice_path", "invoice_table_path", "expected_result"),
|
||||
[
|
||||
(
|
||||
"Invoicing date",
|
||||
Path("../../data/invoices/invoice-1.pdf"),
|
||||
Path("../../../data/invoice_table/invoice-1.xlsx"),
|
||||
[{"收款人": "小明", "城市": "深圳市", "总费用/元": 412.00, "开票日期": "2023年02月03日"}],
|
||||
),
|
||||
(
|
||||
"Invoicing date",
|
||||
Path("../../data/invoices/invoice-2.png"),
|
||||
Path("../../../data/invoice_table/invoice-2.xlsx"),
|
||||
[{"收款人": "铁头", "城市": "广州市", "总费用/元": 898.00, "开票日期": "2023年03月17日"}],
|
||||
),
|
||||
(
|
||||
"Invoicing date",
|
||||
Path("../../data/invoices/invoice-3.jpg"),
|
||||
Path("../../../data/invoice_table/invoice-3.xlsx"),
|
||||
[{"收款人": "夏天", "城市": "福州市", "总费用/元": 2462.00, "开票日期": "2023年08月26日"}],
|
||||
),
|
||||
(
|
||||
"Invoicing date",
|
||||
Path("../../data/invoices/invoice-4.zip"),
|
||||
Path("../../../data/invoice_table/invoice-4.xlsx"),
|
||||
[
|
||||
{"收款人": "小明", "城市": "深圳市", "总费用/元": 412.00, "开票日期": "2023年02月03日"},
|
||||
{"收款人": "铁头", "城市": "广州市", "总费用/元": 898.00, "开票日期": "2023年03月17日"},
|
||||
{"收款人": "夏天", "城市": "福州市", "总费用/元": 2462.00, "开票日期": "2023年08月26日"},
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_invoice_ocr_assistant(
|
||||
query: str, invoice_path: Path, invoice_table_path: Path, expected_result: list[dict]
|
||||
):
|
||||
invoice_path = Path.cwd() / invoice_path
|
||||
role = InvoiceOCRAssistant()
|
||||
await role.run(Message(content=query, instruct_content=InvoicePath(file_path=invoice_path)))
|
||||
invoice_table_path = Path.cwd() / invoice_table_path
|
||||
df = pd.read_excel(invoice_table_path)
|
||||
dict_result = df.to_dict(orient="records")
|
||||
assert json.dumps(dict_result) == json.dumps(expected_result)
|
||||
# #!/usr/bin/env python3
|
||||
# # _*_ coding: utf-8 _*_
|
||||
#
|
||||
# """
|
||||
# @Time : 2023/9/21 23:11:27
|
||||
# @Author : Stitch-z
|
||||
# @File : test_invoice_ocr_assistant.py
|
||||
# """
|
||||
#
|
||||
# import json
|
||||
# from pathlib import Path
|
||||
#
|
||||
# import pandas as pd
|
||||
# import pytest
|
||||
#
|
||||
# from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant, InvoicePath
|
||||
# from metagpt.schema import Message
|
||||
#
|
||||
#
|
||||
# @pytest.mark.asyncio
|
||||
# @pytest.mark.parametrize(
|
||||
# ("query", "invoice_path", "invoice_table_path", "expected_result"),
|
||||
# [
|
||||
# (
|
||||
# "Invoicing date",
|
||||
# Path("../../data/invoices/invoice-1.pdf"),
|
||||
# Path("../../../data/invoice_table/invoice-1.xlsx"),
|
||||
# [{"收款人": "小明", "城市": "深圳市", "总费用/元": 412.00, "开票日期": "2023年02月03日"}],
|
||||
# ),
|
||||
# (
|
||||
# "Invoicing date",
|
||||
# Path("../../data/invoices/invoice-2.png"),
|
||||
# Path("../../../data/invoice_table/invoice-2.xlsx"),
|
||||
# [{"收款人": "铁头", "城市": "广州市", "总费用/元": 898.00, "开票日期": "2023年03月17日"}],
|
||||
# ),
|
||||
# (
|
||||
# "Invoicing date",
|
||||
# Path("../../data/invoices/invoice-3.jpg"),
|
||||
# Path("../../../data/invoice_table/invoice-3.xlsx"),
|
||||
# [{"收款人": "夏天", "城市": "福州市", "总费用/元": 2462.00, "开票日期": "2023年08月26日"}],
|
||||
# ),
|
||||
# (
|
||||
# "Invoicing date",
|
||||
# Path("../../data/invoices/invoice-4.zip"),
|
||||
# Path("../../../data/invoice_table/invoice-4.xlsx"),
|
||||
# [
|
||||
# {"收款人": "小明", "城市": "深圳市", "总费用/元": 412.00, "开票日期": "2023年02月03日"},
|
||||
# {"收款人": "铁头", "城市": "广州市", "总费用/元": 898.00, "开票日期": "2023年03月17日"},
|
||||
# {"收款人": "夏天", "城市": "福州市", "总费用/元": 2462.00, "开票日期": "2023年08月26日"},
|
||||
# ],
|
||||
# ),
|
||||
# ],
|
||||
# )
|
||||
# async def test_invoice_ocr_assistant(
|
||||
# query: str, invoice_path: Path, invoice_table_path: Path, expected_result: list[dict]
|
||||
# ):
|
||||
# invoice_path = Path.cwd() / invoice_path
|
||||
# role = InvoiceOCRAssistant()
|
||||
# await role.run(Message(content=query, instruct_content=InvoicePath(file_path=invoice_path)))
|
||||
# invoice_table_path = Path.cwd() / invoice_table_path
|
||||
# df = pd.read_excel(invoice_table_path)
|
||||
# dict_result = df.to_dict(orient="records")
|
||||
# assert json.dumps(dict_result) == json.dumps(expected_result)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@
|
|||
"""
|
||||
import pytest
|
||||
|
||||
from metagpt.schema import AIMessage, Message, RawMessage, SystemMessage, UserMessage
|
||||
from metagpt.schema import AIMessage, Message, SystemMessage, UserMessage
|
||||
|
||||
|
||||
def test_message():
|
||||
|
|
@ -29,13 +29,5 @@ def test_all_messages():
|
|||
assert msg.content == test_content
|
||||
|
||||
|
||||
def test_raw_message():
|
||||
msg = RawMessage(role="user", content="raw")
|
||||
assert msg["role"] == "user"
|
||||
assert msg["content"] == "raw"
|
||||
with pytest.raises(KeyError):
|
||||
assert msg["1"] == 1, "KeyError: '1'"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -8,53 +8,46 @@
|
|||
open-interpreter 0.1.17 requires tiktoken<0.5.0,>=0.4.0, but you have tiktoken 0.5.2 which is incompatible.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools.code_interpreter import OpenCodeInterpreter, OpenInterpreterDecorator
|
||||
|
||||
logger.add("./tests/data/test_ci.log")
|
||||
stock = "./tests/data/baba_stock.csv"
|
||||
|
||||
|
||||
# TODO: 需要一种表格数据格式,能够支持schame管理的,标注字段类型和字段含义。
|
||||
class CreateStockIndicators(Action):
|
||||
@OpenInterpreterDecorator(save_code=True, code_file_path="./tests/data/stock_indicators.py")
|
||||
async def run(self, stock_path: str, indicators=["Simple Moving Average", "BollingerBands"]) -> pd.DataFrame:
|
||||
"""对stock_path中的股票数据, 使用pandas和ta计算indicators中的技术指标, 返回带有技术指标的股票数据,不需要去除空值, 不需要安装任何包;
|
||||
指标生成对应的三列: SMA, BB_upper, BB_lower
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_actions():
|
||||
# Prerequisites
|
||||
# Conflict with openai 1.x
|
||||
|
||||
# 计算指标
|
||||
indicators = ["Simple Moving Average", "BollingerBands"]
|
||||
stocker = CreateStockIndicators()
|
||||
df, msg = await stocker.run(stock, indicators=indicators)
|
||||
assert isinstance(df, pd.DataFrame)
|
||||
assert "Close" in df.columns
|
||||
assert "Date" in df.columns
|
||||
# 将df保存为文件,将文件路径传入到下一个action
|
||||
df_path = "./tests/data/stock_indicators.csv"
|
||||
df.to_csv(df_path)
|
||||
assert Path(df_path).is_file()
|
||||
# 可视化指标结果
|
||||
figure_path = "./tests/data/figure_ci.png"
|
||||
ci_ploter = OpenCodeInterpreter()
|
||||
ci_ploter.chat(
|
||||
f"使用seaborn对{df_path}中与股票布林带有关的数据列的Date, Close, SMA, BB_upper(布林带上界), BB_lower(布林带下界)进行可视化, 可视化图片保存在{figure_path}中。不需要任何指标计算,把Date列转换为日期类型。要求图片优美,BB_upper, BB_lower之间使用合适的颜色填充。"
|
||||
)
|
||||
assert Path(figure_path).is_file()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
# from pathlib import Path
|
||||
#
|
||||
# import pandas as pd
|
||||
# import pytest
|
||||
#
|
||||
# from metagpt.actions import Action
|
||||
# from metagpt.logs import logger
|
||||
# from metagpt.tools.code_interpreter import OpenCodeInterpreter, OpenInterpreterDecorator
|
||||
#
|
||||
# logger.add("./tests/data/test_ci.log")
|
||||
# stock = "./tests/data/baba_stock.csv"
|
||||
#
|
||||
#
|
||||
# # TODO: 需要一种表格数据格式,能够支持schame管理的,标注字段类型和字段含义。
|
||||
# class CreateStockIndicators(Action):
|
||||
# @OpenInterpreterDecorator(save_code=True, code_file_path="./tests/data/stock_indicators.py")
|
||||
# async def run(self, stock_path: str, indicators=["Simple Moving Average", "BollingerBands"]) -> pd.DataFrame:
|
||||
# """对stock_path中的股票数据, 使用pandas和ta计算indicators中的技术指标, 返回带有技术指标的股票数据,不需要去除空值, 不需要安装任何包;
|
||||
# 指标生成对应的三列: SMA, BB_upper, BB_lower
|
||||
# """
|
||||
# ...
|
||||
#
|
||||
#
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_actions():
|
||||
# # 计算指标
|
||||
# indicators = ["Simple Moving Average", "BollingerBands"]
|
||||
# stocker = CreateStockIndicators()
|
||||
# df, msg = await stocker.run(stock, indicators=indicators)
|
||||
# assert isinstance(df, pd.DataFrame)
|
||||
# assert "Close" in df.columns
|
||||
# assert "Date" in df.columns
|
||||
# # 将df保存为文件,将文件路径传入到下一个action
|
||||
# df_path = "./tests/data/stock_indicators.csv"
|
||||
# df.to_csv(df_path)
|
||||
# assert Path(df_path).is_file()
|
||||
# # 可视化指标结果
|
||||
# figure_path = "./tests/data/figure_ci.png"
|
||||
# ci_ploter = OpenCodeInterpreter()
|
||||
# ci_ploter.chat(
|
||||
# f"使用seaborn对{df_path}中与股票布林带有关的数据列的Date, Close, SMA, BB_upper(布林带上界), BB_lower(布林带下界)进行可视化, 可视化图片保存在{figure_path}中。不需要任何指标计算,把Date列转换为日期类型。要求图片优美,BB_upper, BB_lower之间使用合适的颜色填充。"
|
||||
# )
|
||||
# assert Path(figure_path).is_file()
|
||||
|
|
|
|||
|
|
@ -12,33 +12,6 @@ from metagpt.config import CONFIG
|
|||
from metagpt.tools.moderation import Moderation
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("content",),
|
||||
[
|
||||
[
|
||||
["I will kill you", "The weather is really nice today", "I want to hit you"],
|
||||
]
|
||||
],
|
||||
)
|
||||
def test_moderation(content):
|
||||
# Prerequisites
|
||||
assert CONFIG.OPENAI_API_KEY and CONFIG.OPENAI_API_KEY != "YOUR_API_KEY"
|
||||
assert not CONFIG.OPENAI_API_TYPE
|
||||
assert CONFIG.OPENAI_API_MODEL
|
||||
|
||||
moderation = Moderation()
|
||||
results = moderation.moderation(content=content)
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == len(content)
|
||||
|
||||
results = moderation.moderation_with_categories(content=content)
|
||||
assert isinstance(results, list)
|
||||
assert results
|
||||
for m in results:
|
||||
assert "flagged" in m
|
||||
assert "true_categories" in m
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("content",),
|
||||
|
|
|
|||
|
|
@ -9,7 +9,10 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.tools.openai_text_to_image import oas3_openai_text_to_image
|
||||
from metagpt.tools.openai_text_to_image import (
|
||||
OpenAIText2Image,
|
||||
oas3_openai_text_to_image,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -23,5 +26,13 @@ async def test_draw():
|
|||
assert binary_data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_image():
|
||||
data = await OpenAIText2Image.get_image_data(
|
||||
url="https://www.baidu.com/img/PCtm_d9c8750bed0b3c7d089fa7d55720d6cf.png"
|
||||
)
|
||||
assert data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
"""
|
||||
|
||||
import os
|
||||
import platform
|
||||
from typing import Any, Set
|
||||
|
||||
import pytest
|
||||
|
|
@ -17,7 +18,7 @@ from metagpt.actions import RunCode
|
|||
from metagpt.const import get_metagpt_root
|
||||
from metagpt.roles.tutorial_assistant import TutorialAssistant
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import any_to_str, any_to_str_set
|
||||
from metagpt.utils.common import any_to_str, any_to_str_set, check_cmd_exists
|
||||
|
||||
|
||||
class TestGetProjectRoot:
|
||||
|
|
@ -28,13 +29,12 @@ class TestGetProjectRoot:
|
|||
|
||||
def test_get_project_root(self):
|
||||
project_root = get_metagpt_root()
|
||||
assert project_root.name == "metagpt"
|
||||
assert project_root.name == "MetaGPT"
|
||||
|
||||
def test_get_root_exception(self):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
self.change_etc_dir()
|
||||
get_metagpt_root()
|
||||
assert str(exc_info.value) == "Project root not found."
|
||||
self.change_etc_dir()
|
||||
project_root = get_metagpt_root()
|
||||
assert project_root
|
||||
|
||||
def test_any_to_str(self):
|
||||
class Input(BaseModel):
|
||||
|
|
@ -65,8 +65,8 @@ class TestGetProjectRoot:
|
|||
want={"metagpt.roles.tutorial_assistant.TutorialAssistant", "metagpt.actions.run_code.RunCode", "a"},
|
||||
),
|
||||
Input(
|
||||
x={TutorialAssistant, RunCode(), "a"},
|
||||
want={"metagpt.roles.tutorial_assistant.TutorialAssistant", "metagpt.actions.run_code.RunCode", "a"},
|
||||
x={TutorialAssistant, "a"},
|
||||
want={"metagpt.roles.tutorial_assistant.TutorialAssistant", "a"},
|
||||
),
|
||||
Input(
|
||||
x=(TutorialAssistant, RunCode(), "a"),
|
||||
|
|
@ -77,6 +77,25 @@ class TestGetProjectRoot:
|
|||
v = any_to_str_set(i.x)
|
||||
assert v == i.want
|
||||
|
||||
def test_check_cmd_exists(self):
|
||||
class Input(BaseModel):
|
||||
command: str
|
||||
platform: str
|
||||
|
||||
inputs = [
|
||||
{"command": "cat", "platform": "linux"},
|
||||
{"command": "ls", "platform": "linux"},
|
||||
{"command": "mspaint", "platform": "windows"},
|
||||
]
|
||||
plat = "windows" if platform.system().lower() == "windows" else "linux"
|
||||
for i in inputs:
|
||||
seed = Input(**i)
|
||||
result = check_cmd_exists(seed.command)
|
||||
if plat == seed.platform:
|
||||
assert result == 0
|
||||
else:
|
||||
assert result != 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue