update provider unittests to update coverage rate

This commit is contained in:
better629 2023-12-29 02:39:00 +08:00
parent 5fc8207950
commit 0f047e5693
26 changed files with 509 additions and 76 deletions

View file

@ -17,7 +17,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential
from metagpt.config import CONFIG
from metagpt.llm import BaseLLM
from metagpt.logs import logger
from metagpt.provider.postprecess.llm_output_postprecess import llm_output_postprecess
from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess
from metagpt.utils.common import OutputParser, general_after_log
TAG = "CONTENT"
@ -275,7 +275,7 @@ class ActionNode:
output_class = self.create_model_class(output_class_name, output_data_mapping)
if schema == "json":
parsed_data = llm_output_postprecess(
parsed_data = llm_output_postprocess(
output=content, schema=output_class.model_json_schema(), req_key=f"[/{TAG}]"
)
else: # using markdown parser

View file

@ -100,7 +100,7 @@ def log_info(message, **params):
def log_warn(message, **params):
msg = logfmt(dict(message=message, **params))
print(msg, file=sys.stderr)
logger.warn(msg)
logger.warning(msg)
def logfmt(props):

View file

@ -79,9 +79,6 @@ class GeminiLLM(BaseLLM):
except Exception as e:
logger.error(f"google gemini updats costs failed! exp: {e}")
def close(self):
pass
def get_choice_text(self, resp: GenerateContentResponse) -> str:
return resp.text

View file

@ -31,7 +31,6 @@ class OpenLLMCostManager(CostManager):
f"Max budget: ${CONFIG.max_budget:.3f} | reference "
f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
)
CONFIG.total_cost = self.total_cost
@register_provider(LLMProviderEnum.OPEN_LLM)

View file

@ -12,8 +12,8 @@ from metagpt.utils.repair_llm_raw_output import (
)
class BasePostPrecessPlugin(object):
model = None # the plugin of the `model`, use to judge in `llm_postprecess`
class BasePostProcessPlugin(object):
model = None # the plugin of the `model`, use to judge in `llm_postprocess`
def run_repair_llm_output(self, output: str, schema: dict, req_key: str = "[/CONTENT]") -> Union[dict, list]:
"""

View file

@ -4,17 +4,17 @@
from typing import Union
from metagpt.provider.postprecess.base_postprecess_plugin import BasePostPrecessPlugin
from metagpt.provider.postprocess.base_postprocess_plugin import BasePostProcessPlugin
def llm_output_postprecess(
def llm_output_postprocess(
output: str, schema: dict, req_key: str = "[/CONTENT]", model_name: str = None
) -> Union[dict, str]:
"""
default use BasePostPrecessPlugin if there is not matched plugin.
default use BasePostProcessPlugin if there is not matched plugin.
"""
# TODO choose different model's plugin according to the model_name
postprecess_plugin = BasePostPrecessPlugin()
postprocess_plugin = BasePostProcessPlugin()
result = postprecess_plugin.run(output=output, schema=schema, req_key=req_key)
result = postprocess_plugin.run(output=output, schema=schema, req_key=req_key)
return result

View file

@ -33,7 +33,7 @@ class ZhiPuModelAPI(ModelAPI):
zhipu_api_url: https://open.bigmodel.cn/api/paas/v3/model-api/{model}/{invoke_method}
"""
arr = zhipu_api_url.split("/api/")
# ("https://open.bigmodel.cn/api/" , "/paas/v3/model-api/chatglm_turbo/invoke")
# ("https://open.bigmodel.cn/api" , "/paas/v3/model-api/chatglm_turbo/invoke")
return f"{arr[0]}/api", f"/{arr[1]}"
@classmethod

View file

@ -68,9 +68,6 @@ class ZhiPuAILLM(BaseLLM):
except Exception as e:
logger.error(f"zhipuai updats costs failed! exp: {e}")
def close(self):
pass
def get_choice_text(self, resp: dict) -> str:
"""get the first text of choice from llm response"""
assist_msg = resp.get("data", {}).get("choices", [{"role": "error"}])[-1]

View file

@ -0,0 +1,3 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :

View file

@ -0,0 +1,38 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
import pytest
from metagpt.provider.postprocess.base_postprocess_plugin import BasePostProcessPlugin
raw_output = """
[CONTENT]
{
"Original Requirements": "xxx"
}
[/CONTENT]
"""
raw_schema = {
"title":"prd",
"type":"object",
"properties":{
"Original Requirements":{
"title":"Original Requirements",
"type":"string"
},
},
"required":[
"Original Requirements",
]
}
def test_llm_post_process_plugin():
post_process_plugin = BasePostProcessPlugin()
output = post_process_plugin.run(
output=raw_output,
schema=raw_schema
)
assert "Original Requirements" in output

View file

@ -0,0 +1,14 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
import pytest
from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess
from tests.metagpt.provider.postprocess.test_base_postprocess_plugin import raw_output, raw_schema
def test_llm_output_postprocess():
output = llm_output_postprocess(output=raw_output, schema=raw_schema)
assert "Original Requirements" in output

View file

@ -2,28 +2,33 @@
# -*- coding: utf-8 -*-
# @Desc : the unittest of Claude2
import pytest
import pytest
from anthropic.resources.completions import Completion
from metagpt.config import CONFIG
from metagpt.provider.anthropic_api import Claude2
CONFIG.anthropic_api_key = "xxx"
prompt = "who are you"
resp = "I'am Claude2"
def mock_llm_ask(self, msg: str) -> str:
return resp
def mock_anthropic_completions_create(self, model: str, prompt: str, max_tokens_to_sample: int) -> Completion:
return Completion(id="xx", completion=resp, model="claude-2", stop_reason="stop_sequence", type="completion")
async def mock_llm_aask(self, msg: str) -> str:
return resp
async def mock_anthropic_acompletions_create(self, model: str, prompt: str, max_tokens_to_sample: int) -> Completion:
return Completion(id="xx", completion=resp, model="claude-2", stop_reason="stop_sequence", type="completion")
def test_claude2_ask(mocker):
mocker.patch("metagpt.provider.anthropic_api.Claude2.ask", mock_llm_ask)
mocker.patch("anthropic.resources.completions.Completions.create", mock_anthropic_completions_create)
assert resp == Claude2().ask(prompt)
@pytest.mark.asyncio
async def test_claude2_aask(mocker):
mocker.patch("metagpt.provider.anthropic_api.Claude2.aask", mock_llm_aask)
mocker.patch("anthropic.resources.completions.AsyncCompletions.create", mock_anthropic_acompletions_create)
assert resp == await Claude2().aask(prompt)

View file

@ -0,0 +1,15 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
import pytest
from metagpt.provider.azure_openai_api import AzureOpenAILLM
from metagpt.config import CONFIG
CONFIG.OPENAI_API_VERSION = "xx"
CONFIG.openai_proxy = "http://127.0.0.1:80" # fake value
def test_azure_openai_api():
_ = AzureOpenAILLM()

View file

@ -8,6 +8,9 @@ from openai.types.chat.chat_completion import (
ChatCompletionMessage,
Choice,
)
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.chat.chat_completion_chunk import Choice as AChoice
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from openai.types.completion_usage import CompletionUsage
from metagpt.config import CONFIG
@ -16,8 +19,11 @@ from metagpt.provider.fireworks_api import (
FireworksCostManager,
FireworksLLM,
)
from metagpt.utils.cost_manager import Costs
CONFIG.fireworks_api_key = "xxx"
CONFIG.max_budget = 10
CONFIG.calc_usage = True
resp_content = "I'm fireworks"
default_resp = ChatCompletion(
@ -36,6 +42,22 @@ default_resp = ChatCompletion(
usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202),
)
default_resp_chunk = ChatCompletionChunk(
id=default_resp.id,
model=default_resp.model,
object="chat.completion.chunk",
created=default_resp.created,
choices=[
AChoice(
delta=ChoiceDelta(content=resp_content, role="assistant"),
finish_reason="stop",
index=0,
logprobs=None,
)
],
usage=dict(default_resp.usage),
)
prompt_msg = "who are you"
messages = [{"role": "user", "content": prompt_msg}]
@ -50,29 +72,37 @@ def test_fireworks_costmanager():
assert MODEL_GRADE_TOKEN_COSTS["80"] == cost_manager.model_grade_token_costs("xxx-80b-chat")
assert MODEL_GRADE_TOKEN_COSTS["mixtral-8x7b"] == cost_manager.model_grade_token_costs("mixtral-8x7b-chat")
def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> ChatCompletion:
return default_resp
cost_manager.update_cost(prompt_tokens=500000, completion_tokens=500000, model="llama-v2-13b-chat")
assert cost_manager.total_cost == 0.5
async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> ChatCompletion:
return default_resp
async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) -> ChatCompletionChunk:
if stream:
class Iterator(object):
async def __aiter__(self):
yield default_resp_chunk
async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
return default_resp.choices[0].message.content
return Iterator()
else:
return default_resp
@pytest.mark.asyncio
async def test_fireworks_acompletion(mocker):
mocker.patch("metagpt.provider.fireworks_api.FireworksLLM.acompletion", mock_llm_acompletion)
mocker.patch("metagpt.provider.fireworks_api.FireworksLLM._achat_completion", mock_llm_acompletion)
mocker.patch(
"metagpt.provider.fireworks_api.FireworksLLM._achat_completion_stream", mock_llm_achat_completion_stream
)
fireworks_gpt = FireworksLLM()
mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create)
resp = await fireworks_gpt.acompletion(messages, stream=False)
fireworks_gpt = FireworksLLM()
fireworks_gpt.model = "llama-v2-13b-chat"
fireworks_gpt._update_costs(
usage=CompletionUsage(prompt_tokens=500000, completion_tokens=500000, total_tokens=1000000)
)
assert fireworks_gpt.get_costs() == Costs(
total_prompt_tokens=500000, total_completion_tokens=500000, total_cost=0.5, total_budget=0
)
resp = await fireworks_gpt.acompletion(messages)
assert resp.choices[0].message.content in resp_content
resp = await fireworks_gpt.aask(prompt_msg, stream=False)

View file

@ -0,0 +1,84 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
import pytest
import os
import requests
import aiohttp
from typing import Iterator, Tuple, Union, Generator, AsyncGenerator
from openai import OpenAIError
from metagpt.provider.general_api_base import ApiType, log_debug, log_info, log_warn, OpenAIResponse, \
_requests_proxies_arg, _aiohttp_proxies_arg, _make_session, parse_stream_helper, parse_stream, APIRequestor
def test_basic():
_ = ApiType.from_str("azure")
_ = ApiType.from_str("azuread")
_ = ApiType.from_str("openai")
with pytest.raises(OpenAIError):
_ = ApiType.from_str("xx")
os.environ.setdefault("LLM_LOG", "debug")
log_debug("debug")
log_warn("warn")
log_info("info")
def test_openai_response():
resp = OpenAIResponse(data=[], headers={"retry-after": 3})
assert resp.request_id is None
assert resp.retry_after == 3
assert resp.operation_location is None
assert resp.organization is None
assert resp.response_ms is None
def test_proxy():
assert _requests_proxies_arg(proxy=None) is None
proxy = "127.0.0.1:80"
assert _requests_proxies_arg(proxy=proxy) == {"http": proxy, "https": proxy}
proxy_dict = {"http": proxy}
assert _requests_proxies_arg(proxy=proxy_dict) == proxy_dict
proxy_dict = {"https": proxy}
assert _requests_proxies_arg(proxy=proxy_dict) == proxy_dict
assert _make_session() is not None
def test_parse_stream():
assert parse_stream_helper(None) is None
assert parse_stream_helper(b"data: [DONE]") is None
assert parse_stream_helper(b"data: test") == "test"
assert parse_stream_helper(b"test") is None
for line in parse_stream([b"data: test"]):
assert line == "test"
api_requestor = APIRequestor(base_url="http://www.baidu.com")
def mock_interpret_response(self, result: requests.Response, stream: bool
) -> Tuple[Union[bytes, Iterator[Generator]], bytes]:
return b"baidu", False
async def mock_interpret_async_response(self, result: aiohttp.ClientResponse, stream: bool
) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool]:
return b"baidu", True
def test_api_requestor(mocker):
mocker.patch("metagpt.provider.general_api_base.APIRequestor._interpret_response", mock_interpret_response)
resp, _, _ = api_requestor.request(method="get", url="/s?wd=baidu")
resp, _, _ = api_requestor.request(method="post", url="/s?wd=baidu")
@pytest.mark.asyncio
async def test_async_api_requestor(mocker):
mocker.patch("metagpt.provider.general_api_base.APIRequestor._interpret_async_response", mock_interpret_async_response)
resp, _, _ = await api_requestor.arequest(method="get", url="/s?wd=baidu")
resp, _, _ = await api_requestor.arequest(method="post", url="/s?wd=baidu")

View file

@ -4,11 +4,24 @@
import pytest
from metagpt.provider.general_api_requestor import GeneralAPIRequestor
from metagpt.provider.general_api_requestor import (
GeneralAPIRequestor,
parse_stream,
parse_stream_helper,
)
api_requestor = GeneralAPIRequestor(base_url="http://www.baidu.com")
def test_parse_stream():
assert parse_stream_helper(None) is None
assert parse_stream_helper(b"data: [DONE]") is None
assert parse_stream_helper(b"data: test") == b"test"
assert parse_stream_helper(b"test") is None
for line in parse_stream([b"data: test"]):
assert line == b"test"
def test_api_requestor():
resp, _, _ = api_requestor.request(method="get", url="/s?wd=baidu")
assert b"baidu" in resp

View file

@ -6,9 +6,14 @@ from abc import ABC
from dataclasses import dataclass
import pytest
from google.ai import generativelanguage as glm
from google.generativeai.types import content_types
from metagpt.config import CONFIG
from metagpt.provider.google_gemini_api import GeminiLLM
CONFIG.gemini_api_key = "xx"
@dataclass
class MockGeminiResponse(ABC):
@ -21,29 +26,53 @@ resp_content = "I'm gemini from google"
default_resp = MockGeminiResponse(text=resp_content)
def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> MockGeminiResponse:
def mock_gemini_count_tokens(self, contents: content_types.ContentsType) -> glm.CountTokensResponse:
return glm.CountTokensResponse(total_tokens=20)
async def mock_gemini_count_tokens_async(self, contents: content_types.ContentsType) -> glm.CountTokensResponse:
return glm.CountTokensResponse(total_tokens=20)
def mock_gemini_generate_content(self, **kwargs) -> MockGeminiResponse:
return default_resp
async def mock_llm_acompletion(
self, messgaes: list[dict], stream: bool = False, timeout: int = 60
) -> MockGeminiResponse:
return default_resp
async def mock_gemini_generate_content_async(self, stream: bool = False, **kwargs) -> MockGeminiResponse:
if stream:
class Iterator(object):
async def __aiter__(self):
yield default_resp
async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
return resp_content
return Iterator()
else:
return default_resp
@pytest.mark.asyncio
async def test_gemini_acompletion(mocker):
mocker.patch("metagpt.provider.google_gemini_api.GeminiLLM.acompletion", mock_llm_acompletion)
mocker.patch("metagpt.provider.google_gemini_api.GeminiLLM._achat_completion", mock_llm_acompletion)
mocker.patch("metagpt.provider.google_gemini_api.GeminiGenerativeModel.count_tokens", mock_gemini_count_tokens)
mocker.patch(
"metagpt.provider.google_gemini_api.GeminiLLM._achat_completion_stream", mock_llm_achat_completion_stream
"metagpt.provider.google_gemini_api.GeminiGenerativeModel.count_tokens_async", mock_gemini_count_tokens_async
)
mocker.patch("google.generativeai.generative_models.GenerativeModel.generate_content", mock_gemini_generate_content)
mocker.patch(
"google.generativeai.generative_models.GenerativeModel.generate_content_async",
mock_gemini_generate_content_async,
)
gemini_gpt = GeminiLLM()
assert gemini_gpt._user_msg(prompt_msg) == {"role": "user", "parts": [prompt_msg]}
assert gemini_gpt._assistant_msg(prompt_msg) == {"role": "model", "parts": [prompt_msg]}
usage = gemini_gpt.get_usage(messages, resp_content)
assert usage == {"prompt_tokens": 20, "completion_tokens": 20}
resp = gemini_gpt.completion(messages)
assert resp == default_resp
resp = await gemini_gpt.acompletion(messages)
assert resp.text == default_resp.text

View file

@ -2,6 +2,9 @@
# -*- coding: utf-8 -*-
# @Desc : the unittest of ollama api
import json
from typing import Any, Tuple
import pytest
from metagpt.config import CONFIG
@ -14,25 +17,33 @@ resp_content = "I'm ollama"
default_resp = {"message": {"role": "assistant", "content": resp_content}}
CONFIG.ollama_api_base = "http://xxx"
CONFIG.max_budget = 10
def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> dict:
return default_resp
async def mock_ollama_arequest(self, stream: bool = False, **kwargs) -> Tuple[Any, Any, bool]:
if stream:
class Iterator(object):
events = [
b'{"message": {"role": "assistant", "content": "I\'m ollama"}, "done": false}',
b'{"prompt_eval_count": 20, "eval_count": 20, "done": true}',
]
async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> dict:
return default_resp
async def __aiter__(self):
for event in self.events:
yield event
async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
return resp_content
return Iterator(), None, None
else:
raw_default_resp = default_resp.copy()
raw_default_resp.update({"prompt_eval_count": 20, "eval_count": 20})
return json.dumps(raw_default_resp).encode(), None, None
@pytest.mark.asyncio
async def test_gemini_acompletion(mocker):
mocker.patch("metagpt.provider.ollama_api.OllamaLLM.acompletion", mock_llm_acompletion)
mocker.patch("metagpt.provider.ollama_api.OllamaLLM._achat_completion", mock_llm_acompletion)
mocker.patch("metagpt.provider.ollama_api.OllamaLLM._achat_completion_stream", mock_llm_achat_completion_stream)
mocker.patch("metagpt.provider.general_api_requestor.GeneralAPIRequestor.arequest", mock_ollama_arequest)
ollama_gpt = OllamaLLM()
resp = await ollama_gpt.acompletion(messages)

View file

@ -0,0 +1,93 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
import pytest
from openai.types.chat.chat_completion import (
ChatCompletion,
ChatCompletionMessage,
Choice,
)
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, ChoiceDelta, Choice as AChoice
from openai.types.completion_usage import CompletionUsage
from metagpt.config import CONFIG
from metagpt.provider.open_llm_api import OpenLLM
from metagpt.utils.cost_manager import Costs
CONFIG.max_budget = 10
CONFIG.calc_usage = True
resp_content = "I'm llama2"
default_resp = ChatCompletion(
id="cmpl-a6652c1bb181caae8dd19ad8",
model="llama-v2-13b-chat",
object="chat.completion",
created=1703302755,
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(role="assistant", content=resp_content),
logprobs=None,
)
]
)
default_resp_chunk = ChatCompletionChunk(
id=default_resp.id,
model=default_resp.model,
object="chat.completion.chunk",
created=default_resp.created,
choices=[
AChoice(
delta=ChoiceDelta(
content=resp_content,
role="assistant"
),
finish_reason="stop",
index=0,
logprobs=None,
)
]
)
prompt_msg = "who are you"
messages = [{"role": "user", "content": prompt_msg}]
async def mock_openai_acompletions_create(self, stream: bool=False, **kwargs) -> ChatCompletionChunk:
if stream:
class Iterator(object):
async def __aiter__(self):
yield default_resp_chunk
return Iterator()
else:
return default_resp
@pytest.mark.asyncio
async def test_openllm_acompletion(mocker):
mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create)
openllm_gpt = OpenLLM()
openllm_gpt.model = "llama-v2-13b-chat"
openllm_gpt._update_costs(usage=CompletionUsage(prompt_tokens=100, completion_tokens=100, total_tokens=200))
assert openllm_gpt.get_costs() == Costs(total_prompt_tokens=100, total_completion_tokens=100, total_cost=0, total_budget=0)
resp = await openllm_gpt.acompletion(messages)
assert resp.choices[0].message.content in resp_content
resp = await openllm_gpt.aask(prompt_msg, stream=False)
assert resp == resp_content
resp = await openllm_gpt.acompletion_text(messages, stream=False)
assert resp == resp_content
resp = await openllm_gpt.acompletion_text(messages, stream=True)
assert resp == resp_content
resp = await openllm_gpt.aask(prompt_msg)
assert resp == resp_content

View file

@ -2,9 +2,14 @@ from unittest.mock import Mock
import pytest
from metagpt.config import CONFIG
from metagpt.provider.openai_api import OpenAILLM
from metagpt.schema import UserMessage
CONFIG.openai_proxy = None
print("openai_api_key ", CONFIG.openai_api_key)
@pytest.mark.asyncio
async def test_aask_code():

View file

@ -4,24 +4,31 @@
import pytest
from metagpt.provider.spark_api import SparkLLM
from metagpt.config import CONFIG
from metagpt.provider.spark_api import GetMessageFromWeb, SparkLLM
CONFIG.spark_appid = "xxx"
CONFIG.spark_api_secret = "xxx"
CONFIG.spark_api_key = "xxx"
CONFIG.domain = "xxxxxx"
CONFIG.spark_url = "xxxx"
prompt_msg = "who are you"
resp_content = "I'm Spark"
def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> str:
return resp_content
def test_get_msg_from_web():
get_msg_from_web = GetMessageFromWeb(text=prompt_msg)
assert get_msg_from_web.gen_params()["parameter"]["chat"]["domain"] == "xxxxxx"
async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> str:
def mock_spark_get_msg_from_web_run(self) -> str:
return resp_content
@pytest.mark.asyncio
async def test_spark_acompletion(mocker):
mocker.patch("metagpt.provider.spark_api.SparkLLM.acompletion", mock_llm_acompletion)
mocker.patch("metagpt.provider.spark_api.SparkLLM.acompletion_text", mock_llm_acompletion)
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
spark_gpt = SparkLLM()
resp = await spark_gpt.acompletion([])

View file

@ -3,36 +3,68 @@
# @Desc : the unittest of ZhiPuAILLM
import pytest
from zhipuai.utils.sse_client import Event
from metagpt.config import CONFIG
from metagpt.provider.zhipuai_api import ZhiPuAILLM
CONFIG.zhipuai_api_key = "xxx"
CONFIG.zhipuai_api_key = "xxx.xxx"
prompt_msg = "who are you"
messages = [{"role": "user", "content": prompt_msg}]
resp_content = "I'm chatglm-turbo"
default_resp = {"code": 200, "data": {"choices": [{"role": "assistant", "content": resp_content}]}}
default_resp = {
"code": 200,
"data": {
"choices": [{"role": "assistant", "content": resp_content}],
"usage": {"prompt_tokens": 20, "completion_tokens": 20},
},
}
def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> dict:
def mock_zhipuai_invoke(**kwargs) -> dict:
return default_resp
async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> dict:
async def mock_zhipuai_ainvoke(**kwargs) -> dict:
return default_resp
async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
return resp_content
async def mock_zhipuai_asse_invoke(**kwargs):
class MockResponse(object):
async def _aread(self):
class Iterator(object):
events = [
Event(id="xxx", event="add", data=resp_content, retry=0),
Event(
id="xxx",
event="finish",
data="",
meta='{"usage": {"completion_tokens": 20,"prompt_tokens": 20}}',
),
]
async def __aiter__(self):
for event in self.events:
yield event
async for chunk in Iterator():
yield chunk
async def async_events(self):
async for chunk in self._aread():
yield chunk
return MockResponse()
@pytest.mark.asyncio
async def test_zhipuai_acompletion(mocker):
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAILLM.acompletion", mock_llm_acompletion)
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAILLM._achat_completion", mock_llm_acompletion)
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAILLM._achat_completion_stream", mock_llm_achat_completion_stream)
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.invoke", mock_zhipuai_invoke)
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.ainvoke", mock_zhipuai_ainvoke)
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.asse_invoke", mock_zhipuai_asse_invoke)
zhipu_gpt = ZhiPuAILLM()
resp = await zhipu_gpt.acompletion(messages)
@ -51,7 +83,7 @@ async def test_zhipuai_acompletion(mocker):
assert resp == resp_content
def test_zhipuai_proxy(mocker):
def test_zhipuai_proxy():
import openai
from metagpt.config import CONFIG

View file

@ -0,0 +1,3 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :

View file

@ -0,0 +1,18 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
import pytest
from metagpt.provider.zhipuai.async_sse_client import AsyncSSEClient
@pytest.mark.asyncio
async def test_async_sse_client():
class Iterator(object):
async def __aiter__(self):
yield b"data: test_value"
async_sse_client = AsyncSSEClient(event_source=Iterator())
async for event in async_sse_client.async_events():
assert event.data, "test_value"

View file

@ -0,0 +1,40 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
from typing import Any, Tuple
import pytest
import zhipuai
from zhipuai.model_api.api import InvokeType
from zhipuai.utils.http_client import headers as zhipuai_default_headers
from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI
api_key = "xxx.xxx"
zhipuai.api_key = api_key
default_resp = {"result": "test response"}
async def mock_requestor_arequest(self, **kwargs) -> Tuple[Any, Any, str]:
return default_resp, None, None
@pytest.mark.asyncio
async def test_zhipu_model_api(mocker):
header = ZhiPuModelAPI.get_header()
zhipuai_default_headers.update({"Authorization": api_key})
assert header == zhipuai_default_headers
sse_header = ZhiPuModelAPI.get_sse_header()
assert len(sse_header["Authorization"]) == 191
url_prefix, url_suffix = ZhiPuModelAPI.split_zhipu_api_url(InvokeType.SYNC, kwargs={"model": "chatglm_turbo"})
assert url_prefix == "https://open.bigmodel.cn/api"
assert url_suffix == "/paas/v3/model-api/chatglm_turbo/invoke"
mocker.patch("metagpt.provider.general_api_requestor.GeneralAPIRequestor.arequest", mock_requestor_arequest)
result = await ZhiPuModelAPI.arequest(InvokeType.SYNC, stream=False, method="get", headers={}, kwargs={"model": "chatglm_turbo"})
assert result == default_resp