From 6a1690095364aa1c34528b94092f8ff445f82600 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Tue, 26 Dec 2023 10:03:56 +0800 Subject: [PATCH 1/4] fixbug: config.yaml feat: +tests --- config/config.yaml | 1 - metagpt/tools/azure_tts.py | 11 +------ metagpt/tools/hello.py | 4 ++- requirements-test.txt | 5 ++++ requirements.txt | 8 +++--- tests/metagpt/tools/test_azure_tts.py | 30 ++++++++++++-------- tests/metagpt/tools/test_code_interpreter.py | 17 +++++++++++ tests/metagpt/tools/test_hello.py | 30 ++++++++++++++++++++ 8 files changed, 78 insertions(+), 28 deletions(-) create mode 100644 requirements-test.txt create mode 100644 tests/metagpt/tools/test_hello.py diff --git a/config/config.yaml b/config/config.yaml index 711110f97..5025a4977 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -121,7 +121,6 @@ TIMEOUT: 60 # Timeout for llm invocation # PROMPT_FORMAT: json #json or markdown -<<<<<<< HEAD ### Agent configurations # RAISE_NOT_CONFIG_ERROR: true # "true" if the LLM key is not configured, throw a NotConfiguredException, else "false". # WORKSPACE_PATH_WITH_UID: false # "true" if using `{workspace}/{uid}` as the workspace path; "false" use `{workspace}`. diff --git a/metagpt/tools/azure_tts.py b/metagpt/tools/azure_tts.py index 8fdb10c13..d3e67c269 100644 --- a/metagpt/tools/azure_tts.py +++ b/metagpt/tools/azure_tts.py @@ -6,7 +6,6 @@ @File : azure_tts.py @Modified by: mashenquan, 2023/8/17. Azure TTS OAS3 api, which provides text-to-speech functionality """ -import asyncio import base64 from pathlib import Path from uuid import uuid4 @@ -14,7 +13,7 @@ from uuid import uuid4 import aiofiles from azure.cognitiveservices.speech import AudioConfig, SpeechConfig, SpeechSynthesizer -from metagpt.config import CONFIG, Config +from metagpt.config import CONFIG from metagpt.logs import logger @@ -103,11 +102,3 @@ async def oas3_azsure_tts(text, lang="", voice="", style="", role="", subscripti return "" return base64_string - - -if __name__ == "__main__": - Config() - loop = asyncio.new_event_loop() - v = loop.create_task(oas3_azsure_tts("测试,test")) - loop.run_until_complete(v) - print(v) diff --git a/metagpt/tools/hello.py b/metagpt/tools/hello.py index 8a21e1b4e..52d2d11c1 100644 --- a/metagpt/tools/hello.py +++ b/metagpt/tools/hello.py @@ -12,6 +12,7 @@ -H 'Content-Type: application/json' \ -d '{}' """ +from pathlib import Path import connexion @@ -22,6 +23,7 @@ async def post_greeting(name: str) -> str: if __name__ == "__main__": - app = connexion.AioHttpApp(__name__, specification_dir="../../.well-known/") + specification_dir = Path(__file__).parent.parent.parent / ".well-known" + app = connexion.AsyncApp(__name__, specification_dir=str(specification_dir)) app.add_api("openapi.yaml", arguments={"title": "Hello World Example"}) app.run(port=8080) diff --git a/requirements-test.txt b/requirements-test.txt new file mode 100644 index 000000000..39ba608b7 --- /dev/null +++ b/requirements-test.txt @@ -0,0 +1,5 @@ +# For unit test +-r requirements.txt + +connexion[uvicorn]~=3.0.5 +azure-cognitiveservices-speech~=1.31.0 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 5cb01ab99..f2566fb15 100644 --- a/requirements.txt +++ b/requirements.txt @@ -40,12 +40,12 @@ typing_extensions==4.7.0 libcst==1.0.1 qdrant-client==1.4.0 pytest-mock==3.11.1 -# open-interpreter==0.1.7; python_version>"3.9" +# open-interpreter==0.1.7; python_version>"3.9" # Conflict with openai 1.x ta==0.10.2 semantic-kernel==0.4.0.dev0 wrapt==1.15.0 #aiohttp_jinja2 -#azure-cognitiveservices-speech~=1.31.0 +# azure-cognitiveservices-speech~=1.31.0 # Used by metagpt/tools/azure_tts.py #aioboto3~=11.3.0 #redis==4.3.5 websocket-client==1.6.2 @@ -54,8 +54,8 @@ gitpython==3.1.40 zhipuai==1.0.7 socksio~=1.0.0 gitignore-parser==0.1.9 -# connexion[swagger-ui] +# connexion[uvicorn]~=3.0.5 # Used by metagpt/tools/hello.py websockets~=12.0 networkx~=3.2.1 google-generativeai==0.3.1 -playwright==1.40.0 \ No newline at end of file +playwright==1.40.0 diff --git a/tests/metagpt/tools/test_azure_tts.py b/tests/metagpt/tools/test_azure_tts.py index b7f94a19c..38fef557e 100644 --- a/tests/metagpt/tools/test_azure_tts.py +++ b/tests/metagpt/tools/test_azure_tts.py @@ -7,13 +7,20 @@ @Modified By: mashenquan, 2023-8-9, add more text formatting options @Modified By: mashenquan, 2023-8-17, move to `tools` folder. """ -import asyncio + +import pytest +from azure.cognitiveservices.speech import ResultReason from metagpt.config import CONFIG from metagpt.tools.azure_tts import AzureTTS -def test_azure_tts(): +@pytest.mark.asyncio +async def test_azure_tts(): + # Prerequisites + assert CONFIG.AZURE_TTS_SUBSCRIPTION_KEY and CONFIG.AZURE_TTS_SUBSCRIPTION_KEY != "YOUR_API_KEY" + assert CONFIG.AZURE_TTS_REGION + azure_tts = AzureTTS(subscription_key="", region="") text = """ 女儿看见父亲走了进来,问道: @@ -25,20 +32,19 @@ def test_azure_tts(): “Writing a binary file in Python is similar to writing a regular text file, but you'll work with bytes instead of strings.” """ - path = CONFIG.workspace / "tts" + path = CONFIG.workspace_path / "tts" path.mkdir(exist_ok=True, parents=True) filename = path / "girl.wav" - loop = asyncio.new_event_loop() - v = loop.create_task( - azure_tts.synthesize_speech(lang="zh-CN", voice="zh-CN-XiaomoNeural", text=text, output_file=str(filename)) + filename.unlink(missing_ok=True) + result = await azure_tts.synthesize_speech( + lang="zh-CN", voice="zh-CN-XiaomoNeural", text=text, output_file=str(filename) ) - result = loop.run_until_complete(v) - print(result) - - # 运行需要先配置 SUBSCRIPTION_KEY - # TODO: 这里如果要检验,还要额外加上对应的asr,才能确保前后生成是接近一致的,但现在还没有 + assert result + assert result.audio_data + assert result.reason == ResultReason.SynthesizingAudioCompleted + assert filename.exists() if __name__ == "__main__": - test_azure_tts() + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/tools/test_code_interpreter.py b/tests/metagpt/tools/test_code_interpreter.py index 03d4ce8df..b8380967c 100644 --- a/tests/metagpt/tools/test_code_interpreter.py +++ b/tests/metagpt/tools/test_code_interpreter.py @@ -1,3 +1,13 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : +@Author : +@File : test_code_interpreter.py +@Warning : open-interpreter 0.1.17 requires openai<0.29.0,>=0.28.0, but you have openai 1.6.0 which is incompatible. + open-interpreter 0.1.17 requires tiktoken<0.5.0,>=0.4.0, but you have tiktoken 0.5.2 which is incompatible. +""" + from pathlib import Path import pandas as pd @@ -23,6 +33,9 @@ class CreateStockIndicators(Action): @pytest.mark.asyncio async def test_actions(): + # Prerequisites + # Conflict with openai 1.x + # 计算指标 indicators = ["Simple Moving Average", "BollingerBands"] stocker = CreateStockIndicators() @@ -41,3 +54,7 @@ async def test_actions(): f"使用seaborn对{df_path}中与股票布林带有关的数据列的Date, Close, SMA, BB_upper(布林带上界), BB_lower(布林带下界)进行可视化, 可视化图片保存在{figure_path}中。不需要任何指标计算,把Date列转换为日期类型。要求图片优美,BB_upper, BB_lower之间使用合适的颜色填充。" ) assert Path(figure_path).is_file() + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/tools/test_hello.py b/tests/metagpt/tools/test_hello.py new file mode 100644 index 000000000..037dcd1b7 --- /dev/null +++ b/tests/metagpt/tools/test_hello.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/26 +@Author : mashenquan +@File : test_hello.py +""" +import subprocess +from pathlib import Path + +import pytest +import requests + + +@pytest.mark.asyncio +async def test_hello(): + script_pathname = Path(__file__).resolve() + process = subprocess.Popen(["python", str(script_pathname)]) + + url = "http://localhost:8080/openapi/greeting/dave" + headers = {"accept": "text/plain", "Content-Type": "application/json"} + data = {} + response = requests.post(url, headers=headers, json=data) + assert response.text == "Hello dave\n" + + process.terminate() + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) From 6512f40ddd3693ee12e4230115df363255814892 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Tue, 26 Dec 2023 13:31:50 +0800 Subject: [PATCH 2/4] feat: +unit test --- metagpt/learn/text_to_image.py | 6 ++- metagpt/learn/text_to_speech.py | 6 +-- metagpt/tools/azure_tts.py | 3 +- metagpt/tools/hello.py | 4 +- metagpt/tools/iflytek_tts.py | 24 +++------- metagpt/tools/metagpt_oas3_api_svc.py | 28 +++--------- metagpt/tools/metagpt_text_to_image.py | 22 +++------ metagpt/tools/moderation.py | 45 +++++++++++++------ metagpt/tools/openai_text_to_embedding.py | 39 ++++++++-------- metagpt/tools/openai_text_to_image.py | 19 ++------ metagpt/tools/web_browser_engine_selenium.py | 3 +- requirements-test.txt | 9 +++- requirements.txt | 2 +- tests/metagpt/tools/test_hello.py | 6 ++- tests/metagpt/tools/test_iflytek_tts.py | 31 +++++++++++++ .../tools/test_metagpt_oas3_api_svc.py | 32 +++++++++++++ .../tools/test_metagpt_text_to_image.py | 25 +++++++++++ tests/metagpt/tools/test_moderation.py | 29 ++++++++++++ .../tools/test_openai_text_to_embedding.py | 30 +++++++++++++ .../tools/test_openai_text_to_image.py | 27 +++++++++++ ...mpt_generator.py => test_prompt_writer.py} | 2 +- tests/metagpt/tools/test_search_engine.py | 14 ++++++ .../tools/test_search_engine_meilisearch.py | 12 +++++ ...test_ut_generator.py => test_ut_writer.py} | 0 .../metagpt/tools/test_web_browser_engine.py | 15 ++++--- .../test_web_browser_engine_playwright.py | 25 ++++++----- .../tools/test_web_browser_engine_selenium.py | 26 ++++++----- 27 files changed, 333 insertions(+), 151 deletions(-) create mode 100644 tests/metagpt/tools/test_iflytek_tts.py create mode 100644 tests/metagpt/tools/test_metagpt_oas3_api_svc.py create mode 100644 tests/metagpt/tools/test_metagpt_text_to_image.py create mode 100644 tests/metagpt/tools/test_openai_text_to_embedding.py create mode 100644 tests/metagpt/tools/test_openai_text_to_image.py rename tests/metagpt/tools/{test_prompt_generator.py => test_prompt_writer.py} (97%) rename tests/metagpt/tools/{test_ut_generator.py => test_ut_writer.py} (100%) diff --git a/metagpt/learn/text_to_image.py b/metagpt/learn/text_to_image.py index eaf528b3e..c3c62fb67 100644 --- a/metagpt/learn/text_to_image.py +++ b/metagpt/learn/text_to_image.py @@ -6,6 +6,7 @@ @File : text_to_image.py @Desc : Text-to-Image skill, which provides text-to-image functionality. """ +import base64 from metagpt.config import CONFIG from metagpt.const import BASE64_FORMAT @@ -25,11 +26,12 @@ async def text_to_image(text, size_type: str = "512x512", openai_api_key="", mod """ image_declaration = "data:image/png;base64," if CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL or model_url: - base64_data = await oas3_metagpt_text_to_image(text, size_type, model_url) + binary_data = await oas3_metagpt_text_to_image(text, size_type, model_url) elif CONFIG.OPENAI_API_KEY or openai_api_key: - base64_data = await oas3_openai_text_to_image(text, size_type) + binary_data = await oas3_openai_text_to_image(text, size_type) else: raise ValueError("Missing necessary parameters.") + base64_data = base64.b64encode(binary_data).decode("utf-8") s3 = S3() url = await s3.cache(data=base64_data, file_ext=".png", format=BASE64_FORMAT) if s3.is_valid else "" diff --git a/metagpt/learn/text_to_speech.py b/metagpt/learn/text_to_speech.py index 72958b8c7..ecd00c724 100644 --- a/metagpt/learn/text_to_speech.py +++ b/metagpt/learn/text_to_speech.py @@ -6,7 +6,6 @@ @File : text_to_speech.py @Desc : Text-to-Speech skill, which provides text-to-speech functionality """ -import openai from metagpt.config import CONFIG from metagpt.const import BASE64_FORMAT @@ -66,7 +65,6 @@ async def text_to_speech( return f"[{text}]({url})" return audio_declaration + base64_data if base64_data else base64_data - raise openai.InvalidRequestError( - message="AZURE_TTS_SUBSCRIPTION_KEY, AZURE_TTS_REGION, IFLYTEK_APP_ID, IFLYTEK_API_KEY, IFLYTEK_API_SECRET error", - param={}, + raise ValueError( + "AZURE_TTS_SUBSCRIPTION_KEY, AZURE_TTS_REGION, IFLYTEK_APP_ID, IFLYTEK_API_KEY, IFLYTEK_API_SECRET error" ) diff --git a/metagpt/tools/azure_tts.py b/metagpt/tools/azure_tts.py index d3e67c269..f4f8aa0a2 100644 --- a/metagpt/tools/azure_tts.py +++ b/metagpt/tools/azure_tts.py @@ -96,9 +96,10 @@ async def oas3_azsure_tts(text, lang="", voice="", style="", role="", subscripti async with aiofiles.open(filename, mode="rb") as reader: data = await reader.read() base64_string = base64.b64encode(data).decode("utf-8") - filename.unlink() except Exception as e: logger.error(f"text:{text}, error:{e}") return "" + finally: + filename.unlink(missing_ok=True) return base64_string diff --git a/metagpt/tools/hello.py b/metagpt/tools/hello.py index 52d2d11c1..ec7fc9231 100644 --- a/metagpt/tools/hello.py +++ b/metagpt/tools/hello.py @@ -7,7 +7,7 @@ @Desc : Implement the OpenAPI Specification 3.0 demo and use the following command to test the HTTP service: curl -X 'POST' \ - 'http://localhost:8080/openapi/greeting/dave' \ + 'http://localhost:8082/openapi/greeting/dave' \ -H 'accept: text/plain' \ -H 'Content-Type: application/json' \ -d '{}' @@ -26,4 +26,4 @@ if __name__ == "__main__": specification_dir = Path(__file__).parent.parent.parent / ".well-known" app = connexion.AsyncApp(__name__, specification_dir=str(specification_dir)) app.add_api("openapi.yaml", arguments={"title": "Hello World Example"}) - app.run(port=8080) + app.run(port=8082) diff --git a/metagpt/tools/iflytek_tts.py b/metagpt/tools/iflytek_tts.py index cb87d2e7f..ad2395362 100644 --- a/metagpt/tools/iflytek_tts.py +++ b/metagpt/tools/iflytek_tts.py @@ -6,7 +6,6 @@ @File : iflytek_tts.py @Desc : iFLYTEK TTS OAS3 api, which provides text-to-speech functionality """ -import asyncio import base64 import hashlib import hmac @@ -74,12 +73,13 @@ class IFlyTekTTS(object): await websocket.send(req) # receive frames - async with aiofiles.open(str(output_file), "w") as writer: + async with aiofiles.open(str(output_file), "wb") as writer: while True: v = await websocket.recv() rsp = IFlyTekTTSResponse(**json.loads(v)) if rsp.data: - await writer.write(rsp.data.audio) + binary_data = base64.b64decode(rsp.data.audio) + await writer.write(binary_data) if rsp.data.status != IFlyTekTTSStatus.STATUS_LAST_FRAME.value: continue break @@ -140,23 +140,13 @@ async def oas3_iflytek_tts(text: str, voice: str = "", app_id: str = "", api_key try: tts = IFlyTekTTS(app_id=app_id, api_key=api_key, api_secret=api_secret) await tts.synthesize_speech(text=text, output_file=str(filename), voice=voice) - async with aiofiles.open(str(filename), mode="r") as reader: - base64_string = await reader.read() + async with aiofiles.open(str(filename), mode="rb") as reader: + data = await reader.read() + base64_string = base64.b64encode(data).decode("utf-8") except Exception as e: logger.error(f"text:{text}, error:{e}") base64_string = "" finally: - filename.unlink() + filename.unlink(missing_ok=True) return base64_string - - -if __name__ == "__main__": - asyncio.get_event_loop().run_until_complete( - oas3_iflytek_tts( - text="你好,hello", - app_id="f7acef62", - api_key="fda72e3aa286042a492525816a5efa08", - api_secret="ZDk3NjdiMDBkODJlOWQ1NjRjMGI2NDY4", - ) - ) diff --git a/metagpt/tools/metagpt_oas3_api_svc.py b/metagpt/tools/metagpt_oas3_api_svc.py index 2ff4c8225..319e7efb2 100644 --- a/metagpt/tools/metagpt_oas3_api_svc.py +++ b/metagpt/tools/metagpt_oas3_api_svc.py @@ -6,39 +6,21 @@ @File : metagpt_oas3_api_svc.py @Desc : MetaGPT OpenAPI Specification 3.0 REST API service """ -import asyncio -import sys + from pathlib import Path import connexion -sys.path.append(str(Path(__file__).resolve().parent.parent.parent)) # fix-bug: No module named 'metagpt' - def oas_http_svc(): """Start the OAS 3.0 OpenAPI HTTP service""" - app = connexion.AioHttpApp(__name__, specification_dir="../../.well-known/") + print("http://localhost:8080/oas3/ui/") + specification_dir = Path(__file__).parent.parent.parent / ".well-known" + app = connexion.AsyncApp(__name__, specification_dir=str(specification_dir)) app.add_api("metagpt_oas3_api.yaml") app.add_api("openapi.yaml") app.run(port=8080) -async def async_main(): - """Start the OAS 3.0 OpenAPI HTTP service in the background.""" - loop = asyncio.get_event_loop() - loop.run_in_executor(None, oas_http_svc) - - # TODO: replace following codes: - while True: - await asyncio.sleep(1) - print("sleep") - - -def main(): - print("http://localhost:8080/oas3/ui/") - oas_http_svc() - - if __name__ == "__main__": - # asyncio.run(async_main()) - main() + oas_http_svc() diff --git a/metagpt/tools/metagpt_text_to_image.py b/metagpt/tools/metagpt_text_to_image.py index 50c0edcba..9a84e69eb 100644 --- a/metagpt/tools/metagpt_text_to_image.py +++ b/metagpt/tools/metagpt_text_to_image.py @@ -6,7 +6,6 @@ @File : metagpt_text_to_image.py @Desc : MetaGPT Text-to-Image OAS3 api, which provides text-to-image functionality. """ -import asyncio import base64 from typing import Dict, List @@ -14,7 +13,7 @@ import aiohttp import requests from pydantic import BaseModel -from metagpt.config import CONFIG, Config +from metagpt.config import CONFIG from metagpt.logs import logger @@ -75,11 +74,12 @@ class MetaGPTText2Image: async with session.post(self.model_url, headers=headers, json=data) as response: result = ImageResult(**await response.json()) if len(result.images) == 0: - return "" - return result.images[0] + return 0 + data = base64.b64decode(result.images[0]) + return data except requests.exceptions.RequestException as e: logger.error(f"An error occurred:{e}") - return "" + return 0 # Export @@ -96,15 +96,3 @@ async def oas3_metagpt_text_to_image(text, size_type: str = "512x512", model_url if not model_url: model_url = CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL return await MetaGPTText2Image(model_url).text_2_image(text, size_type=size_type) - - -if __name__ == "__main__": - Config() - loop = asyncio.new_event_loop() - task = loop.create_task(oas3_metagpt_text_to_image("Panda emoji")) - v = loop.run_until_complete(task) - print(v) - data = base64.b64decode(v) - with open("tmp.png", mode="wb") as writer: - writer.write(data) - print(v) diff --git a/metagpt/tools/moderation.py b/metagpt/tools/moderation.py index 5532e4f66..e4b23d538 100644 --- a/metagpt/tools/moderation.py +++ b/metagpt/tools/moderation.py @@ -5,7 +5,6 @@ @Author : zhanglei @File : moderation.py """ -import asyncio from typing import Union from metagpt.llm import LLM @@ -15,6 +14,38 @@ class Moderation: def __init__(self): self.llm = LLM() + def handle_moderation_results(self, results): + resp = [] + for item in results: + categories = item.categories.dict() + true_categories = [category for category, item_flagged in categories.items() if item_flagged] + resp.append({"flagged": item.flagged, "true_categories": true_categories}) + return resp + + def moderation_with_categories(self, content: Union[str, list[str]]): + resp = [] + if content: + moderation_results = self.llm.moderation(content=content) + resp = self.handle_moderation_results(moderation_results.results) + return resp + + async def amoderation_with_categories(self, content: Union[str, list[str]]): + resp = [] + if content: + moderation_results = await self.llm.amoderation(content=content) + resp = self.handle_moderation_results(moderation_results.results) + return resp + + def moderation(self, content: Union[str, list[str]]): + resp = [] + if content: + moderation_results = self.llm.moderation(content=content) + results = moderation_results.results + for item in results: + resp.append(item.flagged) + + return resp + async def amoderation(self, content: Union[str, list[str]]): resp = [] if content: @@ -24,15 +55,3 @@ class Moderation: resp.append(item.flagged) return resp - - -async def main(): - moderation = Moderation() - rsp = await moderation.amoderation( - content=["I will kill you", "The weather is really nice today", "I want to hit you"] - ) - print(rsp) - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/metagpt/tools/openai_text_to_embedding.py b/metagpt/tools/openai_text_to_embedding.py index fb6fbc653..52b2cc9eb 100644 --- a/metagpt/tools/openai_text_to_embedding.py +++ b/metagpt/tools/openai_text_to_embedding.py @@ -7,14 +7,13 @@ @Desc : OpenAI Text-to-Embedding OAS3 api, which provides text-to-embedding functionality. For more details, checkout: `https://platform.openai.com/docs/api-reference/embeddings/object` """ -import asyncio from typing import List import aiohttp import requests -from pydantic import BaseModel +from pydantic import BaseModel, Field -from metagpt.config import CONFIG, Config +from metagpt.config import CONFIG from metagpt.logs import logger @@ -29,15 +28,18 @@ class Embedding(BaseModel): class Usage(BaseModel): - prompt_tokens: int - total_tokens: int + prompt_tokens: int = 0 + total_tokens: int = 0 class ResultEmbedding(BaseModel): - object: str - data: List[Embedding] - model: str - usage: Usage + class Config: + alias = {"object_": "object"} + + object_: str = "" + data: List[Embedding] = [] + model: str = "" + usage: Usage = Field(default_factory=Usage) class OpenAIText2Embedding: @@ -45,7 +47,7 @@ class OpenAIText2Embedding: """ :param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys` """ - self.openai_api_key = openai_api_key if openai_api_key else CONFIG.OPENAI_API_KEY + self.openai_api_key = openai_api_key or CONFIG.OPENAI_API_KEY async def text_2_embedding(self, text, model="text-embedding-ada-002"): """Text to embedding @@ -55,15 +57,18 @@ class OpenAIText2Embedding: :return: A json object of :class:`ResultEmbedding` class if successful, otherwise `{}`. """ + proxies = {"proxy": CONFIG.openai_proxy} if CONFIG.openai_proxy else {} headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.openai_api_key}"} data = {"input": text, "model": model} + url = "https://api.openai.com/v1/embeddings" try: async with aiohttp.ClientSession() as session: - async with session.post("https://api.openai.com/v1/embeddings", headers=headers, json=data) as response: - return await response.json() + async with session.post(url, headers=headers, json=data, **proxies) as response: + data = await response.json() + return ResultEmbedding(**data) except requests.exceptions.RequestException as e: logger.error(f"An error occurred:{e}") - return {} + return ResultEmbedding() # Export @@ -80,11 +85,3 @@ async def oas3_openai_text_to_embedding(text, model="text-embedding-ada-002", op if not openai_api_key: openai_api_key = CONFIG.OPENAI_API_KEY return await OpenAIText2Embedding(openai_api_key).text_2_embedding(text, model=model) - - -if __name__ == "__main__": - Config() - loop = asyncio.new_event_loop() - task = loop.create_task(oas3_openai_text_to_embedding("Panda emoji")) - v = loop.run_until_complete(task) - print(v) diff --git a/metagpt/tools/openai_text_to_image.py b/metagpt/tools/openai_text_to_image.py index 71381d8f2..fcfa86c7d 100644 --- a/metagpt/tools/openai_text_to_image.py +++ b/metagpt/tools/openai_text_to_image.py @@ -6,13 +6,10 @@ @File : openai_text_to_image.py @Desc : OpenAI Text-to-Image OAS3 api, which provides text-to-image functionality. """ -import asyncio -import base64 import aiohttp import requests -from metagpt.config import Config from metagpt.llm import LLM from metagpt.logs import logger @@ -23,7 +20,6 @@ class OpenAIText2Image: :param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys` """ self._llm = LLM() - self._client = self._llm.async_client def __del__(self): if self._llm: @@ -37,7 +33,7 @@ class OpenAIText2Image: :return: The image data is returned in Base64 encoding. """ try: - result = await self._client.images.generate(prompt=text, n=1, size=size_type) + result = await self._llm.async_client.images.generate(prompt=text, n=1, size=size_type) except Exception as e: logger.error(f"An error occurred:{e}") return "" @@ -57,12 +53,11 @@ class OpenAIText2Image: async with session.get(url) as response: response.raise_for_status() # 如果是 4xx 或 5xx 响应,会引发异常 image_data = await response.read() - base64_image = base64.b64encode(image_data).decode("utf-8") - return base64_image + return image_data except requests.exceptions.RequestException as e: logger.error(f"An error occurred:{e}") - return "" + return 0 # Export @@ -76,11 +71,3 @@ async def oas3_openai_text_to_image(text, size_type: str = "1024x1024"): if not text: return "" return await OpenAIText2Image().text_2_image(text, size_type=size_type) - - -if __name__ == "__main__": - Config() - loop = asyncio.new_event_loop() - task = loop.create_task(oas3_openai_text_to_image("Panda emoji")) - v = loop.run_until_complete(task) - print(v) diff --git a/metagpt/tools/web_browser_engine_selenium.py b/metagpt/tools/web_browser_engine_selenium.py index 628c8dea2..cabae7531 100644 --- a/metagpt/tools/web_browser_engine_selenium.py +++ b/metagpt/tools/web_browser_engine_selenium.py @@ -9,7 +9,7 @@ import asyncio import importlib from concurrent import futures from copy import deepcopy -from typing import Dict, Literal +from typing import Literal from selenium.webdriver.common.by import By from selenium.webdriver.support import expected_conditions as EC @@ -33,7 +33,6 @@ class SeleniumWrapper: def __init__( self, - options: Dict, browser_type: Literal["chrome", "firefox", "edge", "ie"] | None = None, launch_kwargs: dict | None = None, *, diff --git a/requirements-test.txt b/requirements-test.txt index 39ba608b7..fcf265163 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -2,4 +2,11 @@ -r requirements.txt connexion[uvicorn]~=3.0.5 -azure-cognitiveservices-speech~=1.31.0 \ No newline at end of file +azure-cognitiveservices-speech~=1.31.0 +duckduckgo_search +serpapi +google +httplib2 +google_api_python_client +selenium +webdriver_manager \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index f2566fb15..c8d21dfc8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,7 +9,7 @@ faiss_cpu==1.7.4 fire==0.4.0 typer # godot==0.1.1 -# google_api_python_client==2.93.0 +# google_api_python_client==2.93.0 # Used by search_engine.py lancedb==0.1.16 langchain==0.0.352 loguru==0.6.0 diff --git a/tests/metagpt/tools/test_hello.py b/tests/metagpt/tools/test_hello.py index 037dcd1b7..fdf67ac35 100644 --- a/tests/metagpt/tools/test_hello.py +++ b/tests/metagpt/tools/test_hello.py @@ -5,6 +5,7 @@ @Author : mashenquan @File : test_hello.py """ +import asyncio import subprocess from pathlib import Path @@ -14,10 +15,11 @@ import requests @pytest.mark.asyncio async def test_hello(): - script_pathname = Path(__file__).resolve() + script_pathname = Path(__file__).parent / "../../../metagpt/tools/hello.py" process = subprocess.Popen(["python", str(script_pathname)]) + await asyncio.sleep(5) - url = "http://localhost:8080/openapi/greeting/dave" + url = "http://localhost:8082/openapi/greeting/dave" headers = {"accept": "text/plain", "Content-Type": "application/json"} data = {} response = requests.post(url, headers=headers, json=data) diff --git a/tests/metagpt/tools/test_iflytek_tts.py b/tests/metagpt/tools/test_iflytek_tts.py new file mode 100644 index 000000000..58d8a83ce --- /dev/null +++ b/tests/metagpt/tools/test_iflytek_tts.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/26 +@Author : mashenquan +@File : test_iflytek_tts.py +""" +import pytest + +from metagpt.config import CONFIG +from metagpt.tools.iflytek_tts import oas3_iflytek_tts + + +@pytest.mark.asyncio +async def test_tts(): + # Prerequisites + assert CONFIG.IFLYTEK_APP_ID + assert CONFIG.IFLYTEK_API_KEY + assert CONFIG.IFLYTEK_API_SECRET + + result = await oas3_iflytek_tts( + text="你好,hello", + app_id=CONFIG.IFLYTEK_APP_ID, + api_key=CONFIG.IFLYTEK_API_KEY, + api_secret=CONFIG.IFLYTEK_API_SECRET, + ) + assert result + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/tools/test_metagpt_oas3_api_svc.py b/tests/metagpt/tools/test_metagpt_oas3_api_svc.py new file mode 100644 index 000000000..e0f17aa05 --- /dev/null +++ b/tests/metagpt/tools/test_metagpt_oas3_api_svc.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/26 +@Author : mashenquan +@File : test_metagpt_oas3_api_svc.py +""" +import asyncio +import subprocess +from pathlib import Path + +import pytest +import requests + + +@pytest.mark.asyncio +async def test_oas2_svc(): + script_pathname = Path(__file__).parent / "../../../metagpt/tools/metagpt_oas3_api_svc.py" + process = subprocess.Popen(["python", str(script_pathname)]) + await asyncio.sleep(5) + + url = "http://localhost:8080/openapi/greeting/dave" + headers = {"accept": "text/plain", "Content-Type": "application/json"} + data = {} + response = requests.post(url, headers=headers, json=data) + assert response.text == "Hello dave\n" + + process.terminate() + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/tools/test_metagpt_text_to_image.py b/tests/metagpt/tools/test_metagpt_text_to_image.py new file mode 100644 index 000000000..f5ced2061 --- /dev/null +++ b/tests/metagpt/tools/test_metagpt_text_to_image.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/26 +@Author : mashenquan +@File : test_metagpt_text_to_image.py +""" + +import pytest + +from metagpt.config import CONFIG +from metagpt.tools.metagpt_text_to_image import oas3_metagpt_text_to_image + + +@pytest.mark.asyncio +async def test_draw(): + # Prerequisites + assert CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL + + binary_data = await oas3_metagpt_text_to_image("Panda emoji") + assert binary_data + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/tools/test_moderation.py b/tests/metagpt/tools/test_moderation.py index 5ec3bd4de..c71611bd3 100644 --- a/tests/metagpt/tools/test_moderation.py +++ b/tests/metagpt/tools/test_moderation.py @@ -8,6 +8,7 @@ import pytest +from metagpt.config import CONFIG from metagpt.tools.moderation import Moderation @@ -20,11 +21,23 @@ from metagpt.tools.moderation import Moderation ], ) def test_moderation(content): + # Prerequisites + assert CONFIG.OPENAI_API_KEY and CONFIG.OPENAI_API_KEY != "YOUR_API_KEY" + assert not CONFIG.OPENAI_API_TYPE + assert CONFIG.OPENAI_API_MODEL + moderation = Moderation() results = moderation.moderation(content=content) assert isinstance(results, list) assert len(results) == len(content) + results = moderation.moderation_with_categories(content=content) + assert isinstance(results, list) + assert results + for m in results: + assert "flagged" in m + assert "true_categories" in m + @pytest.mark.asyncio @pytest.mark.parametrize( @@ -36,7 +49,23 @@ def test_moderation(content): ], ) async def test_amoderation(content): + # Prerequisites + assert CONFIG.OPENAI_API_KEY and CONFIG.OPENAI_API_KEY != "YOUR_API_KEY" + assert not CONFIG.OPENAI_API_TYPE + assert CONFIG.OPENAI_API_MODEL + moderation = Moderation() results = await moderation.amoderation(content=content) assert isinstance(results, list) assert len(results) == len(content) + + results = await moderation.amoderation_with_categories(content=content) + assert isinstance(results, list) + assert results + for m in results: + assert "flagged" in m + assert "true_categories" in m + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/tools/test_openai_text_to_embedding.py b/tests/metagpt/tools/test_openai_text_to_embedding.py new file mode 100644 index 000000000..086c9d45b --- /dev/null +++ b/tests/metagpt/tools/test_openai_text_to_embedding.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/26 +@Author : mashenquan +@File : test_openai_text_to_embedding.py +""" + +import pytest + +from metagpt.config import CONFIG +from metagpt.tools.openai_text_to_embedding import oas3_openai_text_to_embedding + + +@pytest.mark.asyncio +async def test_embedding(): + # Prerequisites + assert CONFIG.OPENAI_API_KEY and CONFIG.OPENAI_API_KEY != "YOUR_API_KEY" + assert not CONFIG.OPENAI_API_TYPE + assert CONFIG.OPENAI_API_MODEL + + result = await oas3_openai_text_to_embedding("Panda emoji") + assert result + assert result.model + assert len(result.data) > 0 + assert len(result.data[0].embedding) > 0 + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/tools/test_openai_text_to_image.py b/tests/metagpt/tools/test_openai_text_to_image.py new file mode 100644 index 000000000..24691a5e9 --- /dev/null +++ b/tests/metagpt/tools/test_openai_text_to_image.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/26 +@Author : mashenquan +@File : test_openai_text_to_image.py +""" + +import pytest + +from metagpt.config import CONFIG +from metagpt.tools.openai_text_to_image import oas3_openai_text_to_image + + +@pytest.mark.asyncio +async def test_draw(): + # Prerequisites + assert CONFIG.OPENAI_API_KEY and CONFIG.OPENAI_API_KEY != "YOUR_API_KEY" + assert not CONFIG.OPENAI_API_TYPE + assert CONFIG.OPENAI_API_MODEL + + binary_data = await oas3_openai_text_to_image("Panda emoji") + assert binary_data + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/tools/test_prompt_generator.py b/tests/metagpt/tools/test_prompt_writer.py similarity index 97% rename from tests/metagpt/tools/test_prompt_generator.py rename to tests/metagpt/tools/test_prompt_writer.py index ddbd2c43b..9f0c25ba1 100644 --- a/tests/metagpt/tools/test_prompt_generator.py +++ b/tests/metagpt/tools/test_prompt_writer.py @@ -3,7 +3,7 @@ """ @Time : 2023/5/2 17:46 @Author : alexanderwu -@File : test_prompt_generator.py +@File : test_prompt_writer.py """ import pytest diff --git a/tests/metagpt/tools/test_search_engine.py b/tests/metagpt/tools/test_search_engine.py index 25bce124a..d13b1506e 100644 --- a/tests/metagpt/tools/test_search_engine.py +++ b/tests/metagpt/tools/test_search_engine.py @@ -9,6 +9,7 @@ from __future__ import annotations import pytest +from metagpt.config import CONFIG from metagpt.logs import logger from metagpt.tools import SearchEngineType from metagpt.tools.search_engine import SearchEngine @@ -44,6 +45,15 @@ async def test_search_engine( max_results, as_string, ): + # Prerequisites + if search_engine_typpe is SearchEngineType.SERPAPI_GOOGLE: + assert CONFIG.SERPAPI_API_KEY and CONFIG.SERPAPI_API_KEY != "YOUR_API_KEY" + elif search_engine_typpe is SearchEngineType.DIRECT_GOOGLE: + assert CONFIG.GOOGLE_API_KEY and CONFIG.GOOGLE_API_KEY != "YOUR_API_KEY" + assert CONFIG.GOOGLE_CSE_ID and CONFIG.GOOGLE_CSE_ID != "YOUR_CSE_ID" + elif search_engine_typpe is SearchEngineType.SERPER_GOOGLE: + assert CONFIG.SERPER_API_KEY and CONFIG.SERPER_API_KEY != "YOUR_API_KEY" + search_engine = SearchEngine(search_engine_typpe, run_func) rsp = await search_engine.run("metagpt", max_results=max_results, as_string=as_string) logger.info(rsp) @@ -52,3 +62,7 @@ async def test_search_engine( else: assert isinstance(rsp, list) assert len(rsp) == max_results + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/tools/test_search_engine_meilisearch.py b/tests/metagpt/tools/test_search_engine_meilisearch.py index d5f7d162b..9e1fbfbb9 100644 --- a/tests/metagpt/tools/test_search_engine_meilisearch.py +++ b/tests/metagpt/tools/test_search_engine_meilisearch.py @@ -18,6 +18,10 @@ MASTER_KEY = "116Qavl2qpCYNEJNv5-e0RC9kncev1nr1gt7ybEGVLk" @pytest.fixture() def search_engine_server(): + # Prerequisites + # https://www.meilisearch.com/docs/learn/getting_started/installation + # brew update && brew install meilisearch + meilisearch_process = subprocess.Popen(["meilisearch", "--master-key", f"{MASTER_KEY}"], stdout=subprocess.PIPE) time.sleep(3) yield @@ -26,6 +30,10 @@ def search_engine_server(): def test_meilisearch(search_engine_server): + # Prerequisites + # https://www.meilisearch.com/docs/learn/getting_started/installation + # brew update && brew install meilisearch + search_engine = MeilisearchEngine(url="http://localhost:7700", token=MASTER_KEY) # 假设有一个名为"books"的数据源,包含要添加的文档库 @@ -44,3 +52,7 @@ def test_meilisearch(search_engine_server): # 添加文档库到搜索引擎 search_engine.add_documents(books_data_source, documents) logger.info(search_engine.search("Book 1")) + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/tools/test_ut_generator.py b/tests/metagpt/tools/test_ut_writer.py similarity index 100% rename from tests/metagpt/tools/test_ut_generator.py rename to tests/metagpt/tools/test_ut_writer.py diff --git a/tests/metagpt/tools/test_web_browser_engine.py b/tests/metagpt/tools/test_web_browser_engine.py index 1e4e956f2..289edda2f 100644 --- a/tests/metagpt/tools/test_web_browser_engine.py +++ b/tests/metagpt/tools/test_web_browser_engine.py @@ -4,8 +4,8 @@ import pytest -from metagpt.config import Config from metagpt.tools import WebBrowserEngineType, web_browser_engine +from metagpt.utils.parse_html import WebPage @pytest.mark.asyncio @@ -18,14 +18,17 @@ from metagpt.tools import WebBrowserEngineType, web_browser_engine ids=["playwright", "selenium"], ) async def test_scrape_web_page(browser_type, url, urls): - conf = Config() - browser = web_browser_engine.WebBrowserEngine(options=conf.runtime_options, engine=browser_type) + browser = web_browser_engine.WebBrowserEngine(engine=browser_type) result = await browser.run(url) - assert isinstance(result, str) - assert "深度赋智" in result + assert isinstance(result, WebPage) + assert "MetaGPT" in result.inner_text if urls: results = await browser.run(url, *urls) assert isinstance(results, list) assert len(results) == len(urls) + 1 - assert all(("深度赋智" in i) for i in results) + assert all(("MetaGPT" in i.inner_text) for i in results) + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/tools/test_web_browser_engine_playwright.py b/tests/metagpt/tools/test_web_browser_engine_playwright.py index cc6c09925..1e23ebb31 100644 --- a/tests/metagpt/tools/test_web_browser_engine_playwright.py +++ b/tests/metagpt/tools/test_web_browser_engine_playwright.py @@ -4,8 +4,9 @@ import pytest -from metagpt.config import Config +from metagpt.config import CONFIG from metagpt.tools import web_browser_engine_playwright +from metagpt.utils.parse_html import WebPage @pytest.mark.asyncio @@ -19,25 +20,25 @@ from metagpt.tools import web_browser_engine_playwright ids=["chromium-normal", "firefox-normal", "webkit-normal"], ) async def test_scrape_web_page(browser_type, use_proxy, kwagrs, url, urls, proxy, capfd): - conf = Config() - global_proxy = conf.global_proxy + global_proxy = CONFIG.global_proxy try: if use_proxy: - conf.global_proxy = proxy - browser = web_browser_engine_playwright.PlaywrightWrapper( - options=conf.runtime_options, browser_type=browser_type, **kwagrs - ) + CONFIG.global_proxy = proxy + browser = web_browser_engine_playwright.PlaywrightWrapper(browser_type=browser_type, **kwagrs) result = await browser.run(url) - result = result.inner_text - assert isinstance(result, str) - assert "DeepWisdom" in result + assert isinstance(result, WebPage) + assert "MetaGPT" in result.inner_text if urls: results = await browser.run(url, *urls) assert isinstance(results, list) assert len(results) == len(urls) + 1 - assert all(("DeepWisdom" in i) for i in results) + assert all(("MetaGPT" in i.inner_text) for i in results) if use_proxy: assert "Proxy:" in capfd.readouterr().out finally: - conf.global_proxy = global_proxy + CONFIG.global_proxy = global_proxy + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/tools/test_web_browser_engine_selenium.py b/tests/metagpt/tools/test_web_browser_engine_selenium.py index 77f4d8592..a2ac2f933 100644 --- a/tests/metagpt/tools/test_web_browser_engine_selenium.py +++ b/tests/metagpt/tools/test_web_browser_engine_selenium.py @@ -4,8 +4,9 @@ import pytest -from metagpt.config import Config +from metagpt.config import CONFIG from metagpt.tools import web_browser_engine_selenium +from metagpt.utils.parse_html import WebPage @pytest.mark.asyncio @@ -19,23 +20,28 @@ from metagpt.tools import web_browser_engine_selenium ids=["chrome-normal", "firefox-normal", "edge-normal"], ) async def test_scrape_web_page(browser_type, use_proxy, url, urls, proxy, capfd): - conf = Config() - global_proxy = conf.global_proxy + # Prerequisites + # firefox, chrome, Microsoft Edge + + global_proxy = CONFIG.global_proxy try: if use_proxy: - conf.global_proxy = proxy - browser = web_browser_engine_selenium.SeleniumWrapper(options=conf.runtime_options, browser_type=browser_type) + CONFIG.global_proxy = proxy + browser = web_browser_engine_selenium.SeleniumWrapper(browser_type=browser_type) result = await browser.run(url) - result = result.inner_text - assert isinstance(result, str) - assert "Deepwisdom" in result + assert isinstance(result, WebPage) + assert "MetaGPT" in result.inner_text if urls: results = await browser.run(url, *urls) assert isinstance(results, list) assert len(results) == len(urls) + 1 - assert all(("Deepwisdom" in i.inner_text) for i in results) + assert all(("MetaGPT" in i.inner_text) for i in results) if use_proxy: assert "Proxy:" in capfd.readouterr().out finally: - conf.global_proxy = global_proxy + CONFIG.global_proxy = global_proxy + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) From cfedba061afc1537b7a120653c7fab3d30346d46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Tue, 26 Dec 2023 22:22:29 +0800 Subject: [PATCH 3/4] feat: +unit test --- .gitignore | 1 + tests/metagpt/tools/test_ut_writer.py | 17 ++++++++++++++++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 039ba1956..67c2fa316 100644 --- a/.gitignore +++ b/.gitignore @@ -166,3 +166,4 @@ output tmp.png .dependencies.json tests/metagpt/utils/file_repo_git +*.tmp diff --git a/tests/metagpt/tools/test_ut_writer.py b/tests/metagpt/tools/test_ut_writer.py index 2ae94885f..e31afa702 100644 --- a/tests/metagpt/tools/test_ut_writer.py +++ b/tests/metagpt/tools/test_ut_writer.py @@ -3,8 +3,11 @@ """ @Time : 2023/4/30 21:44 @Author : alexanderwu -@File : test_ut_generator.py +@File : test_ut_writer.py """ +from pathlib import Path + +import pytest from metagpt.const import API_QUESTIONS_PATH, SWAGGER_PATH, UT_PY_PATH from metagpt.tools.ut_writer import YFT_PROMPT_PREFIX, UTGenerator @@ -12,7 +15,10 @@ from metagpt.tools.ut_writer import YFT_PROMPT_PREFIX, UTGenerator class TestUTWriter: def test_api_to_ut_sample(self): + # Prerequisites swagger_file = SWAGGER_PATH / "yft_swaggerApi.json" + assert swagger_file.exists() + tags = ["测试"] # "智能合同导入", "律师审查", "ai合同审查", "草拟合同&律师在线审查", "合同审批", "履约管理", "签约公司"] # 这里在文件中手动加入了两个测试标签的API @@ -25,3 +31,12 @@ class TestUTWriter: ret = utg.generate_ut(include_tags=tags) # 后续加入对文件生成内容与数量的检验 assert ret + + pathname = Path(__file__).with_suffix(".tmp") + utg.ask_gpt_and_save(question="question", tag="tag", fname=str(pathname)) + assert pathname.exists() + pathname.unlink(missing_ok=True) + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) From 8d925e50f164bf66a8f593b07a21c57ba161ab96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Tue, 26 Dec 2023 22:35:51 +0800 Subject: [PATCH 4/4] refactor: pre-commit --- tests/metagpt/actions/test_invoice_ocr.py | 5 +---- tests/metagpt/roles/test_invoice_ocr_assistant.py | 7 ++----- tests/metagpt/roles/test_tutorial_assistant.py | 1 + 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/tests/metagpt/actions/test_invoice_ocr.py b/tests/metagpt/actions/test_invoice_ocr.py index b3b93cf9f..12b1b4b30 100644 --- a/tests/metagpt/actions/test_invoice_ocr.py +++ b/tests/metagpt/actions/test_invoice_ocr.py @@ -34,10 +34,7 @@ async def test_invoice_ocr(invoice_path: str): @pytest.mark.parametrize( ("invoice_path", "expected_result"), [ - ( - "../../data/invoices/invoice-1.pdf", - [{"收款人": "小明", "城市": "深圳市", "总费用/元": "412.00", "开票日期": "2023年02月03日"}] - ), + ("../../data/invoices/invoice-1.pdf", [{"收款人": "小明", "城市": "深圳市", "总费用/元": "412.00", "开票日期": "2023年02月03日"}]), ], ) async def test_generate_table(invoice_path: str, expected_result: list[dict]): diff --git a/tests/metagpt/roles/test_invoice_ocr_assistant.py b/tests/metagpt/roles/test_invoice_ocr_assistant.py index 48abb9eb8..500d93a77 100644 --- a/tests/metagpt/roles/test_invoice_ocr_assistant.py +++ b/tests/metagpt/roles/test_invoice_ocr_assistant.py @@ -37,12 +37,10 @@ from metagpt.schema import Message Path("../../data/invoices/invoice-3.jpg"), Path("../../../data/invoice_table/invoice-3.xlsx"), {"收款人": "夏天", "城市": "福州", "总费用/元": 2462.00, "开票日期": "2023年08月26日"}, - ) + ), ], ) -async def test_invoice_ocr_assistant( - query: str, invoice_path: Path, invoice_table_path: Path, expected_result: dict -): +async def test_invoice_ocr_assistant(query: str, invoice_path: Path, invoice_table_path: Path, expected_result: dict): invoice_path = Path.cwd() / invoice_path role = InvoiceOCRAssistant() await role.run(Message(content=query, instruct_content=InvoicePath(file_path=invoice_path))) @@ -56,4 +54,3 @@ async def test_invoice_ocr_assistant( assert expected_result["城市"] in resp["城市"] assert int(expected_result["总费用/元"]) == int(resp["总费用/元"]) assert expected_result["开票日期"] == resp["开票日期"] - diff --git a/tests/metagpt/roles/test_tutorial_assistant.py b/tests/metagpt/roles/test_tutorial_assistant.py index 4455e1bf6..ca54aaff5 100644 --- a/tests/metagpt/roles/test_tutorial_assistant.py +++ b/tests/metagpt/roles/test_tutorial_assistant.py @@ -6,6 +6,7 @@ @File : test_tutorial_assistant.py """ import shutil + import aiofiles import pytest