diff --git a/.gitignore b/.gitignore index 93e24ba48..039ba1956 100644 --- a/.gitignore +++ b/.gitignore @@ -164,4 +164,5 @@ metagpt/roles/idea_agent.py # output folder output tmp.png - +.dependencies.json +tests/metagpt/utils/file_repo_git diff --git a/metagpt/memory/brain_memory.py b/metagpt/memory/brain_memory.py index 0833d71a1..c882859d8 100644 --- a/metagpt/memory/brain_memory.py +++ b/metagpt/memory/brain_memory.py @@ -33,6 +33,9 @@ class BrainMemory(BaseModel): cacheable: bool = True llm: Optional[BaseLLM] = None + class Config: + arbitrary_types_allowed = True + def add_talk(self, msg: Message): """ Add message from user. diff --git a/metagpt/tools/code_interpreter.py b/metagpt/tools/code_interpreter.py index 9575d6c13..5592b0704 100644 --- a/metagpt/tools/code_interpreter.py +++ b/metagpt/tools/code_interpreter.py @@ -1,197 +1,207 @@ -import inspect -import re -import textwrap -from pathlib import Path -from typing import Callable, Dict, List +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : +@Author : +@File : code_interpreter.py +@Warning : open-interpreter 0.1.17 requires openai<0.29.0,>=0.28.0, but you have openai 1.6.0 which is incompatible. + open-interpreter 0.1.17 requires tiktoken<0.5.0,>=0.4.0, but you have tiktoken 0.5.2 which is incompatible. +""" -import wrapt -from interpreter.core.core import Interpreter - -from metagpt.actions.clone_function import ( - CloneFunction, - run_function_code, - run_function_script, -) -from metagpt.config import CONFIG -from metagpt.logs import logger -from metagpt.utils.highlight import highlight - - -def extract_python_code(code: str): - """Extract code blocks: If the code comments are the same, only the last code block is kept.""" - # Use regular expressions to match comment blocks and related code. - pattern = r"(#\s[^\n]*)\n(.*?)(?=\n\s*#|$)" - matches = re.findall(pattern, code, re.DOTALL) - - # Extract the last code block when encountering the same comment. - unique_comments = {} - for comment, code_block in matches: - unique_comments[comment] = code_block - - # concatenate into functional form - result_code = "\n".join([f"{comment}\n{code_block}" for comment, code_block in unique_comments.items()]) - header_code = code[: code.find("#")] - code = header_code + result_code - - logger.info(f"Extract python code: \n {highlight(code)}") - - return code - - -class OpenCodeInterpreter(object): - """https://github.com/KillianLucas/open-interpreter""" - - def __init__(self, auto_run: bool = True) -> None: - interpreter = Interpreter() - interpreter.auto_run = auto_run - interpreter.model = CONFIG.openai_api_model or "gpt-3.5-turbo" - interpreter.api_key = CONFIG.openai_api_key - self.interpreter = interpreter - - def chat(self, query: str, reset: bool = True): - if reset: - self.interpreter.reset() - return self.interpreter.chat(query) - - @staticmethod - def extract_function( - query_respond: List, function_name: str, *, language: str = "python", function_format: str = None - ) -> str: - """create a function from query_respond.""" - if language not in ("python"): - raise NotImplementedError(f"Not support to parse language {language}!") - - # set function form - if function_format is None: - assert language == "python", f"Expect python language for default function_format, but got {language}." - function_format = """def {function_name}():\n{code}""" - # Extract the code module in the open-interpreter respond message. - # The query_respond of open-interpreter before v0.1.4 is: - # [{'role': 'user', 'content': your query string}, - # {'role': 'assistant', 'content': plan from llm, 'function_call': { - # "name": "run_code", "arguments": "{"language": "python", "code": code of first plan}, - # "parsed_arguments": {"language": "python", "code": code of first plan} - # ...] - if "function_call" in query_respond[1]: - code = [ - item["function_call"]["parsed_arguments"]["code"] - for item in query_respond - if "function_call" in item - and "parsed_arguments" in item["function_call"] - and "language" in item["function_call"]["parsed_arguments"] - and item["function_call"]["parsed_arguments"]["language"] == language - ] - # The query_respond of open-interpreter v0.1.7 is: - # [{'role': 'user', 'message': your query string}, - # {'role': 'assistant', 'message': plan from llm, 'language': 'python', - # 'code': code of first plan, 'output': output of first plan code}, - # ...] - elif "code" in query_respond[1]: - code = [ - item["code"] - for item in query_respond - if "code" in item and "language" in item and item["language"] == language - ] - else: - raise ValueError(f"Unexpect message format in query_respond: {query_respond[1].keys()}") - # add indent. - indented_code_str = textwrap.indent("\n".join(code), " " * 4) - # Return the code after deduplication. - if language == "python": - return extract_python_code(function_format.format(function_name=function_name, code=indented_code_str)) - - -def gen_query(func: Callable, args, kwargs) -> str: - # Get the annotation of the function as part of the query. - desc = func.__doc__ - signature = inspect.signature(func) - # Get the signature of the wrapped function and the assignment of the input parameters as part of the query. - bound_args = signature.bind(*args, **kwargs) - bound_args.apply_defaults() - query = f"{desc}, {bound_args.arguments}, If you must use a third-party package, use the most popular ones, for example: pandas, numpy, ta, ..." - return query - - -def gen_template_fun(func: Callable) -> str: - return f"def {func.__name__}{str(inspect.signature(func))}\n # here is your code ..." - - -class OpenInterpreterDecorator(object): - def __init__(self, save_code: bool = False, code_file_path: str = None, clear_code: bool = False) -> None: - self.save_code = save_code - self.code_file_path = code_file_path - self.clear_code = clear_code - - def _have_code(self, rsp: List[Dict]): - # Is there any code generated? - return "code" in rsp[1] and rsp[1]["code"] not in ("", None) - - def _is_faild_plan(self, rsp: List[Dict]): - # is faild plan? - func_code = OpenCodeInterpreter.extract_function(rsp, "function") - # If there is no more than 1 '\n', the plan execution fails. - if isinstance(func_code, str) and func_code.count("\n") <= 1: - return True - return False - - def _check_respond(self, query: str, interpreter: OpenCodeInterpreter, respond: List[Dict], max_try: int = 3): - for _ in range(max_try): - # TODO: If no code or faild plan is generated, execute chat again, repeating no more than max_try times. - if self._have_code(respond) and not self._is_faild_plan(respond): - break - elif not self._have_code(respond): - logger.warning(f"llm did not return executable code, resend the query: \n{query}") - respond = interpreter.chat(query) - elif self._is_faild_plan(respond): - logger.warning(f"llm did not generate successful plan, resend the query: \n{query}") - respond = interpreter.chat(query) - - # Post-processing of respond - if not self._have_code(respond): - error_msg = f"OpenCodeInterpreter do not generate code for query: \n{query}" - logger.error(error_msg) - raise ValueError(error_msg) - - if self._is_faild_plan(respond): - error_msg = f"OpenCodeInterpreter do not generate code for query: \n{query}" - logger.error(error_msg) - raise ValueError(error_msg) - return respond - - def __call__(self, wrapped): - @wrapt.decorator - async def wrapper(wrapped: Callable, instance, args, kwargs): - # Get the decorated function name. - func_name = wrapped.__name__ - # If the script exists locally and clearcode is not required, execute the function from the script. - if self.code_file_path and Path(self.code_file_path).is_file() and not self.clear_code: - return run_function_script(self.code_file_path, func_name, *args, **kwargs) - - # Auto run generate code by using open-interpreter. - interpreter = OpenCodeInterpreter() - query = gen_query(wrapped, args, kwargs) - logger.info(f"query for OpenCodeInterpreter: \n {query}") - respond = interpreter.chat(query) - # Make sure the response is as expected. - respond = self._check_respond(query, interpreter, respond, 3) - # Assemble the code blocks generated by open-interpreter into a function without parameters. - func_code = interpreter.extract_function(respond, func_name) - # Clone the `func_code` into wrapped, that is, - # keep the `func_code` and wrapped functions with the same input parameter and return value types. - template_func = gen_template_fun(wrapped) - cf = CloneFunction() - code = await cf.run(template_func=template_func, source_code=func_code) - # Display the generated function in the terminal. - logger_code = highlight(code, "python") - logger.info(f"Creating following Python function:\n{logger_code}") - # execute this function. - try: - res = run_function_code(code, func_name, *args, **kwargs) - if self.save_code and self.code_file_path: - cf._save(self.code_file_path, code) - except Exception as e: - logger.error(f"Could not evaluate Python code \n{logger_code}: \nError: {e}") - raise Exception("Could not evaluate Python code", e) - return res - - return wrapper(wrapped) +# import inspect +# import re +# import textwrap +# from pathlib import Path +# from typing import Callable, Dict, List +# +# import wrapt +# from interpreter.core.core import Interpreter +# +# from metagpt.actions.clone_function import ( +# CloneFunction, +# run_function_code, +# run_function_script, +# ) +# from metagpt.config import CONFIG +# from metagpt.logs import logger +# from metagpt.utils.highlight import highlight +# +# +# def extract_python_code(code: str): +# """Extract code blocks: If the code comments are the same, only the last code block is kept.""" +# # Use regular expressions to match comment blocks and related code. +# pattern = r"(#\s[^\n]*)\n(.*?)(?=\n\s*#|$)" +# matches = re.findall(pattern, code, re.DOTALL) +# +# # Extract the last code block when encountering the same comment. +# unique_comments = {} +# for comment, code_block in matches: +# unique_comments[comment] = code_block +# +# # concatenate into functional form +# result_code = "\n".join([f"{comment}\n{code_block}" for comment, code_block in unique_comments.items()]) +# header_code = code[: code.find("#")] +# code = header_code + result_code +# +# logger.info(f"Extract python code: \n {highlight(code)}") +# +# return code +# +# +# class OpenCodeInterpreter(object): +# """https://github.com/KillianLucas/open-interpreter""" +# +# def __init__(self, auto_run: bool = True) -> None: +# interpreter = Interpreter() +# interpreter.auto_run = auto_run +# interpreter.model = CONFIG.openai_api_model or "gpt-3.5-turbo" +# interpreter.api_key = CONFIG.openai_api_key +# self.interpreter = interpreter +# +# def chat(self, query: str, reset: bool = True): +# if reset: +# self.interpreter.reset() +# return self.interpreter.chat(query) +# +# @staticmethod +# def extract_function( +# query_respond: List, function_name: str, *, language: str = "python", function_format: str = None +# ) -> str: +# """create a function from query_respond.""" +# if language not in ("python"): +# raise NotImplementedError(f"Not support to parse language {language}!") +# +# # set function form +# if function_format is None: +# assert language == "python", f"Expect python language for default function_format, but got {language}." +# function_format = """def {function_name}():\n{code}""" +# # Extract the code module in the open-interpreter respond message. +# # The query_respond of open-interpreter before v0.1.4 is: +# # [{'role': 'user', 'content': your query string}, +# # {'role': 'assistant', 'content': plan from llm, 'function_call': { +# # "name": "run_code", "arguments": "{"language": "python", "code": code of first plan}, +# # "parsed_arguments": {"language": "python", "code": code of first plan} +# # ...] +# if "function_call" in query_respond[1]: +# code = [ +# item["function_call"]["parsed_arguments"]["code"] +# for item in query_respond +# if "function_call" in item +# and "parsed_arguments" in item["function_call"] +# and "language" in item["function_call"]["parsed_arguments"] +# and item["function_call"]["parsed_arguments"]["language"] == language +# ] +# # The query_respond of open-interpreter v0.1.7 is: +# # [{'role': 'user', 'message': your query string}, +# # {'role': 'assistant', 'message': plan from llm, 'language': 'python', +# # 'code': code of first plan, 'output': output of first plan code}, +# # ...] +# elif "code" in query_respond[1]: +# code = [ +# item["code"] +# for item in query_respond +# if "code" in item and "language" in item and item["language"] == language +# ] +# else: +# raise ValueError(f"Unexpect message format in query_respond: {query_respond[1].keys()}") +# # add indent. +# indented_code_str = textwrap.indent("\n".join(code), " " * 4) +# # Return the code after deduplication. +# if language == "python": +# return extract_python_code(function_format.format(function_name=function_name, code=indented_code_str)) +# +# +# def gen_query(func: Callable, args, kwargs) -> str: +# # Get the annotation of the function as part of the query. +# desc = func.__doc__ +# signature = inspect.signature(func) +# # Get the signature of the wrapped function and the assignment of the input parameters as part of the query. +# bound_args = signature.bind(*args, **kwargs) +# bound_args.apply_defaults() +# query = f"{desc}, {bound_args.arguments}, If you must use a third-party package, use the most popular ones, for example: pandas, numpy, ta, ..." +# return query +# +# +# def gen_template_fun(func: Callable) -> str: +# return f"def {func.__name__}{str(inspect.signature(func))}\n # here is your code ..." +# +# +# class OpenInterpreterDecorator(object): +# def __init__(self, save_code: bool = False, code_file_path: str = None, clear_code: bool = False) -> None: +# self.save_code = save_code +# self.code_file_path = code_file_path +# self.clear_code = clear_code +# +# def _have_code(self, rsp: List[Dict]): +# # Is there any code generated? +# return "code" in rsp[1] and rsp[1]["code"] not in ("", None) +# +# def _is_faild_plan(self, rsp: List[Dict]): +# # is faild plan? +# func_code = OpenCodeInterpreter.extract_function(rsp, "function") +# # If there is no more than 1 '\n', the plan execution fails. +# if isinstance(func_code, str) and func_code.count("\n") <= 1: +# return True +# return False +# +# def _check_respond(self, query: str, interpreter: OpenCodeInterpreter, respond: List[Dict], max_try: int = 3): +# for _ in range(max_try): +# # TODO: If no code or faild plan is generated, execute chat again, repeating no more than max_try times. +# if self._have_code(respond) and not self._is_faild_plan(respond): +# break +# elif not self._have_code(respond): +# logger.warning(f"llm did not return executable code, resend the query: \n{query}") +# respond = interpreter.chat(query) +# elif self._is_faild_plan(respond): +# logger.warning(f"llm did not generate successful plan, resend the query: \n{query}") +# respond = interpreter.chat(query) +# +# # Post-processing of respond +# if not self._have_code(respond): +# error_msg = f"OpenCodeInterpreter do not generate code for query: \n{query}" +# logger.error(error_msg) +# raise ValueError(error_msg) +# +# if self._is_faild_plan(respond): +# error_msg = f"OpenCodeInterpreter do not generate code for query: \n{query}" +# logger.error(error_msg) +# raise ValueError(error_msg) +# return respond +# +# def __call__(self, wrapped): +# @wrapt.decorator +# async def wrapper(wrapped: Callable, instance, args, kwargs): +# # Get the decorated function name. +# func_name = wrapped.__name__ +# # If the script exists locally and clearcode is not required, execute the function from the script. +# if self.code_file_path and Path(self.code_file_path).is_file() and not self.clear_code: +# return run_function_script(self.code_file_path, func_name, *args, **kwargs) +# +# # Auto run generate code by using open-interpreter. +# interpreter = OpenCodeInterpreter() +# query = gen_query(wrapped, args, kwargs) +# logger.info(f"query for OpenCodeInterpreter: \n {query}") +# respond = interpreter.chat(query) +# # Make sure the response is as expected. +# respond = self._check_respond(query, interpreter, respond, 3) +# # Assemble the code blocks generated by open-interpreter into a function without parameters. +# func_code = interpreter.extract_function(respond, func_name) +# # Clone the `func_code` into wrapped, that is, +# # keep the `func_code` and wrapped functions with the same input parameter and return value types. +# template_func = gen_template_fun(wrapped) +# cf = CloneFunction() +# code = await cf.run(template_func=template_func, source_code=func_code) +# # Display the generated function in the terminal. +# logger_code = highlight(code, "python") +# logger.info(f"Creating following Python function:\n{logger_code}") +# # execute this function. +# try: +# res = run_function_code(code, func_name, *args, **kwargs) +# if self.save_code and self.code_file_path: +# cf._save(self.code_file_path, code) +# except Exception as e: +# logger.error(f"Could not evaluate Python code \n{logger_code}: \nError: {e}") +# raise Exception("Could not evaluate Python code", e) +# return res +# +# return wrapper(wrapped) diff --git a/metagpt/tools/moderation.py b/metagpt/tools/moderation.py index e4b23d538..cda164ec5 100644 --- a/metagpt/tools/moderation.py +++ b/metagpt/tools/moderation.py @@ -22,13 +22,6 @@ class Moderation: resp.append({"flagged": item.flagged, "true_categories": true_categories}) return resp - def moderation_with_categories(self, content: Union[str, list[str]]): - resp = [] - if content: - moderation_results = self.llm.moderation(content=content) - resp = self.handle_moderation_results(moderation_results.results) - return resp - async def amoderation_with_categories(self, content: Union[str, list[str]]): resp = [] if content: @@ -36,16 +29,6 @@ class Moderation: resp = self.handle_moderation_results(moderation_results.results) return resp - def moderation(self, content: Union[str, list[str]]): - resp = [] - if content: - moderation_results = self.llm.moderation(content=content) - results = moderation_results.results - for item in results: - resp.append(item.flagged) - - return resp - async def amoderation(self, content: Union[str, list[str]]): resp = [] if content: diff --git a/metagpt/tools/openai_text_to_image.py b/metagpt/tools/openai_text_to_image.py index bd6078245..aa00abdcc 100644 --- a/metagpt/tools/openai_text_to_image.py +++ b/metagpt/tools/openai_text_to_image.py @@ -29,7 +29,7 @@ class OpenAIText2Image: :return: The image data is returned in Base64 encoding. """ try: - result = await self._llm.async_client.images.generate(prompt=text, n=1, size=size_type) + result = await self._llm.aclient.images.generate(prompt=text, n=1, size=size_type) except Exception as e: logger.error(f"An error occurred:{e}") return "" diff --git a/tests/metagpt/actions/test_invoice_ocr.py b/tests/metagpt/actions/test_invoice_ocr.py index 7f16aa9a4..ddadda7e6 100644 --- a/tests/metagpt/actions/test_invoice_ocr.py +++ b/tests/metagpt/actions/test_invoice_ocr.py @@ -1,58 +1,58 @@ -#!/usr/bin/env python3 -# _*_ coding: utf-8 _*_ - -""" -@Time : 2023/10/09 18:40:34 -@Author : Stitch-z -@File : test_invoice_ocr.py -""" - -import os -from pathlib import Path - -import pytest - -from metagpt.actions.invoice_ocr import GenerateTable, InvoiceOCR, ReplyQuestion - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "invoice_path", - [ - "../../data/invoices/invoice-3.jpg", - "../../data/invoices/invoice-4.zip", - ], -) -async def test_invoice_ocr(invoice_path: str): - invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path)) - filename = os.path.basename(invoice_path) - resp = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename) - assert isinstance(resp, list) - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - ("invoice_path", "expected_result"), - [ - ("../../data/invoices/invoice-1.pdf", [{"收款人": "小明", "城市": "深圳市", "总费用/元": "412.00", "开票日期": "2023年02月03日"}]), - ], -) -async def test_generate_table(invoice_path: str, expected_result: list[dict]): - invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path)) - filename = os.path.basename(invoice_path) - ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename) - table_data = await GenerateTable().run(ocr_results=ocr_result, filename=filename) - assert table_data == expected_result - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - ("invoice_path", "query", "expected_result"), - [("../../data/invoices/invoice-1.pdf", "Invoicing date", "2023年02月03日")], -) -async def test_reply_question(invoice_path: str, query: dict, expected_result: str): - invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path)) - filename = os.path.basename(invoice_path) - ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename) - result = await ReplyQuestion().run(query=query, ocr_result=ocr_result) - assert expected_result in result +# #!/usr/bin/env python3 +# # _*_ coding: utf-8 _*_ +# +# """ +# @Time : 2023/10/09 18:40:34 +# @Author : Stitch-z +# @File : test_invoice_ocr.py +# """ +# +# import os +# from pathlib import Path +# +# import pytest +# +# from metagpt.actions.invoice_ocr import GenerateTable, InvoiceOCR, ReplyQuestion +# +# +# @pytest.mark.asyncio +# @pytest.mark.parametrize( +# "invoice_path", +# [ +# "../../data/invoices/invoice-3.jpg", +# "../../data/invoices/invoice-4.zip", +# ], +# ) +# async def test_invoice_ocr(invoice_path: str): +# invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path)) +# filename = os.path.basename(invoice_path) +# resp = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename) +# assert isinstance(resp, list) +# +# +# @pytest.mark.asyncio +# @pytest.mark.parametrize( +# ("invoice_path", "expected_result"), +# [ +# ("../../data/invoices/invoice-1.pdf", [{"收款人": "小明", "城市": "深圳市", "总费用/元": "412.00", "开票日期": "2023年02月03日"}]), +# ], +# ) +# async def test_generate_table(invoice_path: str, expected_result: list[dict]): +# invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path)) +# filename = os.path.basename(invoice_path) +# ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename) +# table_data = await GenerateTable().run(ocr_results=ocr_result, filename=filename) +# assert table_data == expected_result +# +# +# @pytest.mark.asyncio +# @pytest.mark.parametrize( +# ("invoice_path", "query", "expected_result"), +# [("../../data/invoices/invoice-1.pdf", "Invoicing date", "2023年02月03日")], +# ) +# async def test_reply_question(invoice_path: str, query: dict, expected_result: str): +# invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path)) +# filename = os.path.basename(invoice_path) +# ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename) +# result = await ReplyQuestion().run(query=query, ocr_result=ocr_result) +# assert expected_result in result diff --git a/tests/metagpt/document_store/test_lancedb_store.py b/tests/metagpt/document_store/test_lancedb_store.py index 5c0e40f57..1b7368620 100644 --- a/tests/metagpt/document_store/test_lancedb_store.py +++ b/tests/metagpt/document_store/test_lancedb_store.py @@ -7,12 +7,9 @@ """ import random -import pytest - from metagpt.document_store.lancedb_store import LanceStore -@pytest def test_lance_store(): # This simply establishes the connection to the database, so we can drop the table if it exists store = LanceStore("test") diff --git a/tests/metagpt/roles/test_invoice_ocr_assistant.py b/tests/metagpt/roles/test_invoice_ocr_assistant.py index ab3092004..e90182dde 100644 --- a/tests/metagpt/roles/test_invoice_ocr_assistant.py +++ b/tests/metagpt/roles/test_invoice_ocr_assistant.py @@ -1,63 +1,63 @@ -#!/usr/bin/env python3 -# _*_ coding: utf-8 _*_ - -""" -@Time : 2023/9/21 23:11:27 -@Author : Stitch-z -@File : test_invoice_ocr_assistant.py -""" - -import json -from pathlib import Path - -import pandas as pd -import pytest - -from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant, InvoicePath -from metagpt.schema import Message - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - ("query", "invoice_path", "invoice_table_path", "expected_result"), - [ - ( - "Invoicing date", - Path("../../data/invoices/invoice-1.pdf"), - Path("../../../data/invoice_table/invoice-1.xlsx"), - [{"收款人": "小明", "城市": "深圳市", "总费用/元": 412.00, "开票日期": "2023年02月03日"}], - ), - ( - "Invoicing date", - Path("../../data/invoices/invoice-2.png"), - Path("../../../data/invoice_table/invoice-2.xlsx"), - [{"收款人": "铁头", "城市": "广州市", "总费用/元": 898.00, "开票日期": "2023年03月17日"}], - ), - ( - "Invoicing date", - Path("../../data/invoices/invoice-3.jpg"), - Path("../../../data/invoice_table/invoice-3.xlsx"), - [{"收款人": "夏天", "城市": "福州市", "总费用/元": 2462.00, "开票日期": "2023年08月26日"}], - ), - ( - "Invoicing date", - Path("../../data/invoices/invoice-4.zip"), - Path("../../../data/invoice_table/invoice-4.xlsx"), - [ - {"收款人": "小明", "城市": "深圳市", "总费用/元": 412.00, "开票日期": "2023年02月03日"}, - {"收款人": "铁头", "城市": "广州市", "总费用/元": 898.00, "开票日期": "2023年03月17日"}, - {"收款人": "夏天", "城市": "福州市", "总费用/元": 2462.00, "开票日期": "2023年08月26日"}, - ], - ), - ], -) -async def test_invoice_ocr_assistant( - query: str, invoice_path: Path, invoice_table_path: Path, expected_result: list[dict] -): - invoice_path = Path.cwd() / invoice_path - role = InvoiceOCRAssistant() - await role.run(Message(content=query, instruct_content=InvoicePath(file_path=invoice_path))) - invoice_table_path = Path.cwd() / invoice_table_path - df = pd.read_excel(invoice_table_path) - dict_result = df.to_dict(orient="records") - assert json.dumps(dict_result) == json.dumps(expected_result) +# #!/usr/bin/env python3 +# # _*_ coding: utf-8 _*_ +# +# """ +# @Time : 2023/9/21 23:11:27 +# @Author : Stitch-z +# @File : test_invoice_ocr_assistant.py +# """ +# +# import json +# from pathlib import Path +# +# import pandas as pd +# import pytest +# +# from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant, InvoicePath +# from metagpt.schema import Message +# +# +# @pytest.mark.asyncio +# @pytest.mark.parametrize( +# ("query", "invoice_path", "invoice_table_path", "expected_result"), +# [ +# ( +# "Invoicing date", +# Path("../../data/invoices/invoice-1.pdf"), +# Path("../../../data/invoice_table/invoice-1.xlsx"), +# [{"收款人": "小明", "城市": "深圳市", "总费用/元": 412.00, "开票日期": "2023年02月03日"}], +# ), +# ( +# "Invoicing date", +# Path("../../data/invoices/invoice-2.png"), +# Path("../../../data/invoice_table/invoice-2.xlsx"), +# [{"收款人": "铁头", "城市": "广州市", "总费用/元": 898.00, "开票日期": "2023年03月17日"}], +# ), +# ( +# "Invoicing date", +# Path("../../data/invoices/invoice-3.jpg"), +# Path("../../../data/invoice_table/invoice-3.xlsx"), +# [{"收款人": "夏天", "城市": "福州市", "总费用/元": 2462.00, "开票日期": "2023年08月26日"}], +# ), +# ( +# "Invoicing date", +# Path("../../data/invoices/invoice-4.zip"), +# Path("../../../data/invoice_table/invoice-4.xlsx"), +# [ +# {"收款人": "小明", "城市": "深圳市", "总费用/元": 412.00, "开票日期": "2023年02月03日"}, +# {"收款人": "铁头", "城市": "广州市", "总费用/元": 898.00, "开票日期": "2023年03月17日"}, +# {"收款人": "夏天", "城市": "福州市", "总费用/元": 2462.00, "开票日期": "2023年08月26日"}, +# ], +# ), +# ], +# ) +# async def test_invoice_ocr_assistant( +# query: str, invoice_path: Path, invoice_table_path: Path, expected_result: list[dict] +# ): +# invoice_path = Path.cwd() / invoice_path +# role = InvoiceOCRAssistant() +# await role.run(Message(content=query, instruct_content=InvoicePath(file_path=invoice_path))) +# invoice_table_path = Path.cwd() / invoice_table_path +# df = pd.read_excel(invoice_table_path) +# dict_result = df.to_dict(orient="records") +# assert json.dumps(dict_result) == json.dumps(expected_result) diff --git a/tests/metagpt/test_message.py b/tests/metagpt/test_message.py index 8f267ba54..cf6f744dc 100644 --- a/tests/metagpt/test_message.py +++ b/tests/metagpt/test_message.py @@ -8,7 +8,7 @@ """ import pytest -from metagpt.schema import AIMessage, Message, RawMessage, SystemMessage, UserMessage +from metagpt.schema import AIMessage, Message, SystemMessage, UserMessage def test_message(): @@ -29,13 +29,5 @@ def test_all_messages(): assert msg.content == test_content -def test_raw_message(): - msg = RawMessage(role="user", content="raw") - assert msg["role"] == "user" - assert msg["content"] == "raw" - with pytest.raises(KeyError): - assert msg["1"] == 1, "KeyError: '1'" - - if __name__ == "__main__": pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/tools/test_code_interpreter.py b/tests/metagpt/tools/test_code_interpreter.py index b8380967c..71df6315b 100644 --- a/tests/metagpt/tools/test_code_interpreter.py +++ b/tests/metagpt/tools/test_code_interpreter.py @@ -8,53 +8,46 @@ open-interpreter 0.1.17 requires tiktoken<0.5.0,>=0.4.0, but you have tiktoken 0.5.2 which is incompatible. """ -from pathlib import Path - -import pandas as pd -import pytest - -from metagpt.actions import Action -from metagpt.logs import logger -from metagpt.tools.code_interpreter import OpenCodeInterpreter, OpenInterpreterDecorator - -logger.add("./tests/data/test_ci.log") -stock = "./tests/data/baba_stock.csv" - - -# TODO: 需要一种表格数据格式,能够支持schame管理的,标注字段类型和字段含义。 -class CreateStockIndicators(Action): - @OpenInterpreterDecorator(save_code=True, code_file_path="./tests/data/stock_indicators.py") - async def run(self, stock_path: str, indicators=["Simple Moving Average", "BollingerBands"]) -> pd.DataFrame: - """对stock_path中的股票数据, 使用pandas和ta计算indicators中的技术指标, 返回带有技术指标的股票数据,不需要去除空值, 不需要安装任何包; - 指标生成对应的三列: SMA, BB_upper, BB_lower - """ - ... - - -@pytest.mark.asyncio -async def test_actions(): - # Prerequisites - # Conflict with openai 1.x - - # 计算指标 - indicators = ["Simple Moving Average", "BollingerBands"] - stocker = CreateStockIndicators() - df, msg = await stocker.run(stock, indicators=indicators) - assert isinstance(df, pd.DataFrame) - assert "Close" in df.columns - assert "Date" in df.columns - # 将df保存为文件,将文件路径传入到下一个action - df_path = "./tests/data/stock_indicators.csv" - df.to_csv(df_path) - assert Path(df_path).is_file() - # 可视化指标结果 - figure_path = "./tests/data/figure_ci.png" - ci_ploter = OpenCodeInterpreter() - ci_ploter.chat( - f"使用seaborn对{df_path}中与股票布林带有关的数据列的Date, Close, SMA, BB_upper(布林带上界), BB_lower(布林带下界)进行可视化, 可视化图片保存在{figure_path}中。不需要任何指标计算,把Date列转换为日期类型。要求图片优美,BB_upper, BB_lower之间使用合适的颜色填充。" - ) - assert Path(figure_path).is_file() - - -if __name__ == "__main__": - pytest.main([__file__, "-s"]) +# from pathlib import Path +# +# import pandas as pd +# import pytest +# +# from metagpt.actions import Action +# from metagpt.logs import logger +# from metagpt.tools.code_interpreter import OpenCodeInterpreter, OpenInterpreterDecorator +# +# logger.add("./tests/data/test_ci.log") +# stock = "./tests/data/baba_stock.csv" +# +# +# # TODO: 需要一种表格数据格式,能够支持schame管理的,标注字段类型和字段含义。 +# class CreateStockIndicators(Action): +# @OpenInterpreterDecorator(save_code=True, code_file_path="./tests/data/stock_indicators.py") +# async def run(self, stock_path: str, indicators=["Simple Moving Average", "BollingerBands"]) -> pd.DataFrame: +# """对stock_path中的股票数据, 使用pandas和ta计算indicators中的技术指标, 返回带有技术指标的股票数据,不需要去除空值, 不需要安装任何包; +# 指标生成对应的三列: SMA, BB_upper, BB_lower +# """ +# ... +# +# +# @pytest.mark.asyncio +# async def test_actions(): +# # 计算指标 +# indicators = ["Simple Moving Average", "BollingerBands"] +# stocker = CreateStockIndicators() +# df, msg = await stocker.run(stock, indicators=indicators) +# assert isinstance(df, pd.DataFrame) +# assert "Close" in df.columns +# assert "Date" in df.columns +# # 将df保存为文件,将文件路径传入到下一个action +# df_path = "./tests/data/stock_indicators.csv" +# df.to_csv(df_path) +# assert Path(df_path).is_file() +# # 可视化指标结果 +# figure_path = "./tests/data/figure_ci.png" +# ci_ploter = OpenCodeInterpreter() +# ci_ploter.chat( +# f"使用seaborn对{df_path}中与股票布林带有关的数据列的Date, Close, SMA, BB_upper(布林带上界), BB_lower(布林带下界)进行可视化, 可视化图片保存在{figure_path}中。不需要任何指标计算,把Date列转换为日期类型。要求图片优美,BB_upper, BB_lower之间使用合适的颜色填充。" +# ) +# assert Path(figure_path).is_file() diff --git a/tests/metagpt/tools/test_moderation.py b/tests/metagpt/tools/test_moderation.py index c71611bd3..534fe812a 100644 --- a/tests/metagpt/tools/test_moderation.py +++ b/tests/metagpt/tools/test_moderation.py @@ -12,33 +12,6 @@ from metagpt.config import CONFIG from metagpt.tools.moderation import Moderation -@pytest.mark.parametrize( - ("content",), - [ - [ - ["I will kill you", "The weather is really nice today", "I want to hit you"], - ] - ], -) -def test_moderation(content): - # Prerequisites - assert CONFIG.OPENAI_API_KEY and CONFIG.OPENAI_API_KEY != "YOUR_API_KEY" - assert not CONFIG.OPENAI_API_TYPE - assert CONFIG.OPENAI_API_MODEL - - moderation = Moderation() - results = moderation.moderation(content=content) - assert isinstance(results, list) - assert len(results) == len(content) - - results = moderation.moderation_with_categories(content=content) - assert isinstance(results, list) - assert results - for m in results: - assert "flagged" in m - assert "true_categories" in m - - @pytest.mark.asyncio @pytest.mark.parametrize( ("content",), diff --git a/tests/metagpt/tools/test_openai_text_to_image.py b/tests/metagpt/tools/test_openai_text_to_image.py index 24691a5e9..e560da798 100644 --- a/tests/metagpt/tools/test_openai_text_to_image.py +++ b/tests/metagpt/tools/test_openai_text_to_image.py @@ -9,7 +9,10 @@ import pytest from metagpt.config import CONFIG -from metagpt.tools.openai_text_to_image import oas3_openai_text_to_image +from metagpt.tools.openai_text_to_image import ( + OpenAIText2Image, + oas3_openai_text_to_image, +) @pytest.mark.asyncio @@ -23,5 +26,13 @@ async def test_draw(): assert binary_data +@pytest.mark.asyncio +async def test_get_image(): + data = await OpenAIText2Image.get_image_data( + url="https://www.baidu.com/img/PCtm_d9c8750bed0b3c7d089fa7d55720d6cf.png" + ) + assert data + + if __name__ == "__main__": pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/utils/test_common.py b/tests/metagpt/utils/test_common.py index 0ab34437d..5fb5f8a47 100644 --- a/tests/metagpt/utils/test_common.py +++ b/tests/metagpt/utils/test_common.py @@ -8,6 +8,7 @@ """ import os +import platform from typing import Any, Set import pytest @@ -17,7 +18,7 @@ from metagpt.actions import RunCode from metagpt.const import get_metagpt_root from metagpt.roles.tutorial_assistant import TutorialAssistant from metagpt.schema import Message -from metagpt.utils.common import any_to_str, any_to_str_set +from metagpt.utils.common import any_to_str, any_to_str_set, check_cmd_exists class TestGetProjectRoot: @@ -28,13 +29,12 @@ class TestGetProjectRoot: def test_get_project_root(self): project_root = get_metagpt_root() - assert project_root.name == "metagpt" + assert project_root.name == "MetaGPT" def test_get_root_exception(self): - with pytest.raises(Exception) as exc_info: - self.change_etc_dir() - get_metagpt_root() - assert str(exc_info.value) == "Project root not found." + self.change_etc_dir() + project_root = get_metagpt_root() + assert project_root def test_any_to_str(self): class Input(BaseModel): @@ -65,8 +65,8 @@ class TestGetProjectRoot: want={"metagpt.roles.tutorial_assistant.TutorialAssistant", "metagpt.actions.run_code.RunCode", "a"}, ), Input( - x={TutorialAssistant, RunCode(), "a"}, - want={"metagpt.roles.tutorial_assistant.TutorialAssistant", "metagpt.actions.run_code.RunCode", "a"}, + x={TutorialAssistant, "a"}, + want={"metagpt.roles.tutorial_assistant.TutorialAssistant", "a"}, ), Input( x=(TutorialAssistant, RunCode(), "a"), @@ -77,6 +77,25 @@ class TestGetProjectRoot: v = any_to_str_set(i.x) assert v == i.want + def test_check_cmd_exists(self): + class Input(BaseModel): + command: str + platform: str + + inputs = [ + {"command": "cat", "platform": "linux"}, + {"command": "ls", "platform": "linux"}, + {"command": "mspaint", "platform": "windows"}, + ] + plat = "windows" if platform.system().lower() == "windows" else "linux" + for i in inputs: + seed = Input(**i) + result = check_cmd_exists(seed.command) + if plat == seed.platform: + assert result == 0 + else: + assert result != 0 + if __name__ == "__main__": pytest.main([__file__, "-s"])