mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
commit
641c71bf18
33 changed files with 423 additions and 194 deletions
4
.gitignore
vendored
4
.gitignore
vendored
|
|
@ -164,4 +164,6 @@ metagpt/roles/idea_agent.py
|
|||
# output folder
|
||||
output
|
||||
tmp.png
|
||||
|
||||
.dependencies.json
|
||||
tests/metagpt/utils/file_repo_git
|
||||
*.tmp
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
@File : text_to_image.py
|
||||
@Desc : Text-to-Image skill, which provides text-to-image functionality.
|
||||
"""
|
||||
import base64
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import BASE64_FORMAT
|
||||
|
|
@ -25,11 +26,12 @@ async def text_to_image(text, size_type: str = "512x512", openai_api_key="", mod
|
|||
"""
|
||||
image_declaration = "data:image/png;base64,"
|
||||
if CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL or model_url:
|
||||
base64_data = await oas3_metagpt_text_to_image(text, size_type, model_url)
|
||||
binary_data = await oas3_metagpt_text_to_image(text, size_type, model_url)
|
||||
elif CONFIG.OPENAI_API_KEY or openai_api_key:
|
||||
base64_data = await oas3_openai_text_to_image(text, size_type)
|
||||
binary_data = await oas3_openai_text_to_image(text, size_type)
|
||||
else:
|
||||
raise ValueError("Missing necessary parameters.")
|
||||
base64_data = base64.b64encode(binary_data).decode("utf-8")
|
||||
|
||||
s3 = S3()
|
||||
url = await s3.cache(data=base64_data, file_ext=".png", format=BASE64_FORMAT) if s3.is_valid else ""
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@
|
|||
@File : text_to_speech.py
|
||||
@Desc : Text-to-Speech skill, which provides text-to-speech functionality
|
||||
"""
|
||||
import openai
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import BASE64_FORMAT
|
||||
|
|
@ -66,7 +65,6 @@ async def text_to_speech(
|
|||
return f"[{text}]({url})"
|
||||
return audio_declaration + base64_data if base64_data else base64_data
|
||||
|
||||
raise openai.InvalidRequestError(
|
||||
message="AZURE_TTS_SUBSCRIPTION_KEY, AZURE_TTS_REGION, IFLYTEK_APP_ID, IFLYTEK_API_KEY, IFLYTEK_API_SECRET error",
|
||||
param={},
|
||||
raise ValueError(
|
||||
"AZURE_TTS_SUBSCRIPTION_KEY, AZURE_TTS_REGION, IFLYTEK_APP_ID, IFLYTEK_API_KEY, IFLYTEK_API_SECRET error"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@
|
|||
@File : azure_tts.py
|
||||
@Modified by: mashenquan, 2023/8/17. Azure TTS OAS3 api, which provides text-to-speech functionality
|
||||
"""
|
||||
import asyncio
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
|
@ -14,7 +13,7 @@ from uuid import uuid4
|
|||
import aiofiles
|
||||
from azure.cognitiveservices.speech import AudioConfig, SpeechConfig, SpeechSynthesizer
|
||||
|
||||
from metagpt.config import CONFIG, Config
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
|
|
@ -97,17 +96,10 @@ async def oas3_azsure_tts(text, lang="", voice="", style="", role="", subscripti
|
|||
async with aiofiles.open(filename, mode="rb") as reader:
|
||||
data = await reader.read()
|
||||
base64_string = base64.b64encode(data).decode("utf-8")
|
||||
filename.unlink()
|
||||
except Exception as e:
|
||||
logger.error(f"text:{text}, error:{e}")
|
||||
return ""
|
||||
finally:
|
||||
filename.unlink(missing_ok=True)
|
||||
|
||||
return base64_string
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Config()
|
||||
loop = asyncio.new_event_loop()
|
||||
v = loop.create_task(oas3_azsure_tts("测试,test"))
|
||||
loop.run_until_complete(v)
|
||||
print(v)
|
||||
|
|
|
|||
|
|
@ -7,11 +7,12 @@
|
|||
@Desc : Implement the OpenAPI Specification 3.0 demo and use the following command to test the HTTP service:
|
||||
|
||||
curl -X 'POST' \
|
||||
'http://localhost:8080/openapi/greeting/dave' \
|
||||
'http://localhost:8082/openapi/greeting/dave' \
|
||||
-H 'accept: text/plain' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{}'
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
import connexion
|
||||
|
||||
|
|
@ -22,6 +23,7 @@ async def post_greeting(name: str) -> str:
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = connexion.AioHttpApp(__name__, specification_dir="../../.well-known/")
|
||||
specification_dir = Path(__file__).parent.parent.parent / ".well-known"
|
||||
app = connexion.AsyncApp(__name__, specification_dir=str(specification_dir))
|
||||
app.add_api("openapi.yaml", arguments={"title": "Hello World Example"})
|
||||
app.run(port=8080)
|
||||
app.run(port=8082)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@
|
|||
@File : iflytek_tts.py
|
||||
@Desc : iFLYTEK TTS OAS3 api, which provides text-to-speech functionality
|
||||
"""
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
|
|
@ -74,12 +73,13 @@ class IFlyTekTTS(object):
|
|||
await websocket.send(req)
|
||||
|
||||
# receive frames
|
||||
async with aiofiles.open(str(output_file), "w") as writer:
|
||||
async with aiofiles.open(str(output_file), "wb") as writer:
|
||||
while True:
|
||||
v = await websocket.recv()
|
||||
rsp = IFlyTekTTSResponse(**json.loads(v))
|
||||
if rsp.data:
|
||||
await writer.write(rsp.data.audio)
|
||||
binary_data = base64.b64decode(rsp.data.audio)
|
||||
await writer.write(binary_data)
|
||||
if rsp.data.status != IFlyTekTTSStatus.STATUS_LAST_FRAME.value:
|
||||
continue
|
||||
break
|
||||
|
|
@ -140,23 +140,13 @@ async def oas3_iflytek_tts(text: str, voice: str = "", app_id: str = "", api_key
|
|||
try:
|
||||
tts = IFlyTekTTS(app_id=app_id, api_key=api_key, api_secret=api_secret)
|
||||
await tts.synthesize_speech(text=text, output_file=str(filename), voice=voice)
|
||||
async with aiofiles.open(str(filename), mode="r") as reader:
|
||||
base64_string = await reader.read()
|
||||
async with aiofiles.open(str(filename), mode="rb") as reader:
|
||||
data = await reader.read()
|
||||
base64_string = base64.b64encode(data).decode("utf-8")
|
||||
except Exception as e:
|
||||
logger.error(f"text:{text}, error:{e}")
|
||||
base64_string = ""
|
||||
finally:
|
||||
filename.unlink()
|
||||
filename.unlink(missing_ok=True)
|
||||
|
||||
return base64_string
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.get_event_loop().run_until_complete(
|
||||
oas3_iflytek_tts(
|
||||
text="你好,hello",
|
||||
app_id="f7acef62",
|
||||
api_key="fda72e3aa286042a492525816a5efa08",
|
||||
api_secret="ZDk3NjdiMDBkODJlOWQ1NjRjMGI2NDY4",
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -6,39 +6,21 @@
|
|||
@File : metagpt_oas3_api_svc.py
|
||||
@Desc : MetaGPT OpenAPI Specification 3.0 REST API service
|
||||
"""
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import connexion
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.parent)) # fix-bug: No module named 'metagpt'
|
||||
|
||||
|
||||
def oas_http_svc():
|
||||
"""Start the OAS 3.0 OpenAPI HTTP service"""
|
||||
app = connexion.AioHttpApp(__name__, specification_dir="../../.well-known/")
|
||||
print("http://localhost:8080/oas3/ui/")
|
||||
specification_dir = Path(__file__).parent.parent.parent / ".well-known"
|
||||
app = connexion.AsyncApp(__name__, specification_dir=str(specification_dir))
|
||||
app.add_api("metagpt_oas3_api.yaml")
|
||||
app.add_api("openapi.yaml")
|
||||
app.run(port=8080)
|
||||
|
||||
|
||||
async def async_main():
|
||||
"""Start the OAS 3.0 OpenAPI HTTP service in the background."""
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_in_executor(None, oas_http_svc)
|
||||
|
||||
# TODO: replace following codes:
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
print("sleep")
|
||||
|
||||
|
||||
def main():
|
||||
print("http://localhost:8080/oas3/ui/")
|
||||
oas_http_svc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# asyncio.run(async_main())
|
||||
main()
|
||||
oas_http_svc()
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@
|
|||
@File : metagpt_text_to_image.py
|
||||
@Desc : MetaGPT Text-to-Image OAS3 api, which provides text-to-image functionality.
|
||||
"""
|
||||
import asyncio
|
||||
import base64
|
||||
from typing import Dict, List
|
||||
|
||||
|
|
@ -14,7 +13,7 @@ import aiohttp
|
|||
import requests
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.config import CONFIG, Config
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
|
|
@ -75,11 +74,12 @@ class MetaGPTText2Image:
|
|||
async with session.post(self.model_url, headers=headers, json=data) as response:
|
||||
result = ImageResult(**await response.json())
|
||||
if len(result.images) == 0:
|
||||
return ""
|
||||
return result.images[0]
|
||||
return 0
|
||||
data = base64.b64decode(result.images[0])
|
||||
return data
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"An error occurred:{e}")
|
||||
return ""
|
||||
return 0
|
||||
|
||||
|
||||
# Export
|
||||
|
|
@ -96,15 +96,3 @@ async def oas3_metagpt_text_to_image(text, size_type: str = "512x512", model_url
|
|||
if not model_url:
|
||||
model_url = CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL
|
||||
return await MetaGPTText2Image(model_url).text_2_image(text, size_type=size_type)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Config()
|
||||
loop = asyncio.new_event_loop()
|
||||
task = loop.create_task(oas3_metagpt_text_to_image("Panda emoji"))
|
||||
v = loop.run_until_complete(task)
|
||||
print(v)
|
||||
data = base64.b64decode(v)
|
||||
with open("tmp.png", mode="wb") as writer:
|
||||
writer.write(data)
|
||||
print(v)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
@Author : zhanglei
|
||||
@File : moderation.py
|
||||
"""
|
||||
import asyncio
|
||||
from typing import Union
|
||||
|
||||
from metagpt.llm import LLM
|
||||
|
|
@ -15,6 +14,21 @@ class Moderation:
|
|||
def __init__(self):
|
||||
self.llm = LLM()
|
||||
|
||||
def handle_moderation_results(self, results):
|
||||
resp = []
|
||||
for item in results:
|
||||
categories = item.categories.dict()
|
||||
true_categories = [category for category, item_flagged in categories.items() if item_flagged]
|
||||
resp.append({"flagged": item.flagged, "true_categories": true_categories})
|
||||
return resp
|
||||
|
||||
async def amoderation_with_categories(self, content: Union[str, list[str]]):
|
||||
resp = []
|
||||
if content:
|
||||
moderation_results = await self.llm.amoderation(content=content)
|
||||
resp = self.handle_moderation_results(moderation_results.results)
|
||||
return resp
|
||||
|
||||
async def amoderation(self, content: Union[str, list[str]]):
|
||||
resp = []
|
||||
if content:
|
||||
|
|
@ -24,15 +38,3 @@ class Moderation:
|
|||
resp.append(item.flagged)
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
async def main():
|
||||
moderation = Moderation()
|
||||
rsp = await moderation.amoderation(
|
||||
content=["I will kill you", "The weather is really nice today", "I want to hit you"]
|
||||
)
|
||||
print(rsp)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
|
|||
|
|
@ -7,14 +7,13 @@
|
|||
@Desc : OpenAI Text-to-Embedding OAS3 api, which provides text-to-embedding functionality.
|
||||
For more details, checkout: `https://platform.openai.com/docs/api-reference/embeddings/object`
|
||||
"""
|
||||
import asyncio
|
||||
from typing import List
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.config import CONFIG, Config
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
|
|
@ -29,15 +28,18 @@ class Embedding(BaseModel):
|
|||
|
||||
|
||||
class Usage(BaseModel):
|
||||
prompt_tokens: int
|
||||
total_tokens: int
|
||||
prompt_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
|
||||
|
||||
class ResultEmbedding(BaseModel):
|
||||
object: str
|
||||
data: List[Embedding]
|
||||
model: str
|
||||
usage: Usage
|
||||
class Config:
|
||||
alias = {"object_": "object"}
|
||||
|
||||
object_: str = ""
|
||||
data: List[Embedding] = []
|
||||
model: str = ""
|
||||
usage: Usage = Field(default_factory=Usage)
|
||||
|
||||
|
||||
class OpenAIText2Embedding:
|
||||
|
|
@ -45,7 +47,7 @@ class OpenAIText2Embedding:
|
|||
"""
|
||||
:param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`
|
||||
"""
|
||||
self.openai_api_key = openai_api_key if openai_api_key else CONFIG.OPENAI_API_KEY
|
||||
self.openai_api_key = openai_api_key or CONFIG.OPENAI_API_KEY
|
||||
|
||||
async def text_2_embedding(self, text, model="text-embedding-ada-002"):
|
||||
"""Text to embedding
|
||||
|
|
@ -55,15 +57,18 @@ class OpenAIText2Embedding:
|
|||
:return: A json object of :class:`ResultEmbedding` class if successful, otherwise `{}`.
|
||||
"""
|
||||
|
||||
proxies = {"proxy": CONFIG.openai_proxy} if CONFIG.openai_proxy else {}
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.openai_api_key}"}
|
||||
data = {"input": text, "model": model}
|
||||
url = "https://api.openai.com/v1/embeddings"
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post("https://api.openai.com/v1/embeddings", headers=headers, json=data) as response:
|
||||
return await response.json()
|
||||
async with session.post(url, headers=headers, json=data, **proxies) as response:
|
||||
data = await response.json()
|
||||
return ResultEmbedding(**data)
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"An error occurred:{e}")
|
||||
return {}
|
||||
return ResultEmbedding()
|
||||
|
||||
|
||||
# Export
|
||||
|
|
@ -80,11 +85,3 @@ async def oas3_openai_text_to_embedding(text, model="text-embedding-ada-002", op
|
|||
if not openai_api_key:
|
||||
openai_api_key = CONFIG.OPENAI_API_KEY
|
||||
return await OpenAIText2Embedding(openai_api_key).text_2_embedding(text, model=model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Config()
|
||||
loop = asyncio.new_event_loop()
|
||||
task = loop.create_task(oas3_openai_text_to_embedding("Panda emoji"))
|
||||
v = loop.run_until_complete(task)
|
||||
print(v)
|
||||
|
|
|
|||
|
|
@ -6,13 +6,10 @@
|
|||
@File : openai_text_to_image.py
|
||||
@Desc : OpenAI Text-to-Image OAS3 api, which provides text-to-image functionality.
|
||||
"""
|
||||
import asyncio
|
||||
import base64
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
from metagpt.config import Config
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
|
@ -23,7 +20,6 @@ class OpenAIText2Image:
|
|||
:param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`
|
||||
"""
|
||||
self._llm = LLM()
|
||||
self._client = self._llm.async_client
|
||||
|
||||
async def text_2_image(self, text, size_type="1024x1024"):
|
||||
"""Text to image
|
||||
|
|
@ -33,7 +29,7 @@ class OpenAIText2Image:
|
|||
:return: The image data is returned in Base64 encoding.
|
||||
"""
|
||||
try:
|
||||
result = await self._client.images.generate(prompt=text, n=1, size=size_type)
|
||||
result = await self._llm.aclient.images.generate(prompt=text, n=1, size=size_type)
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred:{e}")
|
||||
return ""
|
||||
|
|
@ -53,12 +49,11 @@ class OpenAIText2Image:
|
|||
async with session.get(url) as response:
|
||||
response.raise_for_status() # 如果是 4xx 或 5xx 响应,会引发异常
|
||||
image_data = await response.read()
|
||||
base64_image = base64.b64encode(image_data).decode("utf-8")
|
||||
return base64_image
|
||||
return image_data
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"An error occurred:{e}")
|
||||
return ""
|
||||
return 0
|
||||
|
||||
|
||||
# Export
|
||||
|
|
@ -72,11 +67,3 @@ async def oas3_openai_text_to_image(text, size_type: str = "1024x1024"):
|
|||
if not text:
|
||||
return ""
|
||||
return await OpenAIText2Image().text_2_image(text, size_type=size_type)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Config()
|
||||
loop = asyncio.new_event_loop()
|
||||
task = loop.create_task(oas3_openai_text_to_image("Panda emoji"))
|
||||
v = loop.run_until_complete(task)
|
||||
print(v)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import asyncio
|
|||
import importlib
|
||||
from concurrent import futures
|
||||
from copy import deepcopy
|
||||
from typing import Dict, Literal
|
||||
from typing import Literal
|
||||
|
||||
from selenium.webdriver.common.by import By
|
||||
from selenium.webdriver.support import expected_conditions as EC
|
||||
|
|
@ -33,7 +33,6 @@ class SeleniumWrapper:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
options: Dict,
|
||||
browser_type: Literal["chrome", "firefox", "edge", "ie"] | None = None,
|
||||
launch_kwargs: dict | None = None,
|
||||
*,
|
||||
|
|
|
|||
12
requirements-test.txt
Normal file
12
requirements-test.txt
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
# For unit test
|
||||
-r requirements.txt
|
||||
|
||||
connexion[uvicorn]~=3.0.5
|
||||
azure-cognitiveservices-speech~=1.31.0
|
||||
duckduckgo_search
|
||||
serpapi
|
||||
google
|
||||
httplib2
|
||||
google_api_python_client
|
||||
selenium
|
||||
webdriver_manager
|
||||
|
|
@ -9,7 +9,7 @@ faiss_cpu==1.7.4
|
|||
fire==0.4.0
|
||||
typer
|
||||
# godot==0.1.1
|
||||
# google_api_python_client==2.93.0
|
||||
# google_api_python_client==2.93.0 # Used by search_engine.py
|
||||
lancedb==0.1.16
|
||||
langchain==0.0.352
|
||||
loguru==0.6.0
|
||||
|
|
@ -40,12 +40,12 @@ typing_extensions==4.7.0
|
|||
libcst==1.0.1
|
||||
qdrant-client==1.4.0
|
||||
pytest-mock==3.11.1
|
||||
# open-interpreter==0.1.7; python_version>"3.9"
|
||||
# open-interpreter==0.1.7; python_version>"3.9" # Conflict with openai 1.x
|
||||
ta==0.10.2
|
||||
semantic-kernel==0.4.0.dev0
|
||||
wrapt==1.15.0
|
||||
#aiohttp_jinja2
|
||||
#azure-cognitiveservices-speech~=1.31.0
|
||||
# azure-cognitiveservices-speech~=1.31.0 # Used by metagpt/tools/azure_tts.py
|
||||
#aioboto3~=11.3.0
|
||||
#redis==4.3.5
|
||||
websocket-client==1.6.2
|
||||
|
|
@ -54,8 +54,8 @@ gitpython==3.1.40
|
|||
zhipuai==1.0.7
|
||||
socksio~=1.0.0
|
||||
gitignore-parser==0.1.9
|
||||
# connexion[swagger-ui]
|
||||
# connexion[uvicorn]~=3.0.5 # Used by metagpt/tools/hello.py
|
||||
websockets~=12.0
|
||||
networkx~=3.2.1
|
||||
google-generativeai==0.3.1
|
||||
playwright==1.40.0
|
||||
playwright==1.40.0
|
||||
|
|
|
|||
|
|
@ -34,10 +34,7 @@ async def test_invoice_ocr(invoice_path: str):
|
|||
@pytest.mark.parametrize(
|
||||
("invoice_path", "expected_result"),
|
||||
[
|
||||
(
|
||||
"../../data/invoices/invoice-1.pdf",
|
||||
[{"收款人": "小明", "城市": "深圳市", "总费用/元": "412.00", "开票日期": "2023年02月03日"}]
|
||||
),
|
||||
("../../data/invoices/invoice-1.pdf", [{"收款人": "小明", "城市": "深圳市", "总费用/元": "412.00", "开票日期": "2023年02月03日"}]),
|
||||
],
|
||||
)
|
||||
async def test_generate_table(invoice_path: str, expected_result: list[dict]):
|
||||
|
|
|
|||
|
|
@ -37,12 +37,10 @@ from metagpt.schema import Message
|
|||
Path("../../data/invoices/invoice-3.jpg"),
|
||||
Path("../../../data/invoice_table/invoice-3.xlsx"),
|
||||
{"收款人": "夏天", "城市": "福州", "总费用/元": 2462.00, "开票日期": "2023年08月26日"},
|
||||
)
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_invoice_ocr_assistant(
|
||||
query: str, invoice_path: Path, invoice_table_path: Path, expected_result: dict
|
||||
):
|
||||
async def test_invoice_ocr_assistant(query: str, invoice_path: Path, invoice_table_path: Path, expected_result: dict):
|
||||
invoice_path = Path.cwd() / invoice_path
|
||||
role = InvoiceOCRAssistant()
|
||||
await role.run(Message(content=query, instruct_content=InvoicePath(file_path=invoice_path)))
|
||||
|
|
@ -56,4 +54,3 @@ async def test_invoice_ocr_assistant(
|
|||
assert expected_result["城市"] in resp["城市"]
|
||||
assert int(expected_result["总费用/元"]) == int(resp["总费用/元"])
|
||||
assert expected_result["开票日期"] == resp["开票日期"]
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
@File : test_tutorial_assistant.py
|
||||
"""
|
||||
import shutil
|
||||
|
||||
import aiofiles
|
||||
import pytest
|
||||
|
||||
|
|
|
|||
|
|
@ -7,13 +7,20 @@
|
|||
@Modified By: mashenquan, 2023-8-9, add more text formatting options
|
||||
@Modified By: mashenquan, 2023-8-17, move to `tools` folder.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from azure.cognitiveservices.speech import ResultReason
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.tools.azure_tts import AzureTTS
|
||||
|
||||
|
||||
def test_azure_tts():
|
||||
@pytest.mark.asyncio
|
||||
async def test_azure_tts():
|
||||
# Prerequisites
|
||||
assert CONFIG.AZURE_TTS_SUBSCRIPTION_KEY and CONFIG.AZURE_TTS_SUBSCRIPTION_KEY != "YOUR_API_KEY"
|
||||
assert CONFIG.AZURE_TTS_REGION
|
||||
|
||||
azure_tts = AzureTTS(subscription_key="", region="")
|
||||
text = """
|
||||
女儿看见父亲走了进来,问道:
|
||||
|
|
@ -25,20 +32,19 @@ def test_azure_tts():
|
|||
“Writing a binary file in Python is similar to writing a regular text file, but you'll work with bytes instead of strings.”
|
||||
</mstts:express-as>
|
||||
"""
|
||||
path = CONFIG.workspace / "tts"
|
||||
path = CONFIG.workspace_path / "tts"
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
filename = path / "girl.wav"
|
||||
loop = asyncio.new_event_loop()
|
||||
v = loop.create_task(
|
||||
azure_tts.synthesize_speech(lang="zh-CN", voice="zh-CN-XiaomoNeural", text=text, output_file=str(filename))
|
||||
filename.unlink(missing_ok=True)
|
||||
result = await azure_tts.synthesize_speech(
|
||||
lang="zh-CN", voice="zh-CN-XiaomoNeural", text=text, output_file=str(filename)
|
||||
)
|
||||
result = loop.run_until_complete(v)
|
||||
|
||||
print(result)
|
||||
|
||||
# 运行需要先配置 SUBSCRIPTION_KEY
|
||||
# TODO: 这里如果要检验,还要额外加上对应的asr,才能确保前后生成是接近一致的,但现在还没有
|
||||
assert result
|
||||
assert result.audio_data
|
||||
assert result.reason == ResultReason.SynthesizingAudioCompleted
|
||||
assert filename.exists()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_azure_tts()
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
32
tests/metagpt/tools/test_hello.py
Normal file
32
tests/metagpt/tools/test_hello.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/26
|
||||
@Author : mashenquan
|
||||
@File : test_hello.py
|
||||
"""
|
||||
import asyncio
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hello():
|
||||
script_pathname = Path(__file__).parent / "../../../metagpt/tools/hello.py"
|
||||
process = subprocess.Popen(["python", str(script_pathname)])
|
||||
await asyncio.sleep(5)
|
||||
|
||||
url = "http://localhost:8082/openapi/greeting/dave"
|
||||
headers = {"accept": "text/plain", "Content-Type": "application/json"}
|
||||
data = {}
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
assert response.text == "Hello dave\n"
|
||||
|
||||
process.terminate()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
31
tests/metagpt/tools/test_iflytek_tts.py
Normal file
31
tests/metagpt/tools/test_iflytek_tts.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/26
|
||||
@Author : mashenquan
|
||||
@File : test_iflytek_tts.py
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.tools.iflytek_tts import oas3_iflytek_tts
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tts():
|
||||
# Prerequisites
|
||||
assert CONFIG.IFLYTEK_APP_ID
|
||||
assert CONFIG.IFLYTEK_API_KEY
|
||||
assert CONFIG.IFLYTEK_API_SECRET
|
||||
|
||||
result = await oas3_iflytek_tts(
|
||||
text="你好,hello",
|
||||
app_id=CONFIG.IFLYTEK_APP_ID,
|
||||
api_key=CONFIG.IFLYTEK_API_KEY,
|
||||
api_secret=CONFIG.IFLYTEK_API_SECRET,
|
||||
)
|
||||
assert result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
32
tests/metagpt/tools/test_metagpt_oas3_api_svc.py
Normal file
32
tests/metagpt/tools/test_metagpt_oas3_api_svc.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/26
|
||||
@Author : mashenquan
|
||||
@File : test_metagpt_oas3_api_svc.py
|
||||
"""
|
||||
import asyncio
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oas2_svc():
|
||||
script_pathname = Path(__file__).parent / "../../../metagpt/tools/metagpt_oas3_api_svc.py"
|
||||
process = subprocess.Popen(["python", str(script_pathname)])
|
||||
await asyncio.sleep(5)
|
||||
|
||||
url = "http://localhost:8080/openapi/greeting/dave"
|
||||
headers = {"accept": "text/plain", "Content-Type": "application/json"}
|
||||
data = {}
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
assert response.text == "Hello dave\n"
|
||||
|
||||
process.terminate()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
25
tests/metagpt/tools/test_metagpt_text_to_image.py
Normal file
25
tests/metagpt/tools/test_metagpt_text_to_image.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/26
|
||||
@Author : mashenquan
|
||||
@File : test_metagpt_text_to_image.py
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.tools.metagpt_text_to_image import oas3_metagpt_text_to_image
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_draw():
|
||||
# Prerequisites
|
||||
assert CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL
|
||||
|
||||
binary_data = await oas3_metagpt_text_to_image("Panda emoji")
|
||||
assert binary_data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
@ -8,6 +8,7 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.tools.moderation import Moderation
|
||||
|
||||
|
||||
|
|
@ -21,7 +22,23 @@ from metagpt.tools.moderation import Moderation
|
|||
],
|
||||
)
|
||||
async def test_amoderation(content):
|
||||
# Prerequisites
|
||||
assert CONFIG.OPENAI_API_KEY and CONFIG.OPENAI_API_KEY != "YOUR_API_KEY"
|
||||
assert not CONFIG.OPENAI_API_TYPE
|
||||
assert CONFIG.OPENAI_API_MODEL
|
||||
|
||||
moderation = Moderation()
|
||||
results = await moderation.amoderation(content=content)
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == len(content)
|
||||
|
||||
results = await moderation.amoderation_with_categories(content=content)
|
||||
assert isinstance(results, list)
|
||||
assert results
|
||||
for m in results:
|
||||
assert "flagged" in m
|
||||
assert "true_categories" in m
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
30
tests/metagpt/tools/test_openai_text_to_embedding.py
Normal file
30
tests/metagpt/tools/test_openai_text_to_embedding.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/26
|
||||
@Author : mashenquan
|
||||
@File : test_openai_text_to_embedding.py
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.tools.openai_text_to_embedding import oas3_openai_text_to_embedding
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embedding():
|
||||
# Prerequisites
|
||||
assert CONFIG.OPENAI_API_KEY and CONFIG.OPENAI_API_KEY != "YOUR_API_KEY"
|
||||
assert not CONFIG.OPENAI_API_TYPE
|
||||
assert CONFIG.OPENAI_API_MODEL
|
||||
|
||||
result = await oas3_openai_text_to_embedding("Panda emoji")
|
||||
assert result
|
||||
assert result.model
|
||||
assert len(result.data) > 0
|
||||
assert len(result.data[0].embedding) > 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
38
tests/metagpt/tools/test_openai_text_to_image.py
Normal file
38
tests/metagpt/tools/test_openai_text_to_image.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/26
|
||||
@Author : mashenquan
|
||||
@File : test_openai_text_to_image.py
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.tools.openai_text_to_image import (
|
||||
OpenAIText2Image,
|
||||
oas3_openai_text_to_image,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_draw():
|
||||
# Prerequisites
|
||||
assert CONFIG.OPENAI_API_KEY and CONFIG.OPENAI_API_KEY != "YOUR_API_KEY"
|
||||
assert not CONFIG.OPENAI_API_TYPE
|
||||
assert CONFIG.OPENAI_API_MODEL
|
||||
|
||||
binary_data = await oas3_openai_text_to_image("Panda emoji")
|
||||
assert binary_data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_image():
|
||||
data = await OpenAIText2Image.get_image_data(
|
||||
url="https://www.baidu.com/img/PCtm_d9c8750bed0b3c7d089fa7d55720d6cf.png"
|
||||
)
|
||||
assert data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
@ -3,7 +3,7 @@
|
|||
"""
|
||||
@Time : 2023/5/2 17:46
|
||||
@Author : alexanderwu
|
||||
@File : test_prompt_generator.py
|
||||
@File : test_prompt_writer.py
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
|
@ -9,6 +9,7 @@ from __future__ import annotations
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools import SearchEngineType
|
||||
from metagpt.tools.search_engine import SearchEngine
|
||||
|
|
@ -44,6 +45,15 @@ async def test_search_engine(
|
|||
max_results,
|
||||
as_string,
|
||||
):
|
||||
# Prerequisites
|
||||
if search_engine_typpe is SearchEngineType.SERPAPI_GOOGLE:
|
||||
assert CONFIG.SERPAPI_API_KEY and CONFIG.SERPAPI_API_KEY != "YOUR_API_KEY"
|
||||
elif search_engine_typpe is SearchEngineType.DIRECT_GOOGLE:
|
||||
assert CONFIG.GOOGLE_API_KEY and CONFIG.GOOGLE_API_KEY != "YOUR_API_KEY"
|
||||
assert CONFIG.GOOGLE_CSE_ID and CONFIG.GOOGLE_CSE_ID != "YOUR_CSE_ID"
|
||||
elif search_engine_typpe is SearchEngineType.SERPER_GOOGLE:
|
||||
assert CONFIG.SERPER_API_KEY and CONFIG.SERPER_API_KEY != "YOUR_API_KEY"
|
||||
|
||||
search_engine = SearchEngine(search_engine_typpe, run_func)
|
||||
rsp = await search_engine.run("metagpt", max_results=max_results, as_string=as_string)
|
||||
logger.info(rsp)
|
||||
|
|
@ -52,3 +62,7 @@ async def test_search_engine(
|
|||
else:
|
||||
assert isinstance(rsp, list)
|
||||
assert len(rsp) == max_results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -18,6 +18,10 @@ MASTER_KEY = "116Qavl2qpCYNEJNv5-e0RC9kncev1nr1gt7ybEGVLk"
|
|||
|
||||
@pytest.fixture()
|
||||
def search_engine_server():
|
||||
# Prerequisites
|
||||
# https://www.meilisearch.com/docs/learn/getting_started/installation
|
||||
# brew update && brew install meilisearch
|
||||
|
||||
meilisearch_process = subprocess.Popen(["meilisearch", "--master-key", f"{MASTER_KEY}"], stdout=subprocess.PIPE)
|
||||
time.sleep(3)
|
||||
yield
|
||||
|
|
@ -26,6 +30,10 @@ def search_engine_server():
|
|||
|
||||
|
||||
def test_meilisearch(search_engine_server):
|
||||
# Prerequisites
|
||||
# https://www.meilisearch.com/docs/learn/getting_started/installation
|
||||
# brew update && brew install meilisearch
|
||||
|
||||
search_engine = MeilisearchEngine(url="http://localhost:7700", token=MASTER_KEY)
|
||||
|
||||
# 假设有一个名为"books"的数据源,包含要添加的文档库
|
||||
|
|
@ -44,3 +52,7 @@ def test_meilisearch(search_engine_server):
|
|||
# 添加文档库到搜索引擎
|
||||
search_engine.add_documents(books_data_source, documents)
|
||||
logger.info(search_engine.search("Book 1"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -3,8 +3,11 @@
|
|||
"""
|
||||
@Time : 2023/4/30 21:44
|
||||
@Author : alexanderwu
|
||||
@File : test_ut_generator.py
|
||||
@File : test_ut_writer.py
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.const import API_QUESTIONS_PATH, SWAGGER_PATH, UT_PY_PATH
|
||||
from metagpt.tools.ut_writer import YFT_PROMPT_PREFIX, UTGenerator
|
||||
|
|
@ -12,7 +15,10 @@ from metagpt.tools.ut_writer import YFT_PROMPT_PREFIX, UTGenerator
|
|||
|
||||
class TestUTWriter:
|
||||
def test_api_to_ut_sample(self):
|
||||
# Prerequisites
|
||||
swagger_file = SWAGGER_PATH / "yft_swaggerApi.json"
|
||||
assert swagger_file.exists()
|
||||
|
||||
tags = ["测试"] # "智能合同导入", "律师审查", "ai合同审查", "草拟合同&律师在线审查", "合同审批", "履约管理", "签约公司"]
|
||||
# 这里在文件中手动加入了两个测试标签的API
|
||||
|
||||
|
|
@ -25,3 +31,12 @@ class TestUTWriter:
|
|||
ret = utg.generate_ut(include_tags=tags)
|
||||
# 后续加入对文件生成内容与数量的检验
|
||||
assert ret
|
||||
|
||||
pathname = Path(__file__).with_suffix(".tmp")
|
||||
utg.ask_gpt_and_save(question="question", tag="tag", fname=str(pathname))
|
||||
assert pathname.exists()
|
||||
pathname.unlink(missing_ok=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
@ -4,8 +4,8 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.config import Config
|
||||
from metagpt.tools import WebBrowserEngineType, web_browser_engine
|
||||
from metagpt.utils.parse_html import WebPage
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -18,14 +18,17 @@ from metagpt.tools import WebBrowserEngineType, web_browser_engine
|
|||
ids=["playwright", "selenium"],
|
||||
)
|
||||
async def test_scrape_web_page(browser_type, url, urls):
|
||||
conf = Config()
|
||||
browser = web_browser_engine.WebBrowserEngine(options=conf.runtime_options, engine=browser_type)
|
||||
browser = web_browser_engine.WebBrowserEngine(engine=browser_type)
|
||||
result = await browser.run(url)
|
||||
assert isinstance(result, str)
|
||||
assert "深度赋智" in result
|
||||
assert isinstance(result, WebPage)
|
||||
assert "MetaGPT" in result.inner_text
|
||||
|
||||
if urls:
|
||||
results = await browser.run(url, *urls)
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == len(urls) + 1
|
||||
assert all(("深度赋智" in i) for i in results)
|
||||
assert all(("MetaGPT" in i.inner_text) for i in results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -4,8 +4,9 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.config import Config
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.tools import web_browser_engine_playwright
|
||||
from metagpt.utils.parse_html import WebPage
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -19,25 +20,25 @@ from metagpt.tools import web_browser_engine_playwright
|
|||
ids=["chromium-normal", "firefox-normal", "webkit-normal"],
|
||||
)
|
||||
async def test_scrape_web_page(browser_type, use_proxy, kwagrs, url, urls, proxy, capfd):
|
||||
conf = Config()
|
||||
global_proxy = conf.global_proxy
|
||||
global_proxy = CONFIG.global_proxy
|
||||
try:
|
||||
if use_proxy:
|
||||
conf.global_proxy = proxy
|
||||
browser = web_browser_engine_playwright.PlaywrightWrapper(
|
||||
options=conf.runtime_options, browser_type=browser_type, **kwagrs
|
||||
)
|
||||
CONFIG.global_proxy = proxy
|
||||
browser = web_browser_engine_playwright.PlaywrightWrapper(browser_type=browser_type, **kwagrs)
|
||||
result = await browser.run(url)
|
||||
result = result.inner_text
|
||||
assert isinstance(result, str)
|
||||
assert "DeepWisdom" in result
|
||||
assert isinstance(result, WebPage)
|
||||
assert "MetaGPT" in result.inner_text
|
||||
|
||||
if urls:
|
||||
results = await browser.run(url, *urls)
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == len(urls) + 1
|
||||
assert all(("DeepWisdom" in i) for i in results)
|
||||
assert all(("MetaGPT" in i.inner_text) for i in results)
|
||||
if use_proxy:
|
||||
assert "Proxy:" in capfd.readouterr().out
|
||||
finally:
|
||||
conf.global_proxy = global_proxy
|
||||
CONFIG.global_proxy = global_proxy
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -4,8 +4,9 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.config import Config
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.tools import web_browser_engine_selenium
|
||||
from metagpt.utils.parse_html import WebPage
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -19,23 +20,28 @@ from metagpt.tools import web_browser_engine_selenium
|
|||
ids=["chrome-normal", "firefox-normal", "edge-normal"],
|
||||
)
|
||||
async def test_scrape_web_page(browser_type, use_proxy, url, urls, proxy, capfd):
|
||||
conf = Config()
|
||||
global_proxy = conf.global_proxy
|
||||
# Prerequisites
|
||||
# firefox, chrome, Microsoft Edge
|
||||
|
||||
global_proxy = CONFIG.global_proxy
|
||||
try:
|
||||
if use_proxy:
|
||||
conf.global_proxy = proxy
|
||||
browser = web_browser_engine_selenium.SeleniumWrapper(options=conf.runtime_options, browser_type=browser_type)
|
||||
CONFIG.global_proxy = proxy
|
||||
browser = web_browser_engine_selenium.SeleniumWrapper(browser_type=browser_type)
|
||||
result = await browser.run(url)
|
||||
result = result.inner_text
|
||||
assert isinstance(result, str)
|
||||
assert "Deepwisdom" in result
|
||||
assert isinstance(result, WebPage)
|
||||
assert "MetaGPT" in result.inner_text
|
||||
|
||||
if urls:
|
||||
results = await browser.run(url, *urls)
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == len(urls) + 1
|
||||
assert all(("Deepwisdom" in i.inner_text) for i in results)
|
||||
assert all(("MetaGPT" in i.inner_text) for i in results)
|
||||
if use_proxy:
|
||||
assert "Proxy:" in capfd.readouterr().out
|
||||
finally:
|
||||
conf.global_proxy = global_proxy
|
||||
CONFIG.global_proxy = global_proxy
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
"""
|
||||
|
||||
import os
|
||||
import platform
|
||||
from typing import Any, Set
|
||||
|
||||
import pytest
|
||||
|
|
@ -17,7 +18,7 @@ from metagpt.actions import RunCode
|
|||
from metagpt.const import get_metagpt_root
|
||||
from metagpt.roles.tutorial_assistant import TutorialAssistant
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import any_to_str, any_to_str_set
|
||||
from metagpt.utils.common import any_to_str, any_to_str_set, check_cmd_exists
|
||||
|
||||
|
||||
class TestGetProjectRoot:
|
||||
|
|
@ -28,13 +29,12 @@ class TestGetProjectRoot:
|
|||
|
||||
def test_get_project_root(self):
|
||||
project_root = get_metagpt_root()
|
||||
assert project_root.name == "metagpt"
|
||||
assert project_root.name == "MetaGPT"
|
||||
|
||||
def test_get_root_exception(self):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
self.change_etc_dir()
|
||||
get_metagpt_root()
|
||||
assert str(exc_info.value) == "Project root not found."
|
||||
self.change_etc_dir()
|
||||
project_root = get_metagpt_root()
|
||||
assert project_root
|
||||
|
||||
def test_any_to_str(self):
|
||||
class Input(BaseModel):
|
||||
|
|
@ -65,8 +65,8 @@ class TestGetProjectRoot:
|
|||
want={"metagpt.roles.tutorial_assistant.TutorialAssistant", "metagpt.actions.run_code.RunCode", "a"},
|
||||
),
|
||||
Input(
|
||||
x={TutorialAssistant, RunCode(), "a"},
|
||||
want={"metagpt.roles.tutorial_assistant.TutorialAssistant", "metagpt.actions.run_code.RunCode", "a"},
|
||||
x={TutorialAssistant, "a"},
|
||||
want={"metagpt.roles.tutorial_assistant.TutorialAssistant", "a"},
|
||||
),
|
||||
Input(
|
||||
x=(TutorialAssistant, RunCode(), "a"),
|
||||
|
|
@ -77,6 +77,25 @@ class TestGetProjectRoot:
|
|||
v = any_to_str_set(i.x)
|
||||
assert v == i.want
|
||||
|
||||
def test_check_cmd_exists(self):
|
||||
class Input(BaseModel):
|
||||
command: str
|
||||
platform: str
|
||||
|
||||
inputs = [
|
||||
{"command": "cat", "platform": "linux"},
|
||||
{"command": "ls", "platform": "linux"},
|
||||
{"command": "mspaint", "platform": "windows"},
|
||||
]
|
||||
plat = "windows" if platform.system().lower() == "windows" else "linux"
|
||||
for i in inputs:
|
||||
seed = Input(**i)
|
||||
result = check_cmd_exists(seed.command)
|
||||
if plat == seed.platform:
|
||||
assert result == 0
|
||||
else:
|
||||
assert result != 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue