mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-11 15:15:18 +02:00
commit
9940022096
13 changed files with 447 additions and 17 deletions
|
|
@ -7,7 +7,7 @@
|
|||
## Or, you can configure OPENAI_PROXY to access official OPENAI_API_BASE.
|
||||
OPENAI_API_BASE: "https://api.openai.com/v1"
|
||||
#OPENAI_PROXY: "http://127.0.0.1:8118"
|
||||
#OPENAI_API_KEY: "YOUR_API_KEY"
|
||||
#OPENAI_API_KEY: "YOUR_API_KEY" # set the value to sk-xxx if you host the openai interface for open llm model
|
||||
OPENAI_API_MODEL: "gpt-4"
|
||||
MAX_TOKENS: 1500
|
||||
RPM: 10
|
||||
|
|
@ -31,6 +31,9 @@ RPM: 10
|
|||
#DEPLOYMENT_NAME: "YOUR_DEPLOYMENT_NAME"
|
||||
#DEPLOYMENT_ID: "YOUR_DEPLOYMENT_ID"
|
||||
|
||||
#### if zhipuai from `https://open.bigmodel.cn`. You can set here or export API_KEY="YOUR_API_KEY"
|
||||
# ZHIPUAI_API_KEY: "YOUR_API_KEY"
|
||||
|
||||
#### for Search
|
||||
|
||||
## Supported values: serpapi/google/serper/ddg
|
||||
|
|
|
|||
|
|
@ -45,10 +45,11 @@ class Config(metaclass=Singleton):
|
|||
self.global_proxy = self._get("GLOBAL_PROXY")
|
||||
self.openai_api_key = self._get("OPENAI_API_KEY")
|
||||
self.anthropic_api_key = self._get("Anthropic_API_KEY")
|
||||
if (not self.openai_api_key or "YOUR_API_KEY" == self.openai_api_key) and (
|
||||
not self.anthropic_api_key or "YOUR_API_KEY" == self.anthropic_api_key
|
||||
):
|
||||
raise NotConfiguredException("Set OPENAI_API_KEY or Anthropic_API_KEY first")
|
||||
self.zhipuai_api_key = self._get("ZHIPUAI_API_KEY")
|
||||
if (not self.openai_api_key or "YOUR_API_KEY" == self.openai_api_key) and \
|
||||
(not self.anthropic_api_key or "YOUR_API_KEY" == self.anthropic_api_key) and \
|
||||
(not self.zhipuai_api_key or "YOUR_API_KEY" == self.zhipuai_api_key):
|
||||
raise NotConfiguredException("Set OPENAI_API_KEY or Anthropic_API_KEY or ZHIPUAI_API_KEY first")
|
||||
self.openai_api_base = self._get("OPENAI_API_BASE")
|
||||
openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy
|
||||
if openai_proxy:
|
||||
|
|
|
|||
|
|
@ -6,15 +6,27 @@
|
|||
@File : llm.py
|
||||
"""
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.anthropic_api import Claude2 as Claude
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI as LLM
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI
|
||||
from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI
|
||||
from metagpt.provider.spark_api import SparkAPI
|
||||
from metagpt.provider.human_provider import HumanProvider
|
||||
|
||||
DEFAULT_LLM = LLM()
|
||||
CLAUDE_LLM = Claude()
|
||||
|
||||
async def ai_func(prompt):
|
||||
"""使用LLM进行QA
|
||||
QA with LLMs
|
||||
"""
|
||||
return await DEFAULT_LLM.aask(prompt)
|
||||
def LLM() -> "BaseGPTAPI":
|
||||
""" initialize different LLM instance according to the key field existence"""
|
||||
# TODO a little trick, can use registry to initialize LLM instance further
|
||||
if CONFIG.openai_api_key and CONFIG.openai_api_key.startswith("sk-"):
|
||||
llm = OpenAIGPTAPI()
|
||||
elif CONFIG.claude_api_key:
|
||||
llm = Claude()
|
||||
elif CONFIG.spark_api_key:
|
||||
llm = SparkAPI()
|
||||
elif CONFIG.zhipuai_api_key and CONFIG.zhipuai_api_key != "YOUR_API_KEY":
|
||||
llm = ZhiPuAIGPTAPI()
|
||||
else:
|
||||
raise RuntimeError("You should config a LLM configuration first")
|
||||
|
||||
return llm
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from dataclasses import dataclass
|
|||
class BaseChatbot(ABC):
|
||||
"""Abstract GPT class"""
|
||||
mode: str = "API"
|
||||
use_system_prompt: bool = True
|
||||
|
||||
@abstractmethod
|
||||
def ask(self, msg: str) -> str:
|
||||
|
|
|
|||
|
|
@ -32,15 +32,17 @@ class BaseGPTAPI(BaseChatbot):
|
|||
return self._system_msg(self.system_prompt)
|
||||
|
||||
def ask(self, msg: str) -> str:
|
||||
message = [self._default_system_msg(), self._user_msg(msg)]
|
||||
message = [self._default_system_msg(), self._user_msg(msg)] if self.use_system_prompt else [self._user_msg(msg)]
|
||||
rsp = self.completion(message)
|
||||
return self.get_choice_text(rsp)
|
||||
|
||||
async def aask(self, msg: str, system_msgs: Optional[list[str]] = None) -> str:
|
||||
if system_msgs:
|
||||
message = self._system_msgs(system_msgs) + [self._user_msg(msg)]
|
||||
message = self._system_msgs(system_msgs) + [self._user_msg(msg)] if self.use_system_prompt \
|
||||
else [self._user_msg(msg)]
|
||||
else:
|
||||
message = [self._default_system_msg(), self._user_msg(msg)]
|
||||
message = [self._default_system_msg(), self._user_msg(msg)] if self.use_system_prompt \
|
||||
else [self._user_msg(msg)]
|
||||
rsp = await self.acompletion_text(message, stream=True)
|
||||
logger.debug(message)
|
||||
# logger.debug(rsp)
|
||||
|
|
|
|||
62
metagpt/provider/general_api_requestor.py
Normal file
62
metagpt/provider/general_api_requestor.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : General Async API for http-based LLM model
|
||||
|
||||
from typing import AsyncGenerator, Tuple, Union, Optional, Literal
|
||||
import aiohttp
|
||||
import asyncio
|
||||
|
||||
from openai.api_requestor import APIRequestor
|
||||
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
class GeneralAPIRequestor(APIRequestor):
|
||||
"""
|
||||
usage
|
||||
# full_url = "{api_base}{url}"
|
||||
requester = GeneralAPIRequestor(api_base=api_base)
|
||||
result, _, api_key = await requester.arequest(
|
||||
method=method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
stream=stream,
|
||||
params=kwargs,
|
||||
request_timeout=120
|
||||
)
|
||||
"""
|
||||
|
||||
def _interpret_response_line(
|
||||
self, rbody: str, rcode: int, rheaders, stream: bool
|
||||
) -> str:
|
||||
# just do nothing to meet the APIRequestor process and return the raw data
|
||||
# due to the openai sdk will convert the data into OpenAIResponse which we don't need in general cases.
|
||||
|
||||
return rbody
|
||||
|
||||
async def _interpret_async_response(
|
||||
self, result: aiohttp.ClientResponse, stream: bool
|
||||
) -> Tuple[Union[str, AsyncGenerator[str, None]], bool]:
|
||||
if stream and "text/event-stream" in result.headers.get("Content-Type", ""):
|
||||
return (
|
||||
self._interpret_response_line(
|
||||
line, result.status, result.headers, stream=True
|
||||
)
|
||||
async for line in result.content
|
||||
), True
|
||||
else:
|
||||
try:
|
||||
await result.read()
|
||||
except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as e:
|
||||
raise TimeoutError("Request timed out") from e
|
||||
except aiohttp.ClientError as exp:
|
||||
logger.warning(f"response: {result.content}, exp: {exp}")
|
||||
return (
|
||||
self._interpret_response_line(
|
||||
await result.read(), # let the caller to decode the msg
|
||||
result.status,
|
||||
result.headers,
|
||||
stream=False,
|
||||
),
|
||||
False,
|
||||
)
|
||||
3
metagpt/provider/zhipuai/__init__.py
Normal file
3
metagpt/provider/zhipuai/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
78
metagpt/provider/zhipuai/async_sse_client.py
Normal file
78
metagpt/provider/zhipuai/async_sse_client.py
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : async_sse_client to make keep the use of Event to access response
|
||||
# refs to `https://github.com/zhipuai/zhipuai-sdk-python/blob/main/zhipuai/utils/sse_client.py`
|
||||
|
||||
from zhipuai.utils.sse_client import SSEClient, Event, _FIELD_SEPARATOR
|
||||
|
||||
|
||||
class AsyncSSEClient(SSEClient):
|
||||
|
||||
async def _aread(self):
|
||||
data = b""
|
||||
async for chunk in self._event_source:
|
||||
for line in chunk.splitlines(True):
|
||||
data += line
|
||||
if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
|
||||
yield data
|
||||
data = b""
|
||||
if data:
|
||||
yield data
|
||||
|
||||
async def async_events(self):
|
||||
async for chunk in self._aread():
|
||||
event = Event()
|
||||
# Split before decoding so splitlines() only uses \r and \n
|
||||
for line in chunk.splitlines():
|
||||
# Decode the line.
|
||||
line = line.decode(self._char_enc)
|
||||
|
||||
# Lines starting with a separator are comments and are to be
|
||||
# ignored.
|
||||
if not line.strip() or line.startswith(_FIELD_SEPARATOR):
|
||||
continue
|
||||
|
||||
data = line.split(_FIELD_SEPARATOR, 1)
|
||||
field = data[0]
|
||||
|
||||
# Ignore unknown fields.
|
||||
if field not in event.__dict__:
|
||||
self._logger.debug(
|
||||
"Saw invalid field %s while parsing " "Server Side Event", field
|
||||
)
|
||||
continue
|
||||
|
||||
if len(data) > 1:
|
||||
# From the spec:
|
||||
# "If value starts with a single U+0020 SPACE character,
|
||||
# remove it from value."
|
||||
if data[1].startswith(" "):
|
||||
value = data[1][1:]
|
||||
else:
|
||||
value = data[1]
|
||||
else:
|
||||
# If no value is present after the separator,
|
||||
# assume an empty value.
|
||||
value = ""
|
||||
|
||||
# The data field may come over multiple lines and their values
|
||||
# are concatenated with each other.
|
||||
if field == "data":
|
||||
event.__dict__[field] += value + "\n"
|
||||
else:
|
||||
event.__dict__[field] = value
|
||||
|
||||
# Events with no data are not dispatched.
|
||||
if not event.data:
|
||||
continue
|
||||
|
||||
# If the data field ends with a newline, remove it.
|
||||
if event.data.endswith("\n"):
|
||||
event.data = event.data[0:-1]
|
||||
|
||||
# Empty event names default to 'message'
|
||||
event.event = event.event or "message"
|
||||
|
||||
# Dispatch the event
|
||||
self._logger.debug("Dispatching %s...", event)
|
||||
yield event
|
||||
79
metagpt/provider/zhipuai/zhipu_model_api.py
Normal file
79
metagpt/provider/zhipuai/zhipu_model_api.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : zhipu model api to support sync & async for invoke & sse_invoke
|
||||
|
||||
import zhipuai
|
||||
from zhipuai.model_api.api import ModelAPI, InvokeType
|
||||
from zhipuai.utils.http_client import headers as zhipuai_default_headers
|
||||
|
||||
from metagpt.provider.zhipuai.async_sse_client import AsyncSSEClient
|
||||
from metagpt.provider.general_api_requestor import GeneralAPIRequestor
|
||||
|
||||
|
||||
class ZhiPuModelAPI(ModelAPI):
|
||||
|
||||
@classmethod
|
||||
def get_header(cls) -> dict:
|
||||
token = cls._generate_token()
|
||||
zhipuai_default_headers.update({"Authorization": token})
|
||||
return zhipuai_default_headers
|
||||
|
||||
@classmethod
|
||||
def get_sse_header(cls) -> dict:
|
||||
token = cls._generate_token()
|
||||
headers = {
|
||||
"Authorization": token
|
||||
}
|
||||
return headers
|
||||
|
||||
@classmethod
|
||||
def split_zhipu_api_url(cls, invoke_type: InvokeType, kwargs):
|
||||
# use this method to prevent zhipu api upgrading to different version.
|
||||
# and follow the GeneralAPIRequestor implemented based on openai sdk
|
||||
zhipu_api_url = cls._build_api_url(kwargs, invoke_type)
|
||||
"""
|
||||
example:
|
||||
zhipu_api_url: https://open.bigmodel.cn/api/paas/v3/model-api/{model}/{invoke_method}
|
||||
"""
|
||||
arr = zhipu_api_url.split("/api/")
|
||||
# ("https://open.bigmodel.cn/api/" , "/paas/v3/model-api/chatglm_turbo/invoke")
|
||||
return f"{arr[0]}/api", f"/{arr[1]}"
|
||||
|
||||
@classmethod
|
||||
async def arequest(cls, invoke_type: InvokeType, stream: bool, method: str, headers: dict, kwargs):
|
||||
# TODO to make the async request to be more generic for models in http mode.
|
||||
assert method in ["post", "get"]
|
||||
|
||||
api_base, url = cls.split_zhipu_api_url(invoke_type, kwargs)
|
||||
requester = GeneralAPIRequestor(api_base=api_base)
|
||||
result, _, api_key = await requester.arequest(
|
||||
method=method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
stream=stream,
|
||||
params=kwargs,
|
||||
request_timeout=zhipuai.api_timeout_seconds
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
async def ainvoke(cls, **kwargs) -> dict:
|
||||
""" async invoke different from raw method `async_invoke` which get the final result by task_id"""
|
||||
headers = cls.get_header()
|
||||
resp = await cls.arequest(invoke_type=InvokeType.SYNC,
|
||||
stream=False,
|
||||
method="post",
|
||||
headers=headers,
|
||||
kwargs=kwargs)
|
||||
return resp
|
||||
|
||||
@classmethod
|
||||
async def asse_invoke(cls, **kwargs) -> AsyncSSEClient:
|
||||
""" async sse_invoke """
|
||||
headers = cls.get_sse_header()
|
||||
return AsyncSSEClient(await cls.arequest(invoke_type=InvokeType.SSE,
|
||||
stream=True,
|
||||
method="post",
|
||||
headers=headers,
|
||||
kwargs=kwargs))
|
||||
137
metagpt/provider/zhipuai_api.py
Normal file
137
metagpt/provider/zhipuai_api.py
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : zhipuai LLM from https://open.bigmodel.cn/dev/api#sdk
|
||||
|
||||
from enum import Enum
|
||||
import json
|
||||
from tenacity import (
|
||||
after_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_fixed,
|
||||
)
|
||||
from requests import ConnectionError
|
||||
|
||||
import zhipuai
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.openai_api import CostManager, log_and_reraise
|
||||
from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI
|
||||
|
||||
|
||||
class ZhiPuEvent(Enum):
|
||||
ADD = "add"
|
||||
ERROR = "error"
|
||||
INTERRUPTED = "interrupted"
|
||||
FINISH = "finish"
|
||||
|
||||
|
||||
class ZhiPuAIGPTAPI(BaseGPTAPI):
|
||||
"""
|
||||
Refs to `https://open.bigmodel.cn/dev/api#chatglm_turbo`
|
||||
From now, there is only one model named `chatglm_turbo`
|
||||
"""
|
||||
|
||||
use_system_prompt: bool = False # zhipuai has no system prompt when use api
|
||||
|
||||
def __init__(self):
|
||||
self.__init_zhipuai(CONFIG)
|
||||
self.llm = ZhiPuModelAPI
|
||||
self.model = "chatglm_turbo" # so far only one model, just use it
|
||||
self._cost_manager = CostManager()
|
||||
|
||||
def __init_zhipuai(self, config: CONFIG):
|
||||
assert config.zhipuai_api_key
|
||||
zhipuai.api_key = config.zhipuai_api_key
|
||||
|
||||
def _const_kwargs(self, messages: list[dict]) -> dict:
|
||||
kwargs = {
|
||||
"model": self.model,
|
||||
"prompt": messages,
|
||||
"temperature": 0.3
|
||||
}
|
||||
return kwargs
|
||||
|
||||
def _update_costs(self, usage: dict):
|
||||
""" update each request's token cost """
|
||||
if CONFIG.calc_usage:
|
||||
try:
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0))
|
||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error("zhipuai updats costs failed!", e)
|
||||
|
||||
def get_choice_text(self, resp: dict) -> str:
|
||||
""" get the first text of choice from llm response """
|
||||
assist_msg = resp.get("data", {}).get("choices", [{"role": "error"}])[-1]
|
||||
assert assist_msg["role"] == "assistant"
|
||||
return assist_msg.get("content")
|
||||
|
||||
def completion(self, messages: list[dict]) -> dict:
|
||||
resp = self.llm.invoke(**self._const_kwargs(messages))
|
||||
usage = resp.get("data").get("usage")
|
||||
self._update_costs(usage)
|
||||
return resp
|
||||
|
||||
async def _achat_completion(self, messages: list[dict]) -> dict:
|
||||
resp = await self.llm.ainvoke(**self._const_kwargs(messages))
|
||||
usage = resp.get("data").get("usage")
|
||||
self._update_costs(usage)
|
||||
return resp
|
||||
|
||||
async def acompletion(self, messages: list[dict]) -> dict:
|
||||
return await self._achat_completion(messages)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict]) -> str:
|
||||
response = await self.llm.asse_invoke(**self._const_kwargs(messages))
|
||||
collected_content = []
|
||||
usage = {}
|
||||
async for event in response.async_events():
|
||||
if event.event == ZhiPuEvent.ADD.value:
|
||||
content = event.data
|
||||
collected_content.append(content)
|
||||
print(content, end="")
|
||||
elif event.event == ZhiPuEvent.ERROR.value or event.event == ZhiPuEvent.INTERRUPTED.value:
|
||||
content = event.data
|
||||
logger.error(f"event error: {content}", end="")
|
||||
collected_content.append([content])
|
||||
elif event.event == ZhiPuEvent.FINISH.value:
|
||||
"""
|
||||
event.meta
|
||||
{
|
||||
"task_status":"SUCCESS",
|
||||
"usage":{
|
||||
"completion_tokens":351,
|
||||
"prompt_tokens":595,
|
||||
"total_tokens":946
|
||||
},
|
||||
"task_id":"xx",
|
||||
"request_id":"xxx"
|
||||
}
|
||||
"""
|
||||
meta = json.loads(event.meta)
|
||||
usage = meta.get("usage")
|
||||
else:
|
||||
print(f"zhipuapi else event: {event.data}", end="")
|
||||
|
||||
self._update_costs(usage)
|
||||
full_content = "".join(collected_content)
|
||||
return full_content
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_fixed(1),
|
||||
after=after_log(logger, logger.level("WARNING").name),
|
||||
retry=retry_if_exception_type(ConnectionError),
|
||||
retry_error_callback=log_and_reraise
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
|
||||
""" response in async with stream or non-stream mode """
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
resp = await self._achat_completion(messages)
|
||||
return self.get_choice_text(resp)
|
||||
|
|
@ -22,6 +22,7 @@ TOKEN_COSTS = {
|
|||
"gpt-4-32k-0314": {"prompt": 0.06, "completion": 0.12},
|
||||
"gpt-4-0613": {"prompt": 0.06, "completion": 0.12},
|
||||
"text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0},
|
||||
"chatglm_turbo": {"prompt": 0.0, "completion": 0.00069} # 32k version, prompt + completion tokens=0.005¥/k-tokens
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -37,6 +38,7 @@ TOKEN_MAX = {
|
|||
"gpt-4-32k-0314": 32768,
|
||||
"gpt-4-0613": 8192,
|
||||
"text-embedding-ada-002": 8192,
|
||||
"chatglm_turbo": 32768
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -68,7 +70,9 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"):
|
|||
return count_message_tokens(messages, model="gpt-4-0613")
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
|
||||
f"num_tokens_from_messages() is not implemented for model {model}. "
|
||||
f"See https://github.com/openai/openai-python/blob/main/chatml.md "
|
||||
f"for information on how messages are converted to tokens."
|
||||
)
|
||||
num_tokens = 0
|
||||
for message in messages:
|
||||
|
|
|
|||
|
|
@ -44,3 +44,4 @@ ta==0.10.2
|
|||
semantic-kernel==0.3.13.dev0
|
||||
wrapt==1.15.0
|
||||
websocket-client==0.58.0
|
||||
zhipuai==1.0.7
|
||||
|
|
|
|||
47
tests/metagpt/provider/test_zhipuai_api.py
Normal file
47
tests/metagpt/provider/test_zhipuai_api.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of ZhiPuAIGPTAPI
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI
|
||||
|
||||
|
||||
default_resp = {
|
||||
"code": 200,
|
||||
"data": {
|
||||
"choices": [
|
||||
{"role": "assistant", "content": "I'm chatglm-turbo"}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "who are you"}
|
||||
]
|
||||
|
||||
|
||||
def mock_llm_ask(self, messages: list[dict]) -> dict:
|
||||
return default_resp
|
||||
|
||||
|
||||
def test_zhipuai_completion(mocker):
|
||||
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.completion", mock_llm_ask)
|
||||
|
||||
resp = ZhiPuAIGPTAPI().completion(messages)
|
||||
assert resp["code"] == 200
|
||||
assert "chatglm-turbo" in resp["data"]["choices"][0]["content"]
|
||||
|
||||
|
||||
async def mock_llm_aask(self, messgaes: list[dict], stream: bool = False) -> dict:
|
||||
return default_resp
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zhipuai_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.acompletion_text", mock_llm_aask)
|
||||
|
||||
resp = await ZhiPuAIGPTAPI().acompletion_text(messages, stream=False)
|
||||
|
||||
assert resp["code"] == 200
|
||||
assert "chatglm-turbo" in resp["data"]["choices"][0]["content"]
|
||||
Loading…
Add table
Add a link
Reference in a new issue