mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
update provider unittests
This commit is contained in:
parent
44eec631ea
commit
454e6164fb
18 changed files with 460 additions and 78 deletions
|
|
@ -7,13 +7,13 @@
|
|||
"""
|
||||
|
||||
import anthropic
|
||||
from anthropic import Anthropic
|
||||
from anthropic import Anthropic, AsyncAnthropic
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
|
||||
|
||||
class Claude2:
|
||||
def ask(self, prompt):
|
||||
def ask(self, prompt: str) -> str:
|
||||
client = Anthropic(api_key=CONFIG.anthropic_api_key)
|
||||
|
||||
res = client.completions.create(
|
||||
|
|
@ -23,10 +23,10 @@ class Claude2:
|
|||
)
|
||||
return res.completion
|
||||
|
||||
async def aask(self, prompt):
|
||||
client = Anthropic(api_key=CONFIG.anthropic_api_key)
|
||||
async def aask(self, prompt: str) -> str:
|
||||
aclient = AsyncAnthropic(api_key=CONFIG.anthropic_api_key)
|
||||
|
||||
res = client.completions.create(
|
||||
res = await aclient.completions.create(
|
||||
model="claude-2",
|
||||
prompt=f"{anthropic.HUMAN_PROMPT} {prompt} {anthropic.AI_PROMPT}",
|
||||
max_tokens_to_sample=1000,
|
||||
|
|
|
|||
|
|
@ -162,7 +162,7 @@ class BaseGPTAPI(BaseChatbot):
|
|||
|
||||
def messages_to_prompt(self, messages: list[dict]):
|
||||
"""[{"role": "user", "content": msg}] to user: <msg> etc."""
|
||||
return "\n".join([f"{i['role']}: {i['content']}" for i in messages])
|
||||
return "\n".join([f"{i.role}: {i.content}" for i in messages])
|
||||
|
||||
def messages_to_dict(self, messages):
|
||||
"""objects to [{"role": "user", "content": msg}] etc."""
|
||||
|
|
|
|||
|
|
@ -133,7 +133,9 @@ class FireWorksGPTAPI(OpenAIGPTAPI):
|
|||
retry=retry_if_exception_type(APIConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str:
|
||||
async def acompletion_text(
|
||||
self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3
|
||||
) -> str:
|
||||
"""when streaming, print each token in place."""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
|
|
|
|||
|
|
@ -79,6 +79,9 @@ class GeminiGPTAPI(BaseGPTAPI):
|
|||
except Exception as e:
|
||||
logger.error(f"google gemini updats costs failed! exp: {e}")
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def get_choice_text(self, resp: GenerateContentResponse) -> str:
|
||||
return resp.text
|
||||
|
||||
|
|
@ -133,7 +136,9 @@ class GeminiGPTAPI(BaseGPTAPI):
|
|||
retry=retry_if_exception_type(ConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
|
||||
async def acompletion_text(
|
||||
self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3
|
||||
) -> str:
|
||||
"""response in async with stream or non-stream mode"""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
|
|
|
|||
|
|
@ -57,6 +57,9 @@ class OllamaGPTAPI(BaseGPTAPI):
|
|||
|
||||
self.model = config.ollama_api_model
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
|
||||
kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream}
|
||||
return kwargs
|
||||
|
|
@ -144,7 +147,9 @@ class OllamaGPTAPI(BaseGPTAPI):
|
|||
retry=retry_if_exception_type(ConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
|
||||
async def acompletion_text(
|
||||
self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3
|
||||
) -> str:
|
||||
"""response in async with stream or non-stream mode"""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
|
|
|
|||
|
|
@ -26,16 +26,19 @@ from metagpt.provider.llm_provider_registry import register_provider
|
|||
|
||||
|
||||
@register_provider(LLMProviderEnum.SPARK)
|
||||
class SparkAPI(BaseGPTAPI):
|
||||
class SparkGPTAPI(BaseGPTAPI):
|
||||
def __init__(self):
|
||||
logger.warning("当前方法无法支持异步运行。当你使用acompletion时,并不能并行访问。")
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def ask(self, msg: str) -> str:
|
||||
message = [self._default_system_msg(), self._user_msg(msg)]
|
||||
rsp = self.completion(message)
|
||||
return rsp
|
||||
|
||||
async def aask(self, msg: str, system_msgs: Optional[list[str]] = None) -> str:
|
||||
async def aask(self, msg: str, system_msgs: Optional[list[str]] = None, stream: bool = True) -> str:
|
||||
if system_msgs:
|
||||
message = self._system_msgs(system_msgs) + [self._user_msg(msg)]
|
||||
else:
|
||||
|
|
@ -47,7 +50,9 @@ class SparkAPI(BaseGPTAPI):
|
|||
def get_choice_text(self, rsp: dict) -> str:
|
||||
return rsp["payload"]["choices"]["text"][-1]["content"]
|
||||
|
||||
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
|
||||
async def acompletion_text(
|
||||
self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3
|
||||
) -> str:
|
||||
# 不支持
|
||||
logger.error("该功能禁用。")
|
||||
w = GetMessageFromWeb(messages)
|
||||
|
|
|
|||
|
|
@ -64,6 +64,9 @@ class ZhiPuAIGPTAPI(BaseGPTAPI):
|
|||
except Exception as e:
|
||||
logger.error(f"zhipuai updats costs failed! exp: {e}")
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def get_choice_text(self, resp: dict) -> str:
|
||||
"""get the first text of choice from llm response"""
|
||||
assist_msg = resp.get("data", {}).get("choices", [{"role": "error"}])[-1]
|
||||
|
|
@ -131,6 +134,6 @@ class ZhiPuAIGPTAPI(BaseGPTAPI):
|
|||
async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str:
|
||||
"""response in async with stream or non-stream mode"""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages, timeout=timeout)
|
||||
return await self._achat_completion_stream(messages)
|
||||
resp = await self._achat_completion(messages)
|
||||
return self.get_choice_text(resp)
|
||||
|
|
|
|||
29
tests/metagpt/provider/test_anthropic_api.py
Normal file
29
tests/metagpt/provider/test_anthropic_api.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of Claude2
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.anthropic_api import Claude2
|
||||
|
||||
prompt = "who are you"
|
||||
resp = "I'am Claude2"
|
||||
|
||||
|
||||
def mock_llm_ask(self, msg: str) -> str:
|
||||
return resp
|
||||
|
||||
|
||||
async def mock_llm_aask(self, msg: str) -> str:
|
||||
return resp
|
||||
|
||||
|
||||
def test_claude2_ask(mocker):
|
||||
mocker.patch("metagpt.provider.anthropic_api.Claude2.ask", mock_llm_ask)
|
||||
assert resp == Claude2().ask(prompt)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_claude2_aask(mocker):
|
||||
mocker.patch("metagpt.provider.anthropic_api.Claude2.aask", mock_llm_aask)
|
||||
assert resp == await Claude2().aask(prompt)
|
||||
|
|
@ -6,10 +6,106 @@
|
|||
@File : test_base_gpt_api.py
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.schema import Message
|
||||
|
||||
default_chat_resp = {
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "I'am GPT",
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
]
|
||||
}
|
||||
prompt_msg = "who are you"
|
||||
resp_content = default_chat_resp["choices"][0]["message"]["content"]
|
||||
|
||||
def test_message():
|
||||
message = Message(role="user", content="wtf")
|
||||
|
||||
class MockBaseGPTAPI(BaseGPTAPI):
|
||||
def completion(self, messages: list[dict], timeout=3):
|
||||
return default_chat_resp
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=3):
|
||||
return default_chat_resp
|
||||
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str:
|
||||
return resp_content
|
||||
|
||||
async def close(self):
|
||||
return default_chat_resp
|
||||
|
||||
|
||||
def test_base_gpt_api():
|
||||
message = Message(role="user", content="hello")
|
||||
assert "role" in message.to_dict()
|
||||
assert "user" in str(message)
|
||||
|
||||
base_gpt_api = MockBaseGPTAPI()
|
||||
msg_prompt = base_gpt_api.messages_to_prompt([message])
|
||||
assert msg_prompt == "user: hello"
|
||||
|
||||
msg_dict = base_gpt_api.messages_to_dict([message])
|
||||
assert msg_dict == [{"role": "user", "content": "hello"}]
|
||||
|
||||
openai_funccall_resp = {
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "test",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_Y5r6Ddr2Qc2ZrqgfwzPX5l72",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "execute",
|
||||
"arguments": '{\n "language": "python",\n "code": "print(\'Hello, World!\')"\n}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
]
|
||||
}
|
||||
func: dict = base_gpt_api.get_choice_function(openai_funccall_resp)
|
||||
assert func == {
|
||||
"name": "execute",
|
||||
"arguments": '{\n "language": "python",\n "code": "print(\'Hello, World!\')"\n}',
|
||||
}
|
||||
|
||||
func_args: dict = base_gpt_api.get_choice_function_arguments(openai_funccall_resp)
|
||||
assert func_args == {"language": "python", "code": "print('Hello, World!')"}
|
||||
|
||||
choice_text = base_gpt_api.get_choice_text(openai_funccall_resp)
|
||||
assert choice_text == openai_funccall_resp["choices"][0]["message"]["content"]
|
||||
|
||||
resp = base_gpt_api.ask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = base_gpt_api.ask_batch([prompt_msg])
|
||||
assert resp == resp_content
|
||||
|
||||
resp = base_gpt_api.ask_code([prompt_msg])
|
||||
assert resp == resp_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_base_gpt_api():
|
||||
base_gpt_api = MockBaseGPTAPI()
|
||||
|
||||
resp = await base_gpt_api.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await base_gpt_api.aask_batch([prompt_msg])
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await base_gpt_api.aask_code([prompt_msg])
|
||||
assert resp == resp_content
|
||||
|
|
|
|||
|
|
@ -10,41 +10,82 @@ from openai.types.chat.chat_completion import (
|
|||
)
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
|
||||
from metagpt.provider.fireworks_api import FireWorksGPTAPI
|
||||
from metagpt.provider.fireworks_api import (
|
||||
MODEL_GRADE_TOKEN_COSTS,
|
||||
FireworksCostManager,
|
||||
FireWorksGPTAPI,
|
||||
)
|
||||
|
||||
resp_content = "I'm fireworks"
|
||||
default_resp = ChatCompletion(
|
||||
id="cmpl-a6652c1bb181caae8dd19ad8",
|
||||
model="accounts/fireworks/models/llama-v2-13b-chat",
|
||||
object="chat.completion",
|
||||
created=1703300855,
|
||||
choices=[
|
||||
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(role="assistant", content="I'm fireworks"))
|
||||
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(role="assistant", content=resp_content))
|
||||
],
|
||||
usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202),
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": "who are you"}]
|
||||
prompt_msg = "who are you"
|
||||
messages = [{"role": "user", "content": prompt_msg}]
|
||||
|
||||
|
||||
def mock_llm_ask(self, messages: list[dict]) -> ChatCompletion:
|
||||
def test_fireworks_costmanager():
|
||||
cost_manager = FireworksCostManager()
|
||||
assert MODEL_GRADE_TOKEN_COSTS["-1"] == cost_manager.model_grade_token_costs("test")
|
||||
assert MODEL_GRADE_TOKEN_COSTS["-1"] == cost_manager.model_grade_token_costs("xxx-81b-chat")
|
||||
assert MODEL_GRADE_TOKEN_COSTS["16"] == cost_manager.model_grade_token_costs("llama-v2-13b-chat")
|
||||
assert MODEL_GRADE_TOKEN_COSTS["16"] == cost_manager.model_grade_token_costs("xxx-15.5b-chat")
|
||||
assert MODEL_GRADE_TOKEN_COSTS["16"] == cost_manager.model_grade_token_costs("xxx-16b-chat")
|
||||
assert MODEL_GRADE_TOKEN_COSTS["80"] == cost_manager.model_grade_token_costs("xxx-80b-chat")
|
||||
assert MODEL_GRADE_TOKEN_COSTS["mixtral-8x7b"] == cost_manager.model_grade_token_costs("mixtral-8x7b-chat")
|
||||
|
||||
|
||||
def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> ChatCompletion:
|
||||
return default_resp
|
||||
|
||||
|
||||
async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> ChatCompletion:
|
||||
return default_resp
|
||||
|
||||
|
||||
async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
|
||||
return default_resp.choices[0].message.content
|
||||
|
||||
|
||||
def test_fireworks_completion(mocker):
|
||||
mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.completion", mock_llm_ask)
|
||||
mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.completion", mock_llm_completion)
|
||||
fireworks_gpt = FireWorksGPTAPI()
|
||||
|
||||
resp = FireWorksGPTAPI().completion(messages)
|
||||
assert "fireworks" in resp.choices[0].message.content
|
||||
resp = fireworks_gpt.completion(messages)
|
||||
assert resp.choices[0].message.content == resp_content
|
||||
|
||||
|
||||
async def mock_llm_aask(self, messgaes: list[dict], stream: bool = False) -> ChatCompletion:
|
||||
return default_resp
|
||||
resp = fireworks_gpt.ask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fireworks_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.acompletion", mock_llm_aask)
|
||||
mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.acompletion", mock_llm_acompletion)
|
||||
mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI._achat_completion", mock_llm_acompletion)
|
||||
mocker.patch(
|
||||
"metagpt.provider.fireworks_api.FireWorksGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream
|
||||
)
|
||||
fireworks_gpt = FireWorksGPTAPI()
|
||||
|
||||
resp = await FireWorksGPTAPI().acompletion(messages, stream=False)
|
||||
resp = await fireworks_gpt.acompletion(messages, stream=False)
|
||||
assert resp.choices[0].message.content in resp_content
|
||||
|
||||
assert "fireworks" in resp.choices[0].message.content
|
||||
resp = await fireworks_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await fireworks_gpt.acompletion_text(messages, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await fireworks_gpt.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await fireworks_gpt.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
|
|
|||
20
tests/metagpt/provider/test_general_api_requestor.py
Normal file
20
tests/metagpt/provider/test_general_api_requestor.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of APIRequestor
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.general_api_requestor import GeneralAPIRequestor
|
||||
|
||||
api_requestor = GeneralAPIRequestor(base_url="http://www.baidu.com")
|
||||
|
||||
|
||||
def test_api_requestor():
|
||||
resp, _, _ = api_requestor.request(method="get", url="/s?wd=baidu")
|
||||
assert b"baidu" in resp
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_api_requestor():
|
||||
resp, _, _ = await api_requestor.arequest(method="get", url="/s?wd=baidu")
|
||||
assert b"baidu" in resp
|
||||
|
|
@ -9,33 +9,62 @@ import pytest
|
|||
|
||||
from metagpt.provider.google_gemini_api import GeminiGPTAPI
|
||||
|
||||
messages = [{"role": "user", "parts": "who are you"}]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockGeminiResponse(ABC):
|
||||
text: str
|
||||
|
||||
|
||||
default_resp = MockGeminiResponse(text="I'm gemini from google")
|
||||
prompt_msg = "who are you"
|
||||
messages = [{"role": "user", "parts": prompt_msg}]
|
||||
resp_content = "I'm gemini from google"
|
||||
default_resp = MockGeminiResponse(text=resp_content)
|
||||
|
||||
|
||||
def mock_llm_ask(self, messages: list[dict]) -> MockGeminiResponse:
|
||||
def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> MockGeminiResponse:
|
||||
return default_resp
|
||||
|
||||
|
||||
async def mock_llm_acompletion(
|
||||
self, messgaes: list[dict], stream: bool = False, timeout: int = 60
|
||||
) -> MockGeminiResponse:
|
||||
return default_resp
|
||||
|
||||
|
||||
async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
|
||||
return resp_content
|
||||
|
||||
|
||||
def test_gemini_completion(mocker):
|
||||
mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.completion", mock_llm_ask)
|
||||
resp = GeminiGPTAPI().completion(messages)
|
||||
assert resp.text == default_resp.text
|
||||
mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.completion", mock_llm_completion)
|
||||
gemini_gpt = GeminiGPTAPI()
|
||||
resp = gemini_gpt.completion(messages)
|
||||
assert resp.text == resp_content
|
||||
|
||||
|
||||
async def mock_llm_aask(self, messgaes: list[dict]) -> MockGeminiResponse:
|
||||
return default_resp
|
||||
resp = gemini_gpt.ask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.acompletion", mock_llm_aask)
|
||||
resp = await GeminiGPTAPI().acompletion(messages)
|
||||
mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.acompletion", mock_llm_acompletion)
|
||||
mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI._achat_completion", mock_llm_acompletion)
|
||||
mocker.patch(
|
||||
"metagpt.provider.google_gemini_api.GeminiGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream
|
||||
)
|
||||
gemini_gpt = GeminiGPTAPI()
|
||||
|
||||
resp = await gemini_gpt.acompletion(messages)
|
||||
assert resp.text == default_resp.text
|
||||
|
||||
resp = await gemini_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await gemini_gpt.acompletion_text(messages, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await gemini_gpt.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await gemini_gpt.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
|
|
|||
38
tests/metagpt/provider/test_human_provider.py
Normal file
38
tests/metagpt/provider/test_human_provider.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of HumanProvider
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.human_provider import HumanProvider
|
||||
|
||||
resp_content = "test"
|
||||
|
||||
|
||||
def mock_llm_ask(msg: str, timeout: int = 3) -> str:
|
||||
return resp_content
|
||||
|
||||
|
||||
async def mock_llm_aask(msg: str, timeout: int = 3) -> str:
|
||||
return mock_llm_ask(msg)
|
||||
|
||||
|
||||
def test_human_provider(mocker):
|
||||
mocker.patch("metagpt.provider.human_provider.HumanProvider.ask", mock_llm_ask)
|
||||
human_provider = HumanProvider()
|
||||
|
||||
assert resp_content == human_provider.ask(None)
|
||||
|
||||
assert not human_provider.completion(messages=[])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_human_provider(mocker):
|
||||
mocker.patch("metagpt.provider.human_provider.HumanProvider.aask", mock_llm_aask)
|
||||
human_provider = HumanProvider()
|
||||
|
||||
resp = await human_provider.aask(None)
|
||||
assert resp_content == resp
|
||||
|
||||
resp = await human_provider.acompletion([])
|
||||
assert not resp
|
||||
|
|
@ -5,11 +5,11 @@
|
|||
@Author : mashenquan
|
||||
@File : test_metagpt_llm_api.py
|
||||
"""
|
||||
from metagpt.provider.metagpt_llm_api import MetaGPTLLMAPI
|
||||
from metagpt.provider.metagpt_api import MetaGPTAPI
|
||||
|
||||
|
||||
def test_metagpt():
|
||||
llm = MetaGPTLLMAPI()
|
||||
llm = MetaGPTAPI()
|
||||
assert llm
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,30 +4,58 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.ollama_api import OllamaGPTAPI
|
||||
|
||||
messages = [{"role": "user", "content": "who are you"}]
|
||||
prompt_msg = "who are you"
|
||||
messages = [{"role": "user", "content": prompt_msg}]
|
||||
|
||||
resp_content = "I'm ollama"
|
||||
default_resp = {"message": {"role": "assistant", "content": resp_content}}
|
||||
|
||||
CONFIG.ollama_api_base = "http://xxx"
|
||||
|
||||
|
||||
default_resp = {"message": {"role": "assisant", "content": "I'm ollama"}}
|
||||
|
||||
|
||||
def mock_llm_ask(self, messages: list[dict]) -> dict:
|
||||
def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> dict:
|
||||
return default_resp
|
||||
|
||||
|
||||
async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> dict:
|
||||
return default_resp
|
||||
|
||||
|
||||
async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
|
||||
return resp_content
|
||||
|
||||
|
||||
def test_gemini_completion(mocker):
|
||||
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.completion", mock_llm_ask)
|
||||
resp = OllamaGPTAPI().completion(messages)
|
||||
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.completion", mock_llm_completion)
|
||||
ollama_gpt = OllamaGPTAPI()
|
||||
resp = ollama_gpt.completion(messages)
|
||||
assert resp["message"]["content"] == default_resp["message"]["content"]
|
||||
|
||||
|
||||
async def mock_llm_aask(self, messgaes: list[dict]) -> dict:
|
||||
return default_resp
|
||||
resp = ollama_gpt.ask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.acompletion", mock_llm_aask)
|
||||
resp = await OllamaGPTAPI().acompletion(messages)
|
||||
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.acompletion", mock_llm_acompletion)
|
||||
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI._achat_completion", mock_llm_acompletion)
|
||||
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream)
|
||||
ollama_gpt = OllamaGPTAPI()
|
||||
|
||||
resp = await ollama_gpt.acompletion(messages)
|
||||
assert resp["message"]["content"] == default_resp["message"]["content"]
|
||||
|
||||
resp = await ollama_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await ollama_gpt.acompletion_text(messages, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await ollama_gpt.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await ollama_gpt.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
|
|
|||
|
|
@ -85,14 +85,23 @@ def test_ask_code_list_str():
|
|||
class TestOpenAI:
|
||||
@pytest.fixture
|
||||
def config(self):
|
||||
return Mock(openai_api_key="test_key", openai_base_url="test_url", openai_proxy=None, openai_api_type="other")
|
||||
return Mock(
|
||||
openai_api_key="test_key",
|
||||
OPENAI_API_KEY="test_key",
|
||||
openai_base_url="test_url",
|
||||
OPENAI_BASE_URL="test_url",
|
||||
openai_proxy=None,
|
||||
openai_api_type="other",
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def config_azure(self):
|
||||
return Mock(
|
||||
openai_api_key="test_key",
|
||||
OPENAI_API_KEY="test_key",
|
||||
openai_api_version="test_version",
|
||||
openai_base_url="test_url",
|
||||
OPENAI_BASE_URL="test_url",
|
||||
openai_proxy=None,
|
||||
openai_api_type="azure",
|
||||
)
|
||||
|
|
@ -101,7 +110,9 @@ class TestOpenAI:
|
|||
def config_proxy(self):
|
||||
return Mock(
|
||||
openai_api_key="test_key",
|
||||
OPENAI_API_KEY="test_key",
|
||||
openai_base_url="test_url",
|
||||
OPENAI_BASE_URL="test_url",
|
||||
openai_proxy="http://proxy.com",
|
||||
openai_api_type="other",
|
||||
)
|
||||
|
|
@ -110,8 +121,10 @@ class TestOpenAI:
|
|||
def config_azure_proxy(self):
|
||||
return Mock(
|
||||
openai_api_key="test_key",
|
||||
OPENAI_API_KEY="test_key",
|
||||
openai_api_version="test_version",
|
||||
openai_base_url="test_url",
|
||||
OPENAI_BASE_URL="test_url",
|
||||
openai_proxy="http://proxy.com",
|
||||
openai_api_type="azure",
|
||||
)
|
||||
|
|
@ -129,8 +142,8 @@ class TestOpenAI:
|
|||
instance = OpenAIGPTAPI()
|
||||
instance.config = config_azure
|
||||
kwargs, async_kwargs = instance._make_client_kwargs()
|
||||
assert kwargs == {"api_key": "test_key", "api_version": "test_version", "azure_endpoint": "test_url"}
|
||||
assert async_kwargs == {"api_key": "test_key", "api_version": "test_version", "azure_endpoint": "test_url"}
|
||||
assert kwargs == {"api_key": "test_key", "base_url": "test_url"}
|
||||
assert async_kwargs == {"api_key": "test_key", "base_url": "test_url"}
|
||||
assert "http_client" not in kwargs
|
||||
assert "http_client" not in async_kwargs
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,51 @@
|
|||
from metagpt.logs import logger
|
||||
from metagpt.provider.spark_api import SparkAPI
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of spark api
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.spark_api import SparkGPTAPI
|
||||
|
||||
prompt_msg = "who are you"
|
||||
resp_content = "I'm Spark"
|
||||
|
||||
|
||||
def test_message():
|
||||
llm = SparkAPI()
|
||||
def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> str:
|
||||
return resp_content
|
||||
|
||||
logger.info(llm.ask('只回答"收到了"这三个字。'))
|
||||
result = llm.ask("写一篇五百字的日记")
|
||||
logger.info(result)
|
||||
assert len(result) > 100
|
||||
|
||||
async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> str:
|
||||
return resp_content
|
||||
|
||||
|
||||
def test_spark_completion(mocker):
|
||||
mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.completion", mock_llm_completion)
|
||||
spark_gpt = SparkGPTAPI()
|
||||
|
||||
resp = spark_gpt.completion([])
|
||||
assert resp == resp_content
|
||||
|
||||
resp = spark_gpt.ask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spark_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.acompletion", mock_llm_acompletion)
|
||||
mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.acompletion_text", mock_llm_acompletion)
|
||||
spark_gpt = SparkGPTAPI()
|
||||
|
||||
resp = await spark_gpt.acompletion([], stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await spark_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await spark_gpt.acompletion_text([], stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await spark_gpt.acompletion_text([], stream=True)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await spark_gpt.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
|
|
|||
|
|
@ -4,34 +4,62 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI
|
||||
|
||||
default_resp = {"code": 200, "data": {"choices": [{"role": "assistant", "content": "I'm chatglm-turbo"}]}}
|
||||
CONFIG.zhipuai_api_key = "xxx"
|
||||
|
||||
messages = [{"role": "user", "content": "who are you"}]
|
||||
prompt_msg = "who are you"
|
||||
messages = [{"role": "user", "content": prompt_msg}]
|
||||
|
||||
resp_content = "I'm chatglm-turbo"
|
||||
default_resp = {"code": 200, "data": {"choices": [{"role": "assistant", "content": resp_content}]}}
|
||||
|
||||
|
||||
def mock_llm_ask(self, messages: list[dict]) -> dict:
|
||||
def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> dict:
|
||||
return default_resp
|
||||
|
||||
|
||||
async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> dict:
|
||||
return default_resp
|
||||
|
||||
|
||||
async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
|
||||
return resp_content
|
||||
|
||||
|
||||
def test_zhipuai_completion(mocker):
|
||||
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.completion", mock_llm_ask)
|
||||
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.completion", mock_llm_completion)
|
||||
zhipu_gpt = ZhiPuAIGPTAPI()
|
||||
|
||||
resp = ZhiPuAIGPTAPI().completion(messages)
|
||||
resp = zhipu_gpt.completion(messages)
|
||||
assert resp["code"] == 200
|
||||
assert "chatglm-turbo" in resp["data"]["choices"][0]["content"]
|
||||
assert resp["data"]["choices"][0]["content"] == resp_content
|
||||
|
||||
|
||||
async def mock_llm_aask(self, messgaes: list[dict], stream: bool = False) -> dict:
|
||||
return default_resp
|
||||
resp = zhipu_gpt.ask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zhipuai_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.acompletion_text", mock_llm_aask)
|
||||
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.acompletion", mock_llm_acompletion)
|
||||
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI._achat_completion", mock_llm_acompletion)
|
||||
mocker.patch(
|
||||
"metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream
|
||||
)
|
||||
zhipu_gpt = ZhiPuAIGPTAPI()
|
||||
|
||||
resp = await ZhiPuAIGPTAPI().acompletion_text(messages, stream=False)
|
||||
resp = await zhipu_gpt.acompletion(messages)
|
||||
assert resp["data"]["choices"][0]["content"] == resp_content
|
||||
|
||||
assert resp["code"] == 200
|
||||
assert "chatglm-turbo" in resp["data"]["choices"][0]["content"]
|
||||
resp = await zhipu_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await zhipu_gpt.acompletion_text(messages, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await zhipu_gpt.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await zhipu_gpt.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue