diff --git a/metagpt/learn/text_to_image.py b/metagpt/learn/text_to_image.py index eaf528b3e..c3c62fb67 100644 --- a/metagpt/learn/text_to_image.py +++ b/metagpt/learn/text_to_image.py @@ -6,6 +6,7 @@ @File : text_to_image.py @Desc : Text-to-Image skill, which provides text-to-image functionality. """ +import base64 from metagpt.config import CONFIG from metagpt.const import BASE64_FORMAT @@ -25,11 +26,12 @@ async def text_to_image(text, size_type: str = "512x512", openai_api_key="", mod """ image_declaration = "data:image/png;base64," if CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL or model_url: - base64_data = await oas3_metagpt_text_to_image(text, size_type, model_url) + binary_data = await oas3_metagpt_text_to_image(text, size_type, model_url) elif CONFIG.OPENAI_API_KEY or openai_api_key: - base64_data = await oas3_openai_text_to_image(text, size_type) + binary_data = await oas3_openai_text_to_image(text, size_type) else: raise ValueError("Missing necessary parameters.") + base64_data = base64.b64encode(binary_data).decode("utf-8") s3 = S3() url = await s3.cache(data=base64_data, file_ext=".png", format=BASE64_FORMAT) if s3.is_valid else "" diff --git a/metagpt/learn/text_to_speech.py b/metagpt/learn/text_to_speech.py index 72958b8c7..ecd00c724 100644 --- a/metagpt/learn/text_to_speech.py +++ b/metagpt/learn/text_to_speech.py @@ -6,7 +6,6 @@ @File : text_to_speech.py @Desc : Text-to-Speech skill, which provides text-to-speech functionality """ -import openai from metagpt.config import CONFIG from metagpt.const import BASE64_FORMAT @@ -66,7 +65,6 @@ async def text_to_speech( return f"[{text}]({url})" return audio_declaration + base64_data if base64_data else base64_data - raise openai.InvalidRequestError( - message="AZURE_TTS_SUBSCRIPTION_KEY, AZURE_TTS_REGION, IFLYTEK_APP_ID, IFLYTEK_API_KEY, IFLYTEK_API_SECRET error", - param={}, + raise ValueError( + "AZURE_TTS_SUBSCRIPTION_KEY, AZURE_TTS_REGION, IFLYTEK_APP_ID, IFLYTEK_API_KEY, IFLYTEK_API_SECRET error" ) diff --git a/metagpt/tools/azure_tts.py b/metagpt/tools/azure_tts.py index d3e67c269..f4f8aa0a2 100644 --- a/metagpt/tools/azure_tts.py +++ b/metagpt/tools/azure_tts.py @@ -96,9 +96,10 @@ async def oas3_azsure_tts(text, lang="", voice="", style="", role="", subscripti async with aiofiles.open(filename, mode="rb") as reader: data = await reader.read() base64_string = base64.b64encode(data).decode("utf-8") - filename.unlink() except Exception as e: logger.error(f"text:{text}, error:{e}") return "" + finally: + filename.unlink(missing_ok=True) return base64_string diff --git a/metagpt/tools/hello.py b/metagpt/tools/hello.py index 52d2d11c1..ec7fc9231 100644 --- a/metagpt/tools/hello.py +++ b/metagpt/tools/hello.py @@ -7,7 +7,7 @@ @Desc : Implement the OpenAPI Specification 3.0 demo and use the following command to test the HTTP service: curl -X 'POST' \ - 'http://localhost:8080/openapi/greeting/dave' \ + 'http://localhost:8082/openapi/greeting/dave' \ -H 'accept: text/plain' \ -H 'Content-Type: application/json' \ -d '{}' @@ -26,4 +26,4 @@ if __name__ == "__main__": specification_dir = Path(__file__).parent.parent.parent / ".well-known" app = connexion.AsyncApp(__name__, specification_dir=str(specification_dir)) app.add_api("openapi.yaml", arguments={"title": "Hello World Example"}) - app.run(port=8080) + app.run(port=8082) diff --git a/metagpt/tools/iflytek_tts.py b/metagpt/tools/iflytek_tts.py index cb87d2e7f..ad2395362 100644 --- a/metagpt/tools/iflytek_tts.py +++ b/metagpt/tools/iflytek_tts.py @@ -6,7 +6,6 @@ @File : iflytek_tts.py @Desc : iFLYTEK TTS OAS3 api, which provides text-to-speech functionality """ -import asyncio import base64 import hashlib import hmac @@ -74,12 +73,13 @@ class IFlyTekTTS(object): await websocket.send(req) # receive frames - async with aiofiles.open(str(output_file), "w") as writer: + async with aiofiles.open(str(output_file), "wb") as writer: while True: v = await websocket.recv() rsp = IFlyTekTTSResponse(**json.loads(v)) if rsp.data: - await writer.write(rsp.data.audio) + binary_data = base64.b64decode(rsp.data.audio) + await writer.write(binary_data) if rsp.data.status != IFlyTekTTSStatus.STATUS_LAST_FRAME.value: continue break @@ -140,23 +140,13 @@ async def oas3_iflytek_tts(text: str, voice: str = "", app_id: str = "", api_key try: tts = IFlyTekTTS(app_id=app_id, api_key=api_key, api_secret=api_secret) await tts.synthesize_speech(text=text, output_file=str(filename), voice=voice) - async with aiofiles.open(str(filename), mode="r") as reader: - base64_string = await reader.read() + async with aiofiles.open(str(filename), mode="rb") as reader: + data = await reader.read() + base64_string = base64.b64encode(data).decode("utf-8") except Exception as e: logger.error(f"text:{text}, error:{e}") base64_string = "" finally: - filename.unlink() + filename.unlink(missing_ok=True) return base64_string - - -if __name__ == "__main__": - asyncio.get_event_loop().run_until_complete( - oas3_iflytek_tts( - text="你好,hello", - app_id="f7acef62", - api_key="fda72e3aa286042a492525816a5efa08", - api_secret="ZDk3NjdiMDBkODJlOWQ1NjRjMGI2NDY4", - ) - ) diff --git a/metagpt/tools/metagpt_oas3_api_svc.py b/metagpt/tools/metagpt_oas3_api_svc.py index 2ff4c8225..319e7efb2 100644 --- a/metagpt/tools/metagpt_oas3_api_svc.py +++ b/metagpt/tools/metagpt_oas3_api_svc.py @@ -6,39 +6,21 @@ @File : metagpt_oas3_api_svc.py @Desc : MetaGPT OpenAPI Specification 3.0 REST API service """ -import asyncio -import sys + from pathlib import Path import connexion -sys.path.append(str(Path(__file__).resolve().parent.parent.parent)) # fix-bug: No module named 'metagpt' - def oas_http_svc(): """Start the OAS 3.0 OpenAPI HTTP service""" - app = connexion.AioHttpApp(__name__, specification_dir="../../.well-known/") + print("http://localhost:8080/oas3/ui/") + specification_dir = Path(__file__).parent.parent.parent / ".well-known" + app = connexion.AsyncApp(__name__, specification_dir=str(specification_dir)) app.add_api("metagpt_oas3_api.yaml") app.add_api("openapi.yaml") app.run(port=8080) -async def async_main(): - """Start the OAS 3.0 OpenAPI HTTP service in the background.""" - loop = asyncio.get_event_loop() - loop.run_in_executor(None, oas_http_svc) - - # TODO: replace following codes: - while True: - await asyncio.sleep(1) - print("sleep") - - -def main(): - print("http://localhost:8080/oas3/ui/") - oas_http_svc() - - if __name__ == "__main__": - # asyncio.run(async_main()) - main() + oas_http_svc() diff --git a/metagpt/tools/metagpt_text_to_image.py b/metagpt/tools/metagpt_text_to_image.py index 50c0edcba..9a84e69eb 100644 --- a/metagpt/tools/metagpt_text_to_image.py +++ b/metagpt/tools/metagpt_text_to_image.py @@ -6,7 +6,6 @@ @File : metagpt_text_to_image.py @Desc : MetaGPT Text-to-Image OAS3 api, which provides text-to-image functionality. """ -import asyncio import base64 from typing import Dict, List @@ -14,7 +13,7 @@ import aiohttp import requests from pydantic import BaseModel -from metagpt.config import CONFIG, Config +from metagpt.config import CONFIG from metagpt.logs import logger @@ -75,11 +74,12 @@ class MetaGPTText2Image: async with session.post(self.model_url, headers=headers, json=data) as response: result = ImageResult(**await response.json()) if len(result.images) == 0: - return "" - return result.images[0] + return 0 + data = base64.b64decode(result.images[0]) + return data except requests.exceptions.RequestException as e: logger.error(f"An error occurred:{e}") - return "" + return 0 # Export @@ -96,15 +96,3 @@ async def oas3_metagpt_text_to_image(text, size_type: str = "512x512", model_url if not model_url: model_url = CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL return await MetaGPTText2Image(model_url).text_2_image(text, size_type=size_type) - - -if __name__ == "__main__": - Config() - loop = asyncio.new_event_loop() - task = loop.create_task(oas3_metagpt_text_to_image("Panda emoji")) - v = loop.run_until_complete(task) - print(v) - data = base64.b64decode(v) - with open("tmp.png", mode="wb") as writer: - writer.write(data) - print(v) diff --git a/metagpt/tools/moderation.py b/metagpt/tools/moderation.py index 5532e4f66..e4b23d538 100644 --- a/metagpt/tools/moderation.py +++ b/metagpt/tools/moderation.py @@ -5,7 +5,6 @@ @Author : zhanglei @File : moderation.py """ -import asyncio from typing import Union from metagpt.llm import LLM @@ -15,6 +14,38 @@ class Moderation: def __init__(self): self.llm = LLM() + def handle_moderation_results(self, results): + resp = [] + for item in results: + categories = item.categories.dict() + true_categories = [category for category, item_flagged in categories.items() if item_flagged] + resp.append({"flagged": item.flagged, "true_categories": true_categories}) + return resp + + def moderation_with_categories(self, content: Union[str, list[str]]): + resp = [] + if content: + moderation_results = self.llm.moderation(content=content) + resp = self.handle_moderation_results(moderation_results.results) + return resp + + async def amoderation_with_categories(self, content: Union[str, list[str]]): + resp = [] + if content: + moderation_results = await self.llm.amoderation(content=content) + resp = self.handle_moderation_results(moderation_results.results) + return resp + + def moderation(self, content: Union[str, list[str]]): + resp = [] + if content: + moderation_results = self.llm.moderation(content=content) + results = moderation_results.results + for item in results: + resp.append(item.flagged) + + return resp + async def amoderation(self, content: Union[str, list[str]]): resp = [] if content: @@ -24,15 +55,3 @@ class Moderation: resp.append(item.flagged) return resp - - -async def main(): - moderation = Moderation() - rsp = await moderation.amoderation( - content=["I will kill you", "The weather is really nice today", "I want to hit you"] - ) - print(rsp) - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/metagpt/tools/openai_text_to_embedding.py b/metagpt/tools/openai_text_to_embedding.py index fb6fbc653..52b2cc9eb 100644 --- a/metagpt/tools/openai_text_to_embedding.py +++ b/metagpt/tools/openai_text_to_embedding.py @@ -7,14 +7,13 @@ @Desc : OpenAI Text-to-Embedding OAS3 api, which provides text-to-embedding functionality. For more details, checkout: `https://platform.openai.com/docs/api-reference/embeddings/object` """ -import asyncio from typing import List import aiohttp import requests -from pydantic import BaseModel +from pydantic import BaseModel, Field -from metagpt.config import CONFIG, Config +from metagpt.config import CONFIG from metagpt.logs import logger @@ -29,15 +28,18 @@ class Embedding(BaseModel): class Usage(BaseModel): - prompt_tokens: int - total_tokens: int + prompt_tokens: int = 0 + total_tokens: int = 0 class ResultEmbedding(BaseModel): - object: str - data: List[Embedding] - model: str - usage: Usage + class Config: + alias = {"object_": "object"} + + object_: str = "" + data: List[Embedding] = [] + model: str = "" + usage: Usage = Field(default_factory=Usage) class OpenAIText2Embedding: @@ -45,7 +47,7 @@ class OpenAIText2Embedding: """ :param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys` """ - self.openai_api_key = openai_api_key if openai_api_key else CONFIG.OPENAI_API_KEY + self.openai_api_key = openai_api_key or CONFIG.OPENAI_API_KEY async def text_2_embedding(self, text, model="text-embedding-ada-002"): """Text to embedding @@ -55,15 +57,18 @@ class OpenAIText2Embedding: :return: A json object of :class:`ResultEmbedding` class if successful, otherwise `{}`. """ + proxies = {"proxy": CONFIG.openai_proxy} if CONFIG.openai_proxy else {} headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.openai_api_key}"} data = {"input": text, "model": model} + url = "https://api.openai.com/v1/embeddings" try: async with aiohttp.ClientSession() as session: - async with session.post("https://api.openai.com/v1/embeddings", headers=headers, json=data) as response: - return await response.json() + async with session.post(url, headers=headers, json=data, **proxies) as response: + data = await response.json() + return ResultEmbedding(**data) except requests.exceptions.RequestException as e: logger.error(f"An error occurred:{e}") - return {} + return ResultEmbedding() # Export @@ -80,11 +85,3 @@ async def oas3_openai_text_to_embedding(text, model="text-embedding-ada-002", op if not openai_api_key: openai_api_key = CONFIG.OPENAI_API_KEY return await OpenAIText2Embedding(openai_api_key).text_2_embedding(text, model=model) - - -if __name__ == "__main__": - Config() - loop = asyncio.new_event_loop() - task = loop.create_task(oas3_openai_text_to_embedding("Panda emoji")) - v = loop.run_until_complete(task) - print(v) diff --git a/metagpt/tools/openai_text_to_image.py b/metagpt/tools/openai_text_to_image.py index 71381d8f2..fcfa86c7d 100644 --- a/metagpt/tools/openai_text_to_image.py +++ b/metagpt/tools/openai_text_to_image.py @@ -6,13 +6,10 @@ @File : openai_text_to_image.py @Desc : OpenAI Text-to-Image OAS3 api, which provides text-to-image functionality. """ -import asyncio -import base64 import aiohttp import requests -from metagpt.config import Config from metagpt.llm import LLM from metagpt.logs import logger @@ -23,7 +20,6 @@ class OpenAIText2Image: :param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys` """ self._llm = LLM() - self._client = self._llm.async_client def __del__(self): if self._llm: @@ -37,7 +33,7 @@ class OpenAIText2Image: :return: The image data is returned in Base64 encoding. """ try: - result = await self._client.images.generate(prompt=text, n=1, size=size_type) + result = await self._llm.async_client.images.generate(prompt=text, n=1, size=size_type) except Exception as e: logger.error(f"An error occurred:{e}") return "" @@ -57,12 +53,11 @@ class OpenAIText2Image: async with session.get(url) as response: response.raise_for_status() # 如果是 4xx 或 5xx 响应,会引发异常 image_data = await response.read() - base64_image = base64.b64encode(image_data).decode("utf-8") - return base64_image + return image_data except requests.exceptions.RequestException as e: logger.error(f"An error occurred:{e}") - return "" + return 0 # Export @@ -76,11 +71,3 @@ async def oas3_openai_text_to_image(text, size_type: str = "1024x1024"): if not text: return "" return await OpenAIText2Image().text_2_image(text, size_type=size_type) - - -if __name__ == "__main__": - Config() - loop = asyncio.new_event_loop() - task = loop.create_task(oas3_openai_text_to_image("Panda emoji")) - v = loop.run_until_complete(task) - print(v) diff --git a/metagpt/tools/web_browser_engine_selenium.py b/metagpt/tools/web_browser_engine_selenium.py index 628c8dea2..cabae7531 100644 --- a/metagpt/tools/web_browser_engine_selenium.py +++ b/metagpt/tools/web_browser_engine_selenium.py @@ -9,7 +9,7 @@ import asyncio import importlib from concurrent import futures from copy import deepcopy -from typing import Dict, Literal +from typing import Literal from selenium.webdriver.common.by import By from selenium.webdriver.support import expected_conditions as EC @@ -33,7 +33,6 @@ class SeleniumWrapper: def __init__( self, - options: Dict, browser_type: Literal["chrome", "firefox", "edge", "ie"] | None = None, launch_kwargs: dict | None = None, *, diff --git a/requirements-test.txt b/requirements-test.txt index 39ba608b7..fcf265163 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -2,4 +2,11 @@ -r requirements.txt connexion[uvicorn]~=3.0.5 -azure-cognitiveservices-speech~=1.31.0 \ No newline at end of file +azure-cognitiveservices-speech~=1.31.0 +duckduckgo_search +serpapi +google +httplib2 +google_api_python_client +selenium +webdriver_manager \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index f2566fb15..c8d21dfc8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,7 +9,7 @@ faiss_cpu==1.7.4 fire==0.4.0 typer # godot==0.1.1 -# google_api_python_client==2.93.0 +# google_api_python_client==2.93.0 # Used by search_engine.py lancedb==0.1.16 langchain==0.0.352 loguru==0.6.0 diff --git a/tests/metagpt/tools/test_hello.py b/tests/metagpt/tools/test_hello.py index 037dcd1b7..fdf67ac35 100644 --- a/tests/metagpt/tools/test_hello.py +++ b/tests/metagpt/tools/test_hello.py @@ -5,6 +5,7 @@ @Author : mashenquan @File : test_hello.py """ +import asyncio import subprocess from pathlib import Path @@ -14,10 +15,11 @@ import requests @pytest.mark.asyncio async def test_hello(): - script_pathname = Path(__file__).resolve() + script_pathname = Path(__file__).parent / "../../../metagpt/tools/hello.py" process = subprocess.Popen(["python", str(script_pathname)]) + await asyncio.sleep(5) - url = "http://localhost:8080/openapi/greeting/dave" + url = "http://localhost:8082/openapi/greeting/dave" headers = {"accept": "text/plain", "Content-Type": "application/json"} data = {} response = requests.post(url, headers=headers, json=data) diff --git a/tests/metagpt/tools/test_iflytek_tts.py b/tests/metagpt/tools/test_iflytek_tts.py new file mode 100644 index 000000000..58d8a83ce --- /dev/null +++ b/tests/metagpt/tools/test_iflytek_tts.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/26 +@Author : mashenquan +@File : test_iflytek_tts.py +""" +import pytest + +from metagpt.config import CONFIG +from metagpt.tools.iflytek_tts import oas3_iflytek_tts + + +@pytest.mark.asyncio +async def test_tts(): + # Prerequisites + assert CONFIG.IFLYTEK_APP_ID + assert CONFIG.IFLYTEK_API_KEY + assert CONFIG.IFLYTEK_API_SECRET + + result = await oas3_iflytek_tts( + text="你好,hello", + app_id=CONFIG.IFLYTEK_APP_ID, + api_key=CONFIG.IFLYTEK_API_KEY, + api_secret=CONFIG.IFLYTEK_API_SECRET, + ) + assert result + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/tools/test_metagpt_oas3_api_svc.py b/tests/metagpt/tools/test_metagpt_oas3_api_svc.py new file mode 100644 index 000000000..e0f17aa05 --- /dev/null +++ b/tests/metagpt/tools/test_metagpt_oas3_api_svc.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/26 +@Author : mashenquan +@File : test_metagpt_oas3_api_svc.py +""" +import asyncio +import subprocess +from pathlib import Path + +import pytest +import requests + + +@pytest.mark.asyncio +async def test_oas2_svc(): + script_pathname = Path(__file__).parent / "../../../metagpt/tools/metagpt_oas3_api_svc.py" + process = subprocess.Popen(["python", str(script_pathname)]) + await asyncio.sleep(5) + + url = "http://localhost:8080/openapi/greeting/dave" + headers = {"accept": "text/plain", "Content-Type": "application/json"} + data = {} + response = requests.post(url, headers=headers, json=data) + assert response.text == "Hello dave\n" + + process.terminate() + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/tools/test_metagpt_text_to_image.py b/tests/metagpt/tools/test_metagpt_text_to_image.py new file mode 100644 index 000000000..f5ced2061 --- /dev/null +++ b/tests/metagpt/tools/test_metagpt_text_to_image.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/26 +@Author : mashenquan +@File : test_metagpt_text_to_image.py +""" + +import pytest + +from metagpt.config import CONFIG +from metagpt.tools.metagpt_text_to_image import oas3_metagpt_text_to_image + + +@pytest.mark.asyncio +async def test_draw(): + # Prerequisites + assert CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL + + binary_data = await oas3_metagpt_text_to_image("Panda emoji") + assert binary_data + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/tools/test_moderation.py b/tests/metagpt/tools/test_moderation.py index 5ec3bd4de..c71611bd3 100644 --- a/tests/metagpt/tools/test_moderation.py +++ b/tests/metagpt/tools/test_moderation.py @@ -8,6 +8,7 @@ import pytest +from metagpt.config import CONFIG from metagpt.tools.moderation import Moderation @@ -20,11 +21,23 @@ from metagpt.tools.moderation import Moderation ], ) def test_moderation(content): + # Prerequisites + assert CONFIG.OPENAI_API_KEY and CONFIG.OPENAI_API_KEY != "YOUR_API_KEY" + assert not CONFIG.OPENAI_API_TYPE + assert CONFIG.OPENAI_API_MODEL + moderation = Moderation() results = moderation.moderation(content=content) assert isinstance(results, list) assert len(results) == len(content) + results = moderation.moderation_with_categories(content=content) + assert isinstance(results, list) + assert results + for m in results: + assert "flagged" in m + assert "true_categories" in m + @pytest.mark.asyncio @pytest.mark.parametrize( @@ -36,7 +49,23 @@ def test_moderation(content): ], ) async def test_amoderation(content): + # Prerequisites + assert CONFIG.OPENAI_API_KEY and CONFIG.OPENAI_API_KEY != "YOUR_API_KEY" + assert not CONFIG.OPENAI_API_TYPE + assert CONFIG.OPENAI_API_MODEL + moderation = Moderation() results = await moderation.amoderation(content=content) assert isinstance(results, list) assert len(results) == len(content) + + results = await moderation.amoderation_with_categories(content=content) + assert isinstance(results, list) + assert results + for m in results: + assert "flagged" in m + assert "true_categories" in m + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/tools/test_openai_text_to_embedding.py b/tests/metagpt/tools/test_openai_text_to_embedding.py new file mode 100644 index 000000000..086c9d45b --- /dev/null +++ b/tests/metagpt/tools/test_openai_text_to_embedding.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/26 +@Author : mashenquan +@File : test_openai_text_to_embedding.py +""" + +import pytest + +from metagpt.config import CONFIG +from metagpt.tools.openai_text_to_embedding import oas3_openai_text_to_embedding + + +@pytest.mark.asyncio +async def test_embedding(): + # Prerequisites + assert CONFIG.OPENAI_API_KEY and CONFIG.OPENAI_API_KEY != "YOUR_API_KEY" + assert not CONFIG.OPENAI_API_TYPE + assert CONFIG.OPENAI_API_MODEL + + result = await oas3_openai_text_to_embedding("Panda emoji") + assert result + assert result.model + assert len(result.data) > 0 + assert len(result.data[0].embedding) > 0 + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/tools/test_openai_text_to_image.py b/tests/metagpt/tools/test_openai_text_to_image.py new file mode 100644 index 000000000..24691a5e9 --- /dev/null +++ b/tests/metagpt/tools/test_openai_text_to_image.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/26 +@Author : mashenquan +@File : test_openai_text_to_image.py +""" + +import pytest + +from metagpt.config import CONFIG +from metagpt.tools.openai_text_to_image import oas3_openai_text_to_image + + +@pytest.mark.asyncio +async def test_draw(): + # Prerequisites + assert CONFIG.OPENAI_API_KEY and CONFIG.OPENAI_API_KEY != "YOUR_API_KEY" + assert not CONFIG.OPENAI_API_TYPE + assert CONFIG.OPENAI_API_MODEL + + binary_data = await oas3_openai_text_to_image("Panda emoji") + assert binary_data + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/tools/test_prompt_generator.py b/tests/metagpt/tools/test_prompt_writer.py similarity index 97% rename from tests/metagpt/tools/test_prompt_generator.py rename to tests/metagpt/tools/test_prompt_writer.py index ddbd2c43b..9f0c25ba1 100644 --- a/tests/metagpt/tools/test_prompt_generator.py +++ b/tests/metagpt/tools/test_prompt_writer.py @@ -3,7 +3,7 @@ """ @Time : 2023/5/2 17:46 @Author : alexanderwu -@File : test_prompt_generator.py +@File : test_prompt_writer.py """ import pytest diff --git a/tests/metagpt/tools/test_search_engine.py b/tests/metagpt/tools/test_search_engine.py index 25bce124a..d13b1506e 100644 --- a/tests/metagpt/tools/test_search_engine.py +++ b/tests/metagpt/tools/test_search_engine.py @@ -9,6 +9,7 @@ from __future__ import annotations import pytest +from metagpt.config import CONFIG from metagpt.logs import logger from metagpt.tools import SearchEngineType from metagpt.tools.search_engine import SearchEngine @@ -44,6 +45,15 @@ async def test_search_engine( max_results, as_string, ): + # Prerequisites + if search_engine_typpe is SearchEngineType.SERPAPI_GOOGLE: + assert CONFIG.SERPAPI_API_KEY and CONFIG.SERPAPI_API_KEY != "YOUR_API_KEY" + elif search_engine_typpe is SearchEngineType.DIRECT_GOOGLE: + assert CONFIG.GOOGLE_API_KEY and CONFIG.GOOGLE_API_KEY != "YOUR_API_KEY" + assert CONFIG.GOOGLE_CSE_ID and CONFIG.GOOGLE_CSE_ID != "YOUR_CSE_ID" + elif search_engine_typpe is SearchEngineType.SERPER_GOOGLE: + assert CONFIG.SERPER_API_KEY and CONFIG.SERPER_API_KEY != "YOUR_API_KEY" + search_engine = SearchEngine(search_engine_typpe, run_func) rsp = await search_engine.run("metagpt", max_results=max_results, as_string=as_string) logger.info(rsp) @@ -52,3 +62,7 @@ async def test_search_engine( else: assert isinstance(rsp, list) assert len(rsp) == max_results + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/tools/test_search_engine_meilisearch.py b/tests/metagpt/tools/test_search_engine_meilisearch.py index d5f7d162b..9e1fbfbb9 100644 --- a/tests/metagpt/tools/test_search_engine_meilisearch.py +++ b/tests/metagpt/tools/test_search_engine_meilisearch.py @@ -18,6 +18,10 @@ MASTER_KEY = "116Qavl2qpCYNEJNv5-e0RC9kncev1nr1gt7ybEGVLk" @pytest.fixture() def search_engine_server(): + # Prerequisites + # https://www.meilisearch.com/docs/learn/getting_started/installation + # brew update && brew install meilisearch + meilisearch_process = subprocess.Popen(["meilisearch", "--master-key", f"{MASTER_KEY}"], stdout=subprocess.PIPE) time.sleep(3) yield @@ -26,6 +30,10 @@ def search_engine_server(): def test_meilisearch(search_engine_server): + # Prerequisites + # https://www.meilisearch.com/docs/learn/getting_started/installation + # brew update && brew install meilisearch + search_engine = MeilisearchEngine(url="http://localhost:7700", token=MASTER_KEY) # 假设有一个名为"books"的数据源,包含要添加的文档库 @@ -44,3 +52,7 @@ def test_meilisearch(search_engine_server): # 添加文档库到搜索引擎 search_engine.add_documents(books_data_source, documents) logger.info(search_engine.search("Book 1")) + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/tools/test_ut_generator.py b/tests/metagpt/tools/test_ut_writer.py similarity index 100% rename from tests/metagpt/tools/test_ut_generator.py rename to tests/metagpt/tools/test_ut_writer.py diff --git a/tests/metagpt/tools/test_web_browser_engine.py b/tests/metagpt/tools/test_web_browser_engine.py index 1e4e956f2..289edda2f 100644 --- a/tests/metagpt/tools/test_web_browser_engine.py +++ b/tests/metagpt/tools/test_web_browser_engine.py @@ -4,8 +4,8 @@ import pytest -from metagpt.config import Config from metagpt.tools import WebBrowserEngineType, web_browser_engine +from metagpt.utils.parse_html import WebPage @pytest.mark.asyncio @@ -18,14 +18,17 @@ from metagpt.tools import WebBrowserEngineType, web_browser_engine ids=["playwright", "selenium"], ) async def test_scrape_web_page(browser_type, url, urls): - conf = Config() - browser = web_browser_engine.WebBrowserEngine(options=conf.runtime_options, engine=browser_type) + browser = web_browser_engine.WebBrowserEngine(engine=browser_type) result = await browser.run(url) - assert isinstance(result, str) - assert "深度赋智" in result + assert isinstance(result, WebPage) + assert "MetaGPT" in result.inner_text if urls: results = await browser.run(url, *urls) assert isinstance(results, list) assert len(results) == len(urls) + 1 - assert all(("深度赋智" in i) for i in results) + assert all(("MetaGPT" in i.inner_text) for i in results) + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/tools/test_web_browser_engine_playwright.py b/tests/metagpt/tools/test_web_browser_engine_playwright.py index cc6c09925..1e23ebb31 100644 --- a/tests/metagpt/tools/test_web_browser_engine_playwright.py +++ b/tests/metagpt/tools/test_web_browser_engine_playwright.py @@ -4,8 +4,9 @@ import pytest -from metagpt.config import Config +from metagpt.config import CONFIG from metagpt.tools import web_browser_engine_playwright +from metagpt.utils.parse_html import WebPage @pytest.mark.asyncio @@ -19,25 +20,25 @@ from metagpt.tools import web_browser_engine_playwright ids=["chromium-normal", "firefox-normal", "webkit-normal"], ) async def test_scrape_web_page(browser_type, use_proxy, kwagrs, url, urls, proxy, capfd): - conf = Config() - global_proxy = conf.global_proxy + global_proxy = CONFIG.global_proxy try: if use_proxy: - conf.global_proxy = proxy - browser = web_browser_engine_playwright.PlaywrightWrapper( - options=conf.runtime_options, browser_type=browser_type, **kwagrs - ) + CONFIG.global_proxy = proxy + browser = web_browser_engine_playwright.PlaywrightWrapper(browser_type=browser_type, **kwagrs) result = await browser.run(url) - result = result.inner_text - assert isinstance(result, str) - assert "DeepWisdom" in result + assert isinstance(result, WebPage) + assert "MetaGPT" in result.inner_text if urls: results = await browser.run(url, *urls) assert isinstance(results, list) assert len(results) == len(urls) + 1 - assert all(("DeepWisdom" in i) for i in results) + assert all(("MetaGPT" in i.inner_text) for i in results) if use_proxy: assert "Proxy:" in capfd.readouterr().out finally: - conf.global_proxy = global_proxy + CONFIG.global_proxy = global_proxy + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/tools/test_web_browser_engine_selenium.py b/tests/metagpt/tools/test_web_browser_engine_selenium.py index 77f4d8592..a2ac2f933 100644 --- a/tests/metagpt/tools/test_web_browser_engine_selenium.py +++ b/tests/metagpt/tools/test_web_browser_engine_selenium.py @@ -4,8 +4,9 @@ import pytest -from metagpt.config import Config +from metagpt.config import CONFIG from metagpt.tools import web_browser_engine_selenium +from metagpt.utils.parse_html import WebPage @pytest.mark.asyncio @@ -19,23 +20,28 @@ from metagpt.tools import web_browser_engine_selenium ids=["chrome-normal", "firefox-normal", "edge-normal"], ) async def test_scrape_web_page(browser_type, use_proxy, url, urls, proxy, capfd): - conf = Config() - global_proxy = conf.global_proxy + # Prerequisites + # firefox, chrome, Microsoft Edge + + global_proxy = CONFIG.global_proxy try: if use_proxy: - conf.global_proxy = proxy - browser = web_browser_engine_selenium.SeleniumWrapper(options=conf.runtime_options, browser_type=browser_type) + CONFIG.global_proxy = proxy + browser = web_browser_engine_selenium.SeleniumWrapper(browser_type=browser_type) result = await browser.run(url) - result = result.inner_text - assert isinstance(result, str) - assert "Deepwisdom" in result + assert isinstance(result, WebPage) + assert "MetaGPT" in result.inner_text if urls: results = await browser.run(url, *urls) assert isinstance(results, list) assert len(results) == len(urls) + 1 - assert all(("Deepwisdom" in i.inner_text) for i in results) + assert all(("MetaGPT" in i.inner_text) for i in results) if use_proxy: assert "Proxy:" in capfd.readouterr().out finally: - conf.global_proxy = global_proxy + CONFIG.global_proxy = global_proxy + + +if __name__ == "__main__": + pytest.main([__file__, "-s"])