update provider unittests

This commit is contained in:
better629 2023-12-25 18:00:51 +08:00
parent 44eec631ea
commit 454e6164fb
18 changed files with 460 additions and 78 deletions

View file

@ -7,13 +7,13 @@
"""
import anthropic
from anthropic import Anthropic
from anthropic import Anthropic, AsyncAnthropic
from metagpt.config import CONFIG
class Claude2:
def ask(self, prompt):
def ask(self, prompt: str) -> str:
client = Anthropic(api_key=CONFIG.anthropic_api_key)
res = client.completions.create(
@ -23,10 +23,10 @@ class Claude2:
)
return res.completion
async def aask(self, prompt):
client = Anthropic(api_key=CONFIG.anthropic_api_key)
async def aask(self, prompt: str) -> str:
aclient = AsyncAnthropic(api_key=CONFIG.anthropic_api_key)
res = client.completions.create(
res = await aclient.completions.create(
model="claude-2",
prompt=f"{anthropic.HUMAN_PROMPT} {prompt} {anthropic.AI_PROMPT}",
max_tokens_to_sample=1000,

View file

@ -162,7 +162,7 @@ class BaseGPTAPI(BaseChatbot):
def messages_to_prompt(self, messages: list[dict]):
"""[{"role": "user", "content": msg}] to user: <msg> etc."""
return "\n".join([f"{i['role']}: {i['content']}" for i in messages])
return "\n".join([f"{i.role}: {i.content}" for i in messages])
def messages_to_dict(self, messages):
"""objects to [{"role": "user", "content": msg}] etc."""

View file

@ -133,7 +133,9 @@ class FireWorksGPTAPI(OpenAIGPTAPI):
retry=retry_if_exception_type(APIConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str:
async def acompletion_text(
self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3
) -> str:
"""when streaming, print each token in place."""
if stream:
return await self._achat_completion_stream(messages)

View file

@ -79,6 +79,9 @@ class GeminiGPTAPI(BaseGPTAPI):
except Exception as e:
logger.error(f"google gemini updats costs failed! exp: {e}")
def close(self):
pass
def get_choice_text(self, resp: GenerateContentResponse) -> str:
return resp.text
@ -133,7 +136,9 @@ class GeminiGPTAPI(BaseGPTAPI):
retry=retry_if_exception_type(ConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
async def acompletion_text(
self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3
) -> str:
"""response in async with stream or non-stream mode"""
if stream:
return await self._achat_completion_stream(messages)

View file

@ -57,6 +57,9 @@ class OllamaGPTAPI(BaseGPTAPI):
self.model = config.ollama_api_model
def close(self):
pass
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream}
return kwargs
@ -144,7 +147,9 @@ class OllamaGPTAPI(BaseGPTAPI):
retry=retry_if_exception_type(ConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
async def acompletion_text(
self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3
) -> str:
"""response in async with stream or non-stream mode"""
if stream:
return await self._achat_completion_stream(messages)

View file

@ -26,16 +26,19 @@ from metagpt.provider.llm_provider_registry import register_provider
@register_provider(LLMProviderEnum.SPARK)
class SparkAPI(BaseGPTAPI):
class SparkGPTAPI(BaseGPTAPI):
def __init__(self):
logger.warning("当前方法无法支持异步运行。当你使用acompletion时并不能并行访问。")
def close(self):
pass
def ask(self, msg: str) -> str:
message = [self._default_system_msg(), self._user_msg(msg)]
rsp = self.completion(message)
return rsp
async def aask(self, msg: str, system_msgs: Optional[list[str]] = None) -> str:
async def aask(self, msg: str, system_msgs: Optional[list[str]] = None, stream: bool = True) -> str:
if system_msgs:
message = self._system_msgs(system_msgs) + [self._user_msg(msg)]
else:
@ -47,7 +50,9 @@ class SparkAPI(BaseGPTAPI):
def get_choice_text(self, rsp: dict) -> str:
return rsp["payload"]["choices"]["text"][-1]["content"]
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
async def acompletion_text(
self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3
) -> str:
# 不支持
logger.error("该功能禁用。")
w = GetMessageFromWeb(messages)

View file

@ -64,6 +64,9 @@ class ZhiPuAIGPTAPI(BaseGPTAPI):
except Exception as e:
logger.error(f"zhipuai updats costs failed! exp: {e}")
def close(self):
pass
def get_choice_text(self, resp: dict) -> str:
"""get the first text of choice from llm response"""
assist_msg = resp.get("data", {}).get("choices", [{"role": "error"}])[-1]
@ -131,6 +134,6 @@ class ZhiPuAIGPTAPI(BaseGPTAPI):
async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str:
"""response in async with stream or non-stream mode"""
if stream:
return await self._achat_completion_stream(messages, timeout=timeout)
return await self._achat_completion_stream(messages)
resp = await self._achat_completion(messages)
return self.get_choice_text(resp)

View file

@ -0,0 +1,29 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the unittest of Claude2
import pytest
from metagpt.provider.anthropic_api import Claude2
prompt = "who are you"
resp = "I'am Claude2"
def mock_llm_ask(self, msg: str) -> str:
return resp
async def mock_llm_aask(self, msg: str) -> str:
return resp
def test_claude2_ask(mocker):
mocker.patch("metagpt.provider.anthropic_api.Claude2.ask", mock_llm_ask)
assert resp == Claude2().ask(prompt)
@pytest.mark.asyncio
async def test_claude2_aask(mocker):
mocker.patch("metagpt.provider.anthropic_api.Claude2.aask", mock_llm_aask)
assert resp == await Claude2().aask(prompt)

View file

@ -6,10 +6,106 @@
@File : test_base_gpt_api.py
"""
import pytest
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.schema import Message
default_chat_resp = {
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "I'am GPT",
},
"finish_reason": "stop",
}
]
}
prompt_msg = "who are you"
resp_content = default_chat_resp["choices"][0]["message"]["content"]
def test_message():
message = Message(role="user", content="wtf")
class MockBaseGPTAPI(BaseGPTAPI):
def completion(self, messages: list[dict], timeout=3):
return default_chat_resp
async def acompletion(self, messages: list[dict], timeout=3):
return default_chat_resp
async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str:
return resp_content
async def close(self):
return default_chat_resp
def test_base_gpt_api():
message = Message(role="user", content="hello")
assert "role" in message.to_dict()
assert "user" in str(message)
base_gpt_api = MockBaseGPTAPI()
msg_prompt = base_gpt_api.messages_to_prompt([message])
assert msg_prompt == "user: hello"
msg_dict = base_gpt_api.messages_to_dict([message])
assert msg_dict == [{"role": "user", "content": "hello"}]
openai_funccall_resp = {
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "test",
"tool_calls": [
{
"id": "call_Y5r6Ddr2Qc2ZrqgfwzPX5l72",
"type": "function",
"function": {
"name": "execute",
"arguments": '{\n "language": "python",\n "code": "print(\'Hello, World!\')"\n}',
},
}
],
},
"finish_reason": "stop",
}
]
}
func: dict = base_gpt_api.get_choice_function(openai_funccall_resp)
assert func == {
"name": "execute",
"arguments": '{\n "language": "python",\n "code": "print(\'Hello, World!\')"\n}',
}
func_args: dict = base_gpt_api.get_choice_function_arguments(openai_funccall_resp)
assert func_args == {"language": "python", "code": "print('Hello, World!')"}
choice_text = base_gpt_api.get_choice_text(openai_funccall_resp)
assert choice_text == openai_funccall_resp["choices"][0]["message"]["content"]
resp = base_gpt_api.ask(prompt_msg)
assert resp == resp_content
resp = base_gpt_api.ask_batch([prompt_msg])
assert resp == resp_content
resp = base_gpt_api.ask_code([prompt_msg])
assert resp == resp_content
@pytest.mark.asyncio
async def test_async_base_gpt_api():
base_gpt_api = MockBaseGPTAPI()
resp = await base_gpt_api.aask(prompt_msg)
assert resp == resp_content
resp = await base_gpt_api.aask_batch([prompt_msg])
assert resp == resp_content
resp = await base_gpt_api.aask_code([prompt_msg])
assert resp == resp_content

View file

@ -10,41 +10,82 @@ from openai.types.chat.chat_completion import (
)
from openai.types.completion_usage import CompletionUsage
from metagpt.provider.fireworks_api import FireWorksGPTAPI
from metagpt.provider.fireworks_api import (
MODEL_GRADE_TOKEN_COSTS,
FireworksCostManager,
FireWorksGPTAPI,
)
resp_content = "I'm fireworks"
default_resp = ChatCompletion(
id="cmpl-a6652c1bb181caae8dd19ad8",
model="accounts/fireworks/models/llama-v2-13b-chat",
object="chat.completion",
created=1703300855,
choices=[
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(role="assistant", content="I'm fireworks"))
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(role="assistant", content=resp_content))
],
usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202),
)
messages = [{"role": "user", "content": "who are you"}]
prompt_msg = "who are you"
messages = [{"role": "user", "content": prompt_msg}]
def mock_llm_ask(self, messages: list[dict]) -> ChatCompletion:
def test_fireworks_costmanager():
cost_manager = FireworksCostManager()
assert MODEL_GRADE_TOKEN_COSTS["-1"] == cost_manager.model_grade_token_costs("test")
assert MODEL_GRADE_TOKEN_COSTS["-1"] == cost_manager.model_grade_token_costs("xxx-81b-chat")
assert MODEL_GRADE_TOKEN_COSTS["16"] == cost_manager.model_grade_token_costs("llama-v2-13b-chat")
assert MODEL_GRADE_TOKEN_COSTS["16"] == cost_manager.model_grade_token_costs("xxx-15.5b-chat")
assert MODEL_GRADE_TOKEN_COSTS["16"] == cost_manager.model_grade_token_costs("xxx-16b-chat")
assert MODEL_GRADE_TOKEN_COSTS["80"] == cost_manager.model_grade_token_costs("xxx-80b-chat")
assert MODEL_GRADE_TOKEN_COSTS["mixtral-8x7b"] == cost_manager.model_grade_token_costs("mixtral-8x7b-chat")
def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> ChatCompletion:
return default_resp
async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> ChatCompletion:
return default_resp
async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
return default_resp.choices[0].message.content
def test_fireworks_completion(mocker):
mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.completion", mock_llm_ask)
mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.completion", mock_llm_completion)
fireworks_gpt = FireWorksGPTAPI()
resp = FireWorksGPTAPI().completion(messages)
assert "fireworks" in resp.choices[0].message.content
resp = fireworks_gpt.completion(messages)
assert resp.choices[0].message.content == resp_content
async def mock_llm_aask(self, messgaes: list[dict], stream: bool = False) -> ChatCompletion:
return default_resp
resp = fireworks_gpt.ask(prompt_msg)
assert resp == resp_content
@pytest.mark.asyncio
async def test_fireworks_acompletion(mocker):
mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.acompletion", mock_llm_aask)
mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.acompletion", mock_llm_acompletion)
mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI._achat_completion", mock_llm_acompletion)
mocker.patch(
"metagpt.provider.fireworks_api.FireWorksGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream
)
fireworks_gpt = FireWorksGPTAPI()
resp = await FireWorksGPTAPI().acompletion(messages, stream=False)
resp = await fireworks_gpt.acompletion(messages, stream=False)
assert resp.choices[0].message.content in resp_content
assert "fireworks" in resp.choices[0].message.content
resp = await fireworks_gpt.aask(prompt_msg, stream=False)
assert resp == resp_content
resp = await fireworks_gpt.acompletion_text(messages, stream=False)
assert resp == resp_content
resp = await fireworks_gpt.acompletion_text(messages, stream=True)
assert resp == resp_content
resp = await fireworks_gpt.aask(prompt_msg)
assert resp == resp_content

View file

@ -0,0 +1,20 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the unittest of APIRequestor
import pytest
from metagpt.provider.general_api_requestor import GeneralAPIRequestor
api_requestor = GeneralAPIRequestor(base_url="http://www.baidu.com")
def test_api_requestor():
resp, _, _ = api_requestor.request(method="get", url="/s?wd=baidu")
assert b"baidu" in resp
@pytest.mark.asyncio
async def test_async_api_requestor():
resp, _, _ = await api_requestor.arequest(method="get", url="/s?wd=baidu")
assert b"baidu" in resp

View file

@ -9,33 +9,62 @@ import pytest
from metagpt.provider.google_gemini_api import GeminiGPTAPI
messages = [{"role": "user", "parts": "who are you"}]
@dataclass
class MockGeminiResponse(ABC):
text: str
default_resp = MockGeminiResponse(text="I'm gemini from google")
prompt_msg = "who are you"
messages = [{"role": "user", "parts": prompt_msg}]
resp_content = "I'm gemini from google"
default_resp = MockGeminiResponse(text=resp_content)
def mock_llm_ask(self, messages: list[dict]) -> MockGeminiResponse:
def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> MockGeminiResponse:
return default_resp
async def mock_llm_acompletion(
self, messgaes: list[dict], stream: bool = False, timeout: int = 60
) -> MockGeminiResponse:
return default_resp
async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
return resp_content
def test_gemini_completion(mocker):
mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.completion", mock_llm_ask)
resp = GeminiGPTAPI().completion(messages)
assert resp.text == default_resp.text
mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.completion", mock_llm_completion)
gemini_gpt = GeminiGPTAPI()
resp = gemini_gpt.completion(messages)
assert resp.text == resp_content
async def mock_llm_aask(self, messgaes: list[dict]) -> MockGeminiResponse:
return default_resp
resp = gemini_gpt.ask(prompt_msg)
assert resp == resp_content
@pytest.mark.asyncio
async def test_gemini_acompletion(mocker):
mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.acompletion", mock_llm_aask)
resp = await GeminiGPTAPI().acompletion(messages)
mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.acompletion", mock_llm_acompletion)
mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI._achat_completion", mock_llm_acompletion)
mocker.patch(
"metagpt.provider.google_gemini_api.GeminiGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream
)
gemini_gpt = GeminiGPTAPI()
resp = await gemini_gpt.acompletion(messages)
assert resp.text == default_resp.text
resp = await gemini_gpt.aask(prompt_msg, stream=False)
assert resp == resp_content
resp = await gemini_gpt.acompletion_text(messages, stream=False)
assert resp == resp_content
resp = await gemini_gpt.acompletion_text(messages, stream=True)
assert resp == resp_content
resp = await gemini_gpt.aask(prompt_msg)
assert resp == resp_content

View file

@ -0,0 +1,38 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the unittest of HumanProvider
import pytest
from metagpt.provider.human_provider import HumanProvider
resp_content = "test"
def mock_llm_ask(msg: str, timeout: int = 3) -> str:
return resp_content
async def mock_llm_aask(msg: str, timeout: int = 3) -> str:
return mock_llm_ask(msg)
def test_human_provider(mocker):
mocker.patch("metagpt.provider.human_provider.HumanProvider.ask", mock_llm_ask)
human_provider = HumanProvider()
assert resp_content == human_provider.ask(None)
assert not human_provider.completion(messages=[])
@pytest.mark.asyncio
async def test_async_human_provider(mocker):
mocker.patch("metagpt.provider.human_provider.HumanProvider.aask", mock_llm_aask)
human_provider = HumanProvider()
resp = await human_provider.aask(None)
assert resp_content == resp
resp = await human_provider.acompletion([])
assert not resp

View file

@ -5,11 +5,11 @@
@Author : mashenquan
@File : test_metagpt_llm_api.py
"""
from metagpt.provider.metagpt_llm_api import MetaGPTLLMAPI
from metagpt.provider.metagpt_api import MetaGPTAPI
def test_metagpt():
llm = MetaGPTLLMAPI()
llm = MetaGPTAPI()
assert llm

View file

@ -4,30 +4,58 @@
import pytest
from metagpt.config import CONFIG
from metagpt.provider.ollama_api import OllamaGPTAPI
messages = [{"role": "user", "content": "who are you"}]
prompt_msg = "who are you"
messages = [{"role": "user", "content": prompt_msg}]
resp_content = "I'm ollama"
default_resp = {"message": {"role": "assistant", "content": resp_content}}
CONFIG.ollama_api_base = "http://xxx"
default_resp = {"message": {"role": "assisant", "content": "I'm ollama"}}
def mock_llm_ask(self, messages: list[dict]) -> dict:
def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> dict:
return default_resp
async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> dict:
return default_resp
async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
return resp_content
def test_gemini_completion(mocker):
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.completion", mock_llm_ask)
resp = OllamaGPTAPI().completion(messages)
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.completion", mock_llm_completion)
ollama_gpt = OllamaGPTAPI()
resp = ollama_gpt.completion(messages)
assert resp["message"]["content"] == default_resp["message"]["content"]
async def mock_llm_aask(self, messgaes: list[dict]) -> dict:
return default_resp
resp = ollama_gpt.ask(prompt_msg)
assert resp == resp_content
@pytest.mark.asyncio
async def test_gemini_acompletion(mocker):
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.acompletion", mock_llm_aask)
resp = await OllamaGPTAPI().acompletion(messages)
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.acompletion", mock_llm_acompletion)
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI._achat_completion", mock_llm_acompletion)
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream)
ollama_gpt = OllamaGPTAPI()
resp = await ollama_gpt.acompletion(messages)
assert resp["message"]["content"] == default_resp["message"]["content"]
resp = await ollama_gpt.aask(prompt_msg, stream=False)
assert resp == resp_content
resp = await ollama_gpt.acompletion_text(messages, stream=False)
assert resp == resp_content
resp = await ollama_gpt.acompletion_text(messages, stream=True)
assert resp == resp_content
resp = await ollama_gpt.aask(prompt_msg)
assert resp == resp_content

View file

@ -85,14 +85,23 @@ def test_ask_code_list_str():
class TestOpenAI:
@pytest.fixture
def config(self):
return Mock(openai_api_key="test_key", openai_base_url="test_url", openai_proxy=None, openai_api_type="other")
return Mock(
openai_api_key="test_key",
OPENAI_API_KEY="test_key",
openai_base_url="test_url",
OPENAI_BASE_URL="test_url",
openai_proxy=None,
openai_api_type="other",
)
@pytest.fixture
def config_azure(self):
return Mock(
openai_api_key="test_key",
OPENAI_API_KEY="test_key",
openai_api_version="test_version",
openai_base_url="test_url",
OPENAI_BASE_URL="test_url",
openai_proxy=None,
openai_api_type="azure",
)
@ -101,7 +110,9 @@ class TestOpenAI:
def config_proxy(self):
return Mock(
openai_api_key="test_key",
OPENAI_API_KEY="test_key",
openai_base_url="test_url",
OPENAI_BASE_URL="test_url",
openai_proxy="http://proxy.com",
openai_api_type="other",
)
@ -110,8 +121,10 @@ class TestOpenAI:
def config_azure_proxy(self):
return Mock(
openai_api_key="test_key",
OPENAI_API_KEY="test_key",
openai_api_version="test_version",
openai_base_url="test_url",
OPENAI_BASE_URL="test_url",
openai_proxy="http://proxy.com",
openai_api_type="azure",
)
@ -129,8 +142,8 @@ class TestOpenAI:
instance = OpenAIGPTAPI()
instance.config = config_azure
kwargs, async_kwargs = instance._make_client_kwargs()
assert kwargs == {"api_key": "test_key", "api_version": "test_version", "azure_endpoint": "test_url"}
assert async_kwargs == {"api_key": "test_key", "api_version": "test_version", "azure_endpoint": "test_url"}
assert kwargs == {"api_key": "test_key", "base_url": "test_url"}
assert async_kwargs == {"api_key": "test_key", "base_url": "test_url"}
assert "http_client" not in kwargs
assert "http_client" not in async_kwargs

View file

@ -1,11 +1,51 @@
from metagpt.logs import logger
from metagpt.provider.spark_api import SparkAPI
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the unittest of spark api
import pytest
from metagpt.provider.spark_api import SparkGPTAPI
prompt_msg = "who are you"
resp_content = "I'm Spark"
def test_message():
llm = SparkAPI()
def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> str:
return resp_content
logger.info(llm.ask('只回答"收到了"这三个字。'))
result = llm.ask("写一篇五百字的日记")
logger.info(result)
assert len(result) > 100
async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> str:
return resp_content
def test_spark_completion(mocker):
mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.completion", mock_llm_completion)
spark_gpt = SparkGPTAPI()
resp = spark_gpt.completion([])
assert resp == resp_content
resp = spark_gpt.ask(prompt_msg)
assert resp == resp_content
@pytest.mark.asyncio
async def test_spark_acompletion(mocker):
mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.acompletion", mock_llm_acompletion)
mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.acompletion_text", mock_llm_acompletion)
spark_gpt = SparkGPTAPI()
resp = await spark_gpt.acompletion([], stream=False)
assert resp == resp_content
resp = await spark_gpt.aask(prompt_msg, stream=False)
assert resp == resp_content
resp = await spark_gpt.acompletion_text([], stream=False)
assert resp == resp_content
resp = await spark_gpt.acompletion_text([], stream=True)
assert resp == resp_content
resp = await spark_gpt.aask(prompt_msg)
assert resp == resp_content

View file

@ -4,34 +4,62 @@
import pytest
from metagpt.config import CONFIG
from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI
default_resp = {"code": 200, "data": {"choices": [{"role": "assistant", "content": "I'm chatglm-turbo"}]}}
CONFIG.zhipuai_api_key = "xxx"
messages = [{"role": "user", "content": "who are you"}]
prompt_msg = "who are you"
messages = [{"role": "user", "content": prompt_msg}]
resp_content = "I'm chatglm-turbo"
default_resp = {"code": 200, "data": {"choices": [{"role": "assistant", "content": resp_content}]}}
def mock_llm_ask(self, messages: list[dict]) -> dict:
def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> dict:
return default_resp
async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> dict:
return default_resp
async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
return resp_content
def test_zhipuai_completion(mocker):
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.completion", mock_llm_ask)
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.completion", mock_llm_completion)
zhipu_gpt = ZhiPuAIGPTAPI()
resp = ZhiPuAIGPTAPI().completion(messages)
resp = zhipu_gpt.completion(messages)
assert resp["code"] == 200
assert "chatglm-turbo" in resp["data"]["choices"][0]["content"]
assert resp["data"]["choices"][0]["content"] == resp_content
async def mock_llm_aask(self, messgaes: list[dict], stream: bool = False) -> dict:
return default_resp
resp = zhipu_gpt.ask(prompt_msg)
assert resp == resp_content
@pytest.mark.asyncio
async def test_zhipuai_acompletion(mocker):
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.acompletion_text", mock_llm_aask)
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.acompletion", mock_llm_acompletion)
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI._achat_completion", mock_llm_acompletion)
mocker.patch(
"metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream
)
zhipu_gpt = ZhiPuAIGPTAPI()
resp = await ZhiPuAIGPTAPI().acompletion_text(messages, stream=False)
resp = await zhipu_gpt.acompletion(messages)
assert resp["data"]["choices"][0]["content"] == resp_content
assert resp["code"] == 200
assert "chatglm-turbo" in resp["data"]["choices"][0]["content"]
resp = await zhipu_gpt.aask(prompt_msg, stream=False)
assert resp == resp_content
resp = await zhipu_gpt.acompletion_text(messages, stream=False)
assert resp == resp_content
resp = await zhipu_gpt.acompletion_text(messages, stream=True)
assert resp == resp_content
resp = await zhipu_gpt.aask(prompt_msg)
assert resp == resp_content