Revert "Feat add qianfan api support"

This commit is contained in:
better629 2024-02-08 07:32:34 +08:00 committed by GitHub
parent 351b3ae8df
commit 5a2084cda8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
28 changed files with 319 additions and 574 deletions

View file

@ -54,6 +54,7 @@ jobs:
export ALLOW_OPENAI_API_CALL=0
echo "${{ secrets.METAGPT_KEY_YAML }}" | base64 -d > config/key.yaml
mkdir -p ~/.metagpt && echo "${{ secrets.METAGPT_CONFIG2_YAML }}" | base64 -d > ~/.metagpt/config2.yaml
echo "${{ secrets.SPARK_YAML }}" | base64 -d > ~/.metagpt/spark.yaml
pytest tests/ --doctest-modules --cov=./metagpt/ --cov-report=xml:cov.xml --cov-report=html:htmlcov --durations=20 | tee unittest.txt
- name: Show coverage report
run: |

View file

@ -31,7 +31,7 @@ jobs:
- name: Test with pytest
run: |
export ALLOW_OPENAI_API_CALL=0
mkdir -p ~/.metagpt && cp tests/config2.yaml ~/.metagpt/config2.yaml
mkdir -p ~/.metagpt && cp tests/config2.yaml ~/.metagpt/config2.yaml && cp tests/spark.yaml ~/.metagpt/spark.yaml
pytest tests/ --doctest-modules --cov=./metagpt/ --cov-report=xml:cov.xml --cov-report=html:htmlcov --durations=20 | tee unittest.txt
- name: Show coverage report
run: |

View file

@ -13,18 +13,7 @@ from metagpt.logs import logger
async def main():
llm = LLM()
# llm type check
id_ques = "what's your name"
logger.info(f"{id_ques}: ")
logger.info(await llm.aask(id_ques))
logger.info("\n\n")
logger.info(
await llm.aask(
"who are you", system_msgs=["act as a robot, answer 'I'am robot' if the question is 'who are you'"]
)
)
logger.info(await llm.aask("hello world"))
logger.info(await llm.aask_batch(["hi", "write python hello world."]))
hello_msg = [{"role": "user", "content": "count from 1 to 10. split by newline."}]

View file

@ -24,7 +24,6 @@ class LLMType(Enum):
METAGPT = "metagpt"
AZURE = "azure"
OLLAMA = "ollama"
QIANFAN = "qianfan" # Baidu BCE
def __missing__(self, key):
return self.OPENAI
@ -37,18 +36,13 @@ class LLMConfig(YamlModel):
Optional Fields in pydantic: https://docs.pydantic.dev/latest/migration/#required-optional-and-nullable-fields
"""
api_key: str = "sk-"
api_key: str
api_type: LLMType = LLMType.OPENAI
base_url: str = "https://api.openai.com/v1"
api_version: Optional[str] = None
model: Optional[str] = None # also stands for DEPLOYMENT_NAME
# For Cloud Service Provider like Baidu/ Alibaba
access_key: Optional[str] = None
secret_key: Optional[str] = None
endpoint: Optional[str] = None # for self-deployed model on the cloud
# For Spark(Xunfei), maybe remove later
app_id: Optional[str] = None
api_secret: Optional[str] = None

View file

@ -16,7 +16,6 @@ from metagpt.provider.azure_openai_api import AzureOpenAILLM
from metagpt.provider.metagpt_api import MetaGPTLLM
from metagpt.provider.human_provider import HumanProvider
from metagpt.provider.spark_api import SparkLLM
from metagpt.provider.qianfan_api import QianFanLLM
__all__ = [
"FireworksLLM",
@ -29,5 +28,4 @@ __all__ = [
"OllamaLLM",
"HumanProvider",
"SparkLLM",
"QianFanLLM",
]

View file

@ -11,12 +11,11 @@ from abc import ABC, abstractmethod
from typing import Optional, Union
from openai import AsyncOpenAI
from pydantic import BaseModel
from metagpt.configs.llm_config import LLMConfig
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.utils.cost_manager import CostManager, Costs
from metagpt.utils.cost_manager import CostManager
class BaseLLM(ABC):
@ -68,28 +67,6 @@ class BaseLLM(ABC):
def _default_system_msg(self):
return self._system_msg(self.system_prompt)
def _update_costs(self, usage: Union[dict, BaseModel], model: str = None, local_calc_usage: bool = True):
"""update each request's token cost
Args:
model (str): model name or in some scenarios called endpoint
local_calc_usage (bool): some models don't calculate usage, it will overwrite LLMConfig.calc_usage
"""
calc_usage = self.config.calc_usage and local_calc_usage
model = model if model else self.model
usage = usage.model_dump() if isinstance(usage, BaseModel) else usage
if calc_usage and self.cost_manager:
try:
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
self.cost_manager.update_cost(prompt_tokens, completion_tokens, model)
except Exception as e:
logger.error(f"{self.__class__.__name__} updats costs failed! exp: {e}")
def get_costs(self) -> Costs:
if not self.cost_manager:
return Costs(0, 0, 0, 0)
return self.cost_manager.get_costs()
async def aask(
self,
msg: str,

View file

@ -19,7 +19,7 @@ from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.logs import logger
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import OpenAILLM, log_and_reraise
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.cost_manager import CostManager, Costs
MODEL_GRADE_TOKEN_COSTS = {
"-1": {"prompt": 0.0, "completion": 0.0}, # abnormal condition
@ -81,6 +81,17 @@ class FireworksLLM(OpenAILLM):
kwargs = dict(api_key=self.config.api_key, base_url=self.config.base_url)
return kwargs
def _update_costs(self, usage: CompletionUsage):
if self.config.calc_usage and usage:
try:
# use FireworksCostManager not context.cost_manager
self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
except Exception as e:
logger.error(f"updating costs failed!, exp: {e}")
def get_costs(self) -> Costs:
return self.cost_manager.get_costs()
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
**self._cons_kwargs(messages), stream=True
@ -102,7 +113,7 @@ class FireworksLLM(OpenAILLM):
usage = CompletionUsage(**chunk.usage)
full_content = "".join(collected_content)
self._update_costs(usage.model_dump())
self._update_costs(usage)
return full_content
@retry(

View file

@ -72,6 +72,16 @@ class GeminiLLM(BaseLLM):
kwargs = {"contents": messages, "generation_config": GenerationConfig(temperature=0.3), "stream": stream}
return kwargs
def _update_costs(self, usage: dict):
"""update each request's token cost"""
if self.config.calc_usage:
try:
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
except Exception as e:
logger.error(f"google gemini updats costs failed! exp: {e}")
def get_choice_text(self, resp: GenerateContentResponse) -> str:
return resp.text

View file

@ -46,6 +46,16 @@ class OllamaLLM(BaseLLM):
kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream}
return kwargs
def _update_costs(self, usage: dict):
"""update each request's token cost"""
if self.config.calc_usage:
try:
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
except Exception as e:
logger.error(f"ollama updats costs failed! exp: {e}")
def get_choice_text(self, resp: dict) -> str:
"""get the resp content from llm response"""
assist_msg = resp.get("message", {})

View file

@ -8,7 +8,7 @@ from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.logs import logger
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import OpenAILLM
from metagpt.utils.cost_manager import TokenCostManager
from metagpt.utils.cost_manager import Costs, TokenCostManager
from metagpt.utils.token_counter import count_message_tokens, count_string_tokens
@ -34,3 +34,14 @@ class OpenLLM(OpenAILLM):
logger.error(f"usage calculation failed!: {e}")
return usage
def _update_costs(self, usage: CompletionUsage):
if self.config.calc_usage and usage:
try:
# use OpenLLMCostManager not CONFIG.cost_manager
self._cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
except Exception as e:
logger.error(f"updating costs failed!, exp: {e}")
def get_costs(self) -> Costs:
return self._cost_manager.get_costs()

View file

@ -29,7 +29,7 @@ from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.schema import Message
from metagpt.utils.common import CodeParser, decode_image
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.cost_manager import CostManager, Costs
from metagpt.utils.exceptions import handle_exception
from metagpt.utils.token_counter import (
count_message_tokens,
@ -55,13 +55,16 @@ class OpenAILLM(BaseLLM):
def __init__(self, config: LLMConfig):
self.config = config
self._init_model()
self._init_client()
self.auto_max_tokens = False
self.cost_manager: Optional[CostManager] = None
def _init_model(self):
self.model = self.config.model # Used in _calc_usage & _cons_kwargs
def _init_client(self):
"""https://github.com/openai/openai-python#async-usage"""
self.model = self.config.model # Used in _calc_usage & _cons_kwargs
kwargs = self._make_client_kwargs()
self.aclient = AsyncOpenAI(**kwargs)
@ -237,6 +240,16 @@ class OpenAILLM(BaseLLM):
return usage
@handle_exception
def _update_costs(self, usage: CompletionUsage):
if self.config.calc_usage and usage and self.cost_manager:
self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
def get_costs(self) -> Costs:
if not self.cost_manager:
return Costs(0, 0, 0, 0)
return self.cost_manager.get_costs()
def _get_max_tokens(self, messages: list[dict]):
if not self.auto_max_tokens:
return self.config.max_token

View file

@ -1,152 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : llm api of qianfan from Baidu, supports ERNIE(wen xin yi yan) and opensource models
import copy
import os
import qianfan
from qianfan import ChatCompletion
from qianfan.resources.typing import JsonBody
from tenacity import (
after_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_random_exponential,
)
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import log_and_reraise
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.token_counter import (
QianFan_EndPoint_TOKEN_COSTS,
QianFan_MODEL_TOKEN_COSTS,
)
@register_provider(LLMType.QIANFAN)
class QianFanLLM(BaseLLM):
"""
Refs
Auth: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/3lmokh7n6#%E3%80%90%E6%8E%A8%E8%8D%90%E3%80%91%E4%BD%BF%E7%94%A8%E5%AE%89%E5%85%A8%E8%AE%A4%E8%AF%81aksk%E9%89%B4%E6%9D%83%E8%B0%83%E7%94%A8%E6%B5%81%E7%A8%8B
Token Price: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7#tokens%E5%90%8E%E4%BB%98%E8%B4%B9
Models: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/wlmhm7vuo#%E5%AF%B9%E8%AF%9Dchat
https://cloud.baidu.com/doc/WENXINWORKSHOP/s/xlmokikxe#%E6%94%AF%E6%8C%81%E6%A8%A1%E5%9E%8B%E5%88%97%E8%A1%A8
"""
def __init__(self, config: LLMConfig):
self.config = config
self.use_system_prompt = False # only some ERNIE-x related models support system_prompt
self.__init_qianfan()
self.cost_manager = CostManager(token_costs=self.token_costs)
def __init_qianfan(self):
if self.config.access_key and self.config.secret_key:
# for system level auth, use access_key and secret_key, recommended by official
# set environment variable due to official recommendation
os.environ.setdefault("QIANFAN_ACCESS_KEY", self.config.access_key)
os.environ.setdefault("QIANFAN_SECRET_KEY", self.config.secret_key)
elif self.config.api_key and self.config.secret_key:
# for application level auth, use api_key and secret_key
# set environment variable due to official recommendation
os.environ.setdefault("QIANFAN_AK", self.config.api_key)
os.environ.setdefault("QIANFAN_SK", self.config.secret_key)
else:
raise ValueError("Set the `access_key`&`secret_key` or `api_key`&`secret_key` first")
support_system_pairs = [
("ERNIE-Bot-4", "completions_pro"), # (model, corresponding-endpoint)
("ERNIE-Bot-8k", "ernie_bot_8k"),
("ERNIE-Bot", "completions"),
("ERNIE-Bot-turbo", "eb-instant"),
("ERNIE-Speed", "ernie_speed"),
("EB-turbo-AppBuilder", "ai_apaas"),
]
if self.config.model in [pair[0] for pair in support_system_pairs]:
# only some ERNIE models support
self.use_system_prompt = True
if self.config.endpoint in [pair[1] for pair in support_system_pairs]:
self.use_system_prompt = True
assert not (self.config.model and self.config.endpoint), "Only set `model` or `endpoint` in the config"
assert self.config.model or self.config.endpoint, "Should set one of `model` or `endpoint` in the config"
self.token_costs = copy.deepcopy(QianFan_MODEL_TOKEN_COSTS)
self.token_costs.update(QianFan_EndPoint_TOKEN_COSTS)
# self deployed model on the cloud not to calculate usage, it charges resource pool rental fee
self.calc_usage = self.config.calc_usage and self.config.endpoint is None
self.aclient: ChatCompletion = qianfan.ChatCompletion()
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
kwargs = {
"messages": messages,
"stream": stream,
}
if self.config.temperature > 0:
# different model has default temperature. only set when it's specified.
kwargs["temperature"] = self.config.temperature
if self.config.endpoint:
kwargs["endpoint"] = self.config.endpoint
elif self.config.model:
kwargs["model"] = self.config.model
if self.use_system_prompt:
# if the model support system prompt, extract and pass it
if messages[0]["role"] == "system":
kwargs["messages"] = messages[1:]
kwargs["system"] = messages[0]["content"] # set system prompt here
return kwargs
def _update_costs(self, usage: dict):
"""update each request's token cost"""
model_or_endpoint = self.config.model if self.config.model else self.config.endpoint
local_calc_usage = True if model_or_endpoint in self.token_costs else False
super()._update_costs(usage, model_or_endpoint, local_calc_usage)
def get_choice_text(self, resp: JsonBody) -> str:
return resp.get("result", "")
def completion(self, messages: list[dict]) -> JsonBody:
resp = self.aclient.do(**self._const_kwargs(messages=messages, stream=False))
self._update_costs(resp.body.get("usage", {}))
return resp.body
async def _achat_completion(self, messages: list[dict]) -> JsonBody:
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=False))
self._update_costs(resp.body.get("usage", {}))
return resp.body
async def acompletion(self, messages: list[dict], timeout=3) -> JsonBody:
return await self._achat_completion(messages)
async def _achat_completion_stream(self, messages: list[dict]) -> str:
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=True))
collected_content = []
usage = {}
async for chunk in resp:
content = chunk.body.get("result", "")
usage = chunk.body.get("usage", {})
log_llm_stream(content)
collected_content.append(content)
log_llm_stream("\n")
self._update_costs(usage)
full_content = "".join(collected_content)
return full_content
@retry(
stop=stop_after_attempt(3),
wait=wait_random_exponential(min=1, max=60),
after=after_log(logger, logger.level("WARNING").name),
retry=retry_if_exception_type(ConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
if stream:
return await self._achat_completion_stream(messages)
resp = await self._achat_completion(messages)
return self.get_choice_text(resp)

View file

@ -53,6 +53,16 @@ class ZhiPuAILLM(BaseLLM):
kwargs = {"model": self.model, "messages": messages, "stream": stream, "temperature": 0.3}
return kwargs
def _update_costs(self, usage: dict):
"""update each request's token cost"""
if self.config.calc_usage:
try:
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
except Exception as e:
logger.error(f"zhipuai updats costs failed! exp: {e}")
def completion(self, messages: list[dict], timeout=3) -> dict:
resp: Completion = self.llm.chat.completions.create(**self._const_kwargs(messages))
usage = resp.usage.model_dump()

View file

@ -29,7 +29,6 @@ class CostManager(BaseModel):
total_budget: float = 0
max_budget: float = 10.0
total_cost: float = 0
token_costs: dict[str, dict[str, float]] = TOKEN_COSTS
def update_cost(self, prompt_tokens, completion_tokens, model):
"""
@ -47,8 +46,7 @@ class CostManager(BaseModel):
return
cost = (
prompt_tokens * self.token_costs[model]["prompt"]
+ completion_tokens * self.token_costs[model]["completion"]
prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model]["completion"]
) / 1000
self.total_cost += cost
logger.info(

View file

@ -38,59 +38,6 @@ TOKEN_COSTS = {
}
"""
QianFan Token Price https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7#tokens%E5%90%8E%E4%BB%98%E8%B4%B9
Due to QianFan has multi price strategies, we unify `Tokens post-payment` as a statistical method.
"""
QianFan_MODEL_TOKEN_COSTS = {
"ERNIE-Bot-4": {"prompt": 0.017, "completion": 0.017},
"ERNIE-Bot-8k": {"prompt": 0.0034, "completion": 0.0067},
"ERNIE-Bot": {"prompt": 0.017, "completion": 0.017},
"ERNIE-Bot-turbo": {"prompt": 0.0011, "completion": 0.0011},
"EB-turbo-AppBuilder": {"prompt": 0.0011, "completion": 0.0011},
"ERNIE-Speed": {"prompt": 0.00056, "completion": 0.0011},
"BLOOMZ-7B": {"prompt": 0.00056, "completion": 0.00056},
"Llama-2-7B-Chat": {"prompt": 0.00056, "completion": 0.00056},
"Llama-2-13B-Chat": {"prompt": 0.00084, "completion": 0.00084},
"Llama-2-70B-Chat": {"prompt": 0.0049, "completion": 0.0049},
"ChatGLM2-6B-32K": {"prompt": 0.00056, "completion": 0.00056},
"AquilaChat-7B": {"prompt": 0.00056, "completion": 0.00056},
"Mixtral-8x7B-Instruct": {"prompt": 0.0049, "completion": 0.0049},
"SQLCoder-7B": {"prompt": 0.00056, "completion": 0.00056},
"CodeLlama-7B-Instruct": {"prompt": 0.00056, "completion": 0.00056},
"XuanYuan-70B-Chat-4bit": {"prompt": 0.0049, "completion": 0.0049},
"Qianfan-BLOOMZ-7B-compressed": {"prompt": 0.00056, "completion": 0.00056},
"Qianfan-Chinese-Llama-2-7B": {"prompt": 0.00056, "completion": 0.00056},
"Qianfan-Chinese-Llama-2-13B": {"prompt": 0.00084, "completion": 0.00084},
"ChatLaw": {"prompt": 0.0011, "completion": 0.0011},
"Yi-34B-Chat": {"prompt": 0.0, "completion": 0.0},
}
QianFan_EndPoint_TOKEN_COSTS = {
"completions_pro": QianFan_MODEL_TOKEN_COSTS["ERNIE-Bot-4"],
"ernie_bot_8k": QianFan_MODEL_TOKEN_COSTS["ERNIE-Bot-8k"],
"completions": QianFan_MODEL_TOKEN_COSTS["ERNIE-Bot"],
"eb-instant": QianFan_MODEL_TOKEN_COSTS["ERNIE-Bot-turbo"],
"ai_apaas": QianFan_MODEL_TOKEN_COSTS["EB-turbo-AppBuilder"],
"ernie_speed": QianFan_MODEL_TOKEN_COSTS["ERNIE-Speed"],
"bloomz_7b1": QianFan_MODEL_TOKEN_COSTS["BLOOMZ-7B"],
"llama_2_7b": QianFan_MODEL_TOKEN_COSTS["Llama-2-7B-Chat"],
"llama_2_13b": QianFan_MODEL_TOKEN_COSTS["Llama-2-13B-Chat"],
"llama_2_70b": QianFan_MODEL_TOKEN_COSTS["Llama-2-70B-Chat"],
"chatglm2_6b_32k": QianFan_MODEL_TOKEN_COSTS["ChatGLM2-6B-32K"],
"aquilachat_7b": QianFan_MODEL_TOKEN_COSTS["AquilaChat-7B"],
"mixtral_8x7b_instruct": QianFan_MODEL_TOKEN_COSTS["Mixtral-8x7B-Instruct"],
"sqlcoder_7b": QianFan_MODEL_TOKEN_COSTS["SQLCoder-7B"],
"codellama_7b_instruct": QianFan_MODEL_TOKEN_COSTS["CodeLlama-7B-Instruct"],
"xuanyuan_70b_chat": QianFan_MODEL_TOKEN_COSTS["XuanYuan-70B-Chat-4bit"],
"qianfan_bloomz_7b_compressed": QianFan_MODEL_TOKEN_COSTS["Qianfan-BLOOMZ-7B-compressed"],
"qianfan_chinese_llama_2_7b": QianFan_MODEL_TOKEN_COSTS["Qianfan-Chinese-Llama-2-7B"],
"qianfan_chinese_llama_2_13b": QianFan_MODEL_TOKEN_COSTS["Qianfan-Chinese-Llama-2-13B"],
"chatlaw": QianFan_MODEL_TOKEN_COSTS["ChatLaw"],
"yi_34b_chat": QianFan_MODEL_TOKEN_COSTS["Yi-34B-Chat"],
}
TOKEN_MAX = {
"gpt-3.5-turbo": 4096,
"gpt-3.5-turbo-0301": 4096,

View file

@ -67,4 +67,3 @@ playwright>=1.26 # used at metagpt/tools/libs/web_scraping.py
anytree
ipywidgets==8.1.1
Pillow
qianfan==0.3.1

View file

@ -42,15 +42,3 @@ mock_llm_config_zhipu = LLMConfig(
model="mock_zhipu_model",
proxy="http://localhost:8080",
)
mock_llm_config_spark = LLMConfig(
api_type="spark",
app_id="xxx",
api_key="xxx",
api_secret="xxx",
domain="generalv2",
base_url="wss://spark-api.xf-yun.com/v3.1/chat",
)
mock_llm_config_qianfan = LLMConfig(api_type="qianfan", access_key="xxx", secret_key="xxx", model="ERNIE-Bot-turbo")

View file

@ -1,117 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : default request & response data for provider unittest
from openai.types.chat.chat_completion import (
ChatCompletion,
ChatCompletionMessage,
Choice,
)
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.chat.chat_completion_chunk import Choice as AChoice
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from openai.types.completion_usage import CompletionUsage
from qianfan.resources.typing import QfResponse
from metagpt.provider.base_llm import BaseLLM
prompt = "who are you?"
messages = [{"role": "user", "content": prompt}]
resp_cont_tmpl = "I'm {name}"
default_resp_cont = resp_cont_tmpl.format(name="GPT")
# part of whole ChatCompletion of openai like structure
def get_part_chat_completion(name: str) -> dict:
part_chat_completion = {
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": resp_cont_tmpl.format(name=name),
},
"finish_reason": "stop",
}
],
"usage": {"completion_tokens": 22, "prompt_tokens": 19, "total_tokens": 41},
}
return part_chat_completion
def get_openai_chat_completion(name: str) -> ChatCompletion:
openai_chat_completion = ChatCompletion(
id="cmpl-a6652c1bb181caae8dd19ad8",
model="xx/xxx",
object="chat.completion",
created=1703300855,
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(role="assistant", content=resp_cont_tmpl.format(name=name)),
logprobs=None,
)
],
usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202),
)
return openai_chat_completion
def get_openai_chat_completion_chunk(name: str, usage_as_dict: bool = False) -> ChatCompletionChunk:
usage = CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202)
usage = usage if not usage_as_dict else usage.model_dump()
openai_chat_completion_chunk = ChatCompletionChunk(
id="cmpl-a6652c1bb181caae8dd19ad8",
model="xx/xxx",
object="chat.completion.chunk",
created=1703300855,
choices=[
AChoice(
delta=ChoiceDelta(role="assistant", content=resp_cont_tmpl.format(name=name)),
finish_reason="stop",
index=0,
logprobs=None,
)
],
usage=usage,
)
return openai_chat_completion_chunk
# For gemini
gemini_messages = [{"role": "user", "parts": prompt}]
# For QianFan
qf_jsonbody_dict = {
"id": "as-4v1h587fyv",
"object": "chat.completion",
"created": 1695021339,
"result": "",
"is_truncated": False,
"need_clear_history": False,
"usage": {"prompt_tokens": 7, "completion_tokens": 15, "total_tokens": 22},
}
def get_qianfan_response(name: str) -> QfResponse:
qf_jsonbody_dict["result"] = resp_cont_tmpl.format(name=name)
return QfResponse(code=200, body=qf_jsonbody_dict)
# For llm general chat functions call
async def llm_general_chat_funcs_test(llm: BaseLLM, prompt: str, messages: list[dict], resp_cont: str):
resp = await llm.aask(prompt, stream=False)
assert resp == resp_cont
resp = await llm.aask(prompt)
assert resp == resp_cont
resp = await llm.acompletion_text(messages, stream=False)
assert resp == resp_cont
resp = await llm.acompletion_text(messages, stream=True)
assert resp == resp_cont

View file

@ -8,25 +8,25 @@ from anthropic.resources.completions import Completion
from metagpt.provider.anthropic_api import Claude2
from tests.metagpt.provider.mock_llm_config import mock_llm_config
from tests.metagpt.provider.req_resp_const import prompt, resp_cont_tmpl
resp_cont = resp_cont_tmpl.format(name="Claude")
prompt = "who are you"
resp = "I'am Claude2"
def mock_anthropic_completions_create(self, model: str, prompt: str, max_tokens_to_sample: int) -> Completion:
return Completion(id="xx", completion=resp_cont, model="claude-2", stop_reason="stop_sequence", type="completion")
return Completion(id="xx", completion=resp, model="claude-2", stop_reason="stop_sequence", type="completion")
async def mock_anthropic_acompletions_create(self, model: str, prompt: str, max_tokens_to_sample: int) -> Completion:
return Completion(id="xx", completion=resp_cont, model="claude-2", stop_reason="stop_sequence", type="completion")
return Completion(id="xx", completion=resp, model="claude-2", stop_reason="stop_sequence", type="completion")
def test_claude2_ask(mocker):
mocker.patch("anthropic.resources.completions.Completions.create", mock_anthropic_completions_create)
assert resp_cont == Claude2(mock_llm_config).ask(prompt)
assert resp == Claude2(mock_llm_config).ask(prompt)
@pytest.mark.asyncio
async def test_claude2_aask(mocker):
mocker.patch("anthropic.resources.completions.AsyncCompletions.create", mock_anthropic_acompletions_create)
assert resp_cont == await Claude2(mock_llm_config).aask(prompt)
assert resp == await Claude2(mock_llm_config).aask(prompt)

View file

@ -11,13 +11,21 @@ import pytest
from metagpt.configs.llm_config import LLMConfig
from metagpt.provider.base_llm import BaseLLM
from metagpt.schema import Message
from tests.metagpt.provider.req_resp_const import (
default_resp_cont,
get_part_chat_completion,
prompt,
)
name = "GPT"
default_chat_resp = {
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "I'am GPT",
},
"finish_reason": "stop",
}
]
}
prompt_msg = "who are you"
resp_content = default_chat_resp["choices"][0]["message"]["content"]
class MockBaseLLM(BaseLLM):
@ -25,13 +33,16 @@ class MockBaseLLM(BaseLLM):
pass
def completion(self, messages: list[dict], timeout=3):
return get_part_chat_completion(name)
return default_chat_resp
async def acompletion(self, messages: list[dict], timeout=3):
return get_part_chat_completion(name)
return default_chat_resp
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
return default_resp_cont
return resp_content
async def close(self):
return default_chat_resp
def test_base_llm():
@ -75,25 +86,25 @@ def test_base_llm():
choice_text = base_llm.get_choice_text(openai_funccall_resp)
assert choice_text == openai_funccall_resp["choices"][0]["message"]["content"]
# resp = base_llm.ask(prompt)
# assert resp == default_resp_cont
# resp = base_llm.ask(prompt_msg)
# assert resp == resp_content
# resp = base_llm.ask_batch([prompt])
# assert resp == default_resp_cont
# resp = base_llm.ask_batch([prompt_msg])
# assert resp == resp_content
# resp = base_llm.ask_code([prompt])
# assert resp == default_resp_cont
# resp = base_llm.ask_code([prompt_msg])
# assert resp == resp_content
@pytest.mark.asyncio
async def test_async_base_llm():
base_llm = MockBaseLLM()
resp = await base_llm.aask(prompt)
assert resp == default_resp_cont
resp = await base_llm.aask(prompt_msg)
assert resp == resp_content
resp = await base_llm.aask_batch([prompt])
assert resp == default_resp_cont
resp = await base_llm.aask_batch([prompt_msg])
assert resp == resp_content
# resp = await base_llm.aask_code([prompt])
# assert resp == default_resp_cont
# resp = await base_llm.aask_code([prompt_msg])
# assert resp == resp_content

View file

@ -3,7 +3,14 @@
# @Desc : the unittest of fireworks api
import pytest
from openai.types.chat.chat_completion import (
ChatCompletion,
ChatCompletionMessage,
Choice,
)
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.chat.chat_completion_chunk import Choice as AChoice
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from openai.types.completion_usage import CompletionUsage
from metagpt.provider.fireworks_api import (
@ -13,19 +20,42 @@ from metagpt.provider.fireworks_api import (
)
from metagpt.utils.cost_manager import Costs
from tests.metagpt.provider.mock_llm_config import mock_llm_config
from tests.metagpt.provider.req_resp_const import (
get_openai_chat_completion,
get_openai_chat_completion_chunk,
llm_general_chat_funcs_test,
messages,
prompt,
resp_cont_tmpl,
resp_content = "I'm fireworks"
default_resp = ChatCompletion(
id="cmpl-a6652c1bb181caae8dd19ad8",
model="accounts/fireworks/models/llama-v2-13b-chat",
object="chat.completion",
created=1703300855,
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(role="assistant", content=resp_content),
logprobs=None,
)
],
usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202),
)
name = "fireworks"
resp_cont = resp_cont_tmpl.format(name=name)
default_resp = get_openai_chat_completion(name)
default_resp_chunk = get_openai_chat_completion_chunk(name, usage_as_dict=True)
default_resp_chunk = ChatCompletionChunk(
id=default_resp.id,
model=default_resp.model,
object="chat.completion.chunk",
created=default_resp.created,
choices=[
AChoice(
delta=ChoiceDelta(content=resp_content, role="assistant"),
finish_reason="stop",
index=0,
logprobs=None,
)
],
usage=dict(default_resp.usage),
)
prompt_msg = "who are you"
messages = [{"role": "user", "content": prompt_msg}]
def test_fireworks_costmanager():
@ -58,17 +88,27 @@ async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs)
async def test_fireworks_acompletion(mocker):
mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create)
fireworks_llm = FireworksLLM(mock_llm_config)
fireworks_llm.model = "llama-v2-13b-chat"
fireworks_gpt = FireworksLLM(mock_llm_config)
fireworks_gpt.model = "llama-v2-13b-chat"
fireworks_llm._update_costs(
fireworks_gpt._update_costs(
usage=CompletionUsage(prompt_tokens=500000, completion_tokens=500000, total_tokens=1000000)
)
assert fireworks_llm.get_costs() == Costs(
assert fireworks_gpt.get_costs() == Costs(
total_prompt_tokens=500000, total_completion_tokens=500000, total_cost=0.5, total_budget=0
)
resp = await fireworks_llm.acompletion(messages)
assert resp.choices[0].message.content in resp_cont
resp = await fireworks_gpt.acompletion(messages)
assert resp.choices[0].message.content in resp_content
await llm_general_chat_funcs_test(fireworks_llm, prompt, messages, resp_cont)
resp = await fireworks_gpt.aask(prompt_msg, stream=False)
assert resp == resp_content
resp = await fireworks_gpt.acompletion_text(messages, stream=False)
assert resp == resp_content
resp = await fireworks_gpt.acompletion_text(messages, stream=True)
assert resp == resp_content
resp = await fireworks_gpt.aask(prompt_msg)
assert resp == resp_content

View file

@ -11,12 +11,6 @@ from google.generativeai.types import content_types
from metagpt.provider.google_gemini_api import GeminiLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config
from tests.metagpt.provider.req_resp_const import (
gemini_messages,
llm_general_chat_funcs_test,
prompt,
resp_cont_tmpl,
)
@dataclass
@ -24,8 +18,10 @@ class MockGeminiResponse(ABC):
text: str
resp_cont = resp_cont_tmpl.format(name="gemini")
default_resp = MockGeminiResponse(text=resp_cont)
prompt_msg = "who are you"
messages = [{"role": "user", "parts": prompt_msg}]
resp_content = "I'm gemini from google"
default_resp = MockGeminiResponse(text=resp_content)
def mock_gemini_count_tokens(self, contents: content_types.ContentsType) -> glm.CountTokensResponse:
@ -64,18 +60,28 @@ async def test_gemini_acompletion(mocker):
mock_gemini_generate_content_async,
)
gemini_llm = GeminiLLM(mock_llm_config)
gemini_gpt = GeminiLLM(mock_llm_config)
assert gemini_llm._user_msg(prompt) == {"role": "user", "parts": [prompt]}
assert gemini_llm._assistant_msg(prompt) == {"role": "model", "parts": [prompt]}
assert gemini_gpt._user_msg(prompt_msg) == {"role": "user", "parts": [prompt_msg]}
assert gemini_gpt._assistant_msg(prompt_msg) == {"role": "model", "parts": [prompt_msg]}
usage = gemini_llm.get_usage(gemini_messages, resp_cont)
usage = gemini_gpt.get_usage(messages, resp_content)
assert usage == {"prompt_tokens": 20, "completion_tokens": 20}
resp = gemini_llm.completion(gemini_messages)
resp = gemini_gpt.completion(messages)
assert resp == default_resp
resp = await gemini_llm.acompletion(gemini_messages)
resp = await gemini_gpt.acompletion(messages)
assert resp.text == default_resp.text
await llm_general_chat_funcs_test(gemini_llm, prompt, gemini_messages, resp_cont)
resp = await gemini_gpt.aask(prompt_msg, stream=False)
assert resp == resp_content
resp = await gemini_gpt.acompletion_text(messages, stream=False)
assert resp == resp_content
resp = await gemini_gpt.acompletion_text(messages, stream=True)
assert resp == resp_content
resp = await gemini_gpt.aask(prompt_msg)
assert resp == resp_content

View file

@ -9,15 +9,12 @@ import pytest
from metagpt.provider.ollama_api import OllamaLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config
from tests.metagpt.provider.req_resp_const import (
llm_general_chat_funcs_test,
messages,
prompt,
resp_cont_tmpl,
)
resp_cont = resp_cont_tmpl.format(name="ollama")
default_resp = {"message": {"role": "assistant", "content": resp_cont}}
prompt_msg = "who are you"
messages = [{"role": "user", "content": prompt_msg}]
resp_content = "I'm ollama"
default_resp = {"message": {"role": "assistant", "content": resp_content}}
async def mock_ollama_arequest(self, stream: bool = False, **kwargs) -> Tuple[Any, Any, bool]:
@ -44,12 +41,19 @@ async def mock_ollama_arequest(self, stream: bool = False, **kwargs) -> Tuple[An
async def test_gemini_acompletion(mocker):
mocker.patch("metagpt.provider.general_api_requestor.GeneralAPIRequestor.arequest", mock_ollama_arequest)
ollama_llm = OllamaLLM(mock_llm_config)
ollama_gpt = OllamaLLM(mock_llm_config)
resp = await ollama_llm.acompletion(messages)
resp = await ollama_gpt.acompletion(messages)
assert resp["message"]["content"] == default_resp["message"]["content"]
resp = await ollama_llm.aask(prompt, stream=False)
assert resp == resp_cont
resp = await ollama_gpt.aask(prompt_msg, stream=False)
assert resp == resp_content
await llm_general_chat_funcs_test(ollama_llm, prompt, messages, resp_cont)
resp = await ollama_gpt.acompletion_text(messages, stream=False)
assert resp == resp_content
resp = await ollama_gpt.acompletion_text(messages, stream=True)
assert resp == resp_content
resp = await ollama_gpt.aask(prompt_msg)
assert resp == resp_content

View file

@ -3,26 +3,53 @@
# @Desc :
import pytest
from openai.types.chat.chat_completion import (
ChatCompletion,
ChatCompletionMessage,
Choice,
)
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.chat.chat_completion_chunk import Choice as AChoice
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from openai.types.completion_usage import CompletionUsage
from metagpt.provider.open_llm_api import OpenLLM
from metagpt.utils.cost_manager import CostManager, Costs
from metagpt.utils.cost_manager import Costs
from tests.metagpt.provider.mock_llm_config import mock_llm_config
from tests.metagpt.provider.req_resp_const import (
get_openai_chat_completion,
get_openai_chat_completion_chunk,
llm_general_chat_funcs_test,
messages,
prompt,
resp_cont_tmpl,
resp_content = "I'm llama2"
default_resp = ChatCompletion(
id="cmpl-a6652c1bb181caae8dd19ad8",
model="llama-v2-13b-chat",
object="chat.completion",
created=1703302755,
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(role="assistant", content=resp_content),
logprobs=None,
)
],
)
name = "llama2-7b"
resp_cont = resp_cont_tmpl.format(name=name)
default_resp = get_openai_chat_completion(name)
default_resp_chunk = ChatCompletionChunk(
id=default_resp.id,
model=default_resp.model,
object="chat.completion.chunk",
created=default_resp.created,
choices=[
AChoice(
delta=ChoiceDelta(content=resp_content, role="assistant"),
finish_reason="stop",
index=0,
logprobs=None,
)
],
)
default_resp_chunk = get_openai_chat_completion_chunk(name)
prompt_msg = "who are you"
messages = [{"role": "user", "content": prompt_msg}]
async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) -> ChatCompletionChunk:
@ -41,16 +68,25 @@ async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs)
async def test_openllm_acompletion(mocker):
mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create)
openllm_llm = OpenLLM(mock_llm_config)
openllm_llm.model = "llama-v2-13b-chat"
openllm_gpt = OpenLLM(mock_llm_config)
openllm_gpt.model = "llama-v2-13b-chat"
openllm_llm.cost_manager = CostManager()
openllm_llm._update_costs(usage=CompletionUsage(prompt_tokens=100, completion_tokens=100, total_tokens=200))
assert openllm_llm.get_costs() == Costs(
openllm_gpt._update_costs(usage=CompletionUsage(prompt_tokens=100, completion_tokens=100, total_tokens=200))
assert openllm_gpt.get_costs() == Costs(
total_prompt_tokens=100, total_completion_tokens=100, total_cost=0, total_budget=0
)
resp = await openllm_llm.acompletion(messages)
assert resp.choices[0].message.content in resp_cont
resp = await openllm_gpt.acompletion(messages)
assert resp.choices[0].message.content in resp_content
await llm_general_chat_funcs_test(openllm_llm, prompt, messages, resp_cont)
resp = await openllm_gpt.aask(prompt_msg, stream=False)
assert resp == resp_content
resp = await openllm_gpt.acompletion_text(messages, stream=False)
assert resp == resp_content
resp = await openllm_gpt.acompletion_text(messages, stream=True)
assert resp == resp_content
resp = await openllm_gpt.aask(prompt_msg)
assert resp == resp_content

View file

@ -1,56 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the unittest of qianfan api
from typing import AsyncIterator, Union
import pytest
from qianfan.resources.typing import JsonBody, QfResponse
from metagpt.provider.qianfan_api import QianFanLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config_qianfan
from tests.metagpt.provider.req_resp_const import (
get_qianfan_response,
llm_general_chat_funcs_test,
messages,
prompt,
resp_cont_tmpl,
)
name = "ERNIE-Bot-turbo"
resp_cont = resp_cont_tmpl.format(name=name)
def mock_qianfan_do(self, messages: list[dict], model: str, stream: bool = False, system: str = None) -> QfResponse:
return get_qianfan_response(name=name)
async def mock_qianfan_ado(
self, messages: list[dict], model: str, stream: bool = True, system: str = None
) -> Union[QfResponse, AsyncIterator[QfResponse]]:
resps = [get_qianfan_response(name=name)]
if stream:
async def aresp_iterator(resps: list[JsonBody]):
for resp in resps:
yield resp
return aresp_iterator(resps)
else:
return resps[0]
@pytest.mark.asyncio
async def test_qianfan_acompletion(mocker):
mocker.patch("qianfan.resources.llm.chat_completion.ChatCompletion.do", mock_qianfan_do)
mocker.patch("qianfan.resources.llm.chat_completion.ChatCompletion.ado", mock_qianfan_ado)
qianfan_llm = QianFanLLM(mock_llm_config_qianfan)
resp = qianfan_llm.completion(messages)
assert resp.get("result") == resp_cont
resp = await qianfan_llm.acompletion(messages)
assert resp.get("result") == resp_cont
await llm_general_chat_funcs_test(qianfan_llm, prompt, messages, resp_cont)

View file

@ -4,18 +4,12 @@
import pytest
from metagpt.config2 import Config
from metagpt.provider.spark_api import GetMessageFromWeb, SparkLLM
from tests.metagpt.provider.mock_llm_config import (
mock_llm_config,
mock_llm_config_spark,
)
from tests.metagpt.provider.req_resp_const import (
llm_general_chat_funcs_test,
prompt,
resp_cont_tmpl,
)
from tests.metagpt.provider.mock_llm_config import mock_llm_config
resp_cont = resp_cont_tmpl.format(name="Spark")
prompt_msg = "who are you"
resp_content = "I'm Spark"
class MockWebSocketApp(object):
@ -29,7 +23,7 @@ class MockWebSocketApp(object):
def test_get_msg_from_web(mocker):
mocker.patch("websocket.WebSocketApp", MockWebSocketApp)
get_msg_from_web = GetMessageFromWeb(prompt, mock_llm_config)
get_msg_from_web = GetMessageFromWeb(prompt_msg, mock_llm_config)
assert get_msg_from_web.gen_params()["parameter"]["chat"]["domain"] == "mock_domain"
ret = get_msg_from_web.run()
@ -37,26 +31,34 @@ def test_get_msg_from_web(mocker):
def mock_spark_get_msg_from_web_run(self) -> str:
return resp_cont
return resp_content
@pytest.mark.asyncio
async def test_spark_aask(mocker):
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
llm = SparkLLM(mock_llm_config_spark)
async def test_spark_aask():
llm = SparkLLM(Config.from_home("spark.yaml").llm)
resp = await llm.aask("Hello!")
assert resp == resp_cont
print(resp)
@pytest.mark.asyncio
async def test_spark_acompletion(mocker):
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
spark_llm = SparkLLM(mock_llm_config)
spark_gpt = SparkLLM(mock_llm_config)
resp = await spark_llm.acompletion([])
assert resp == resp_cont
resp = await spark_gpt.acompletion([])
assert resp == resp_content
await llm_general_chat_funcs_test(spark_llm, prompt, prompt, resp_cont)
resp = await spark_gpt.aask(prompt_msg, stream=False)
assert resp == resp_content
resp = await spark_gpt.acompletion_text([], stream=False)
assert resp == resp_content
resp = await spark_gpt.acompletion_text([], stream=True)
assert resp == resp_content
resp = await spark_gpt.aask(prompt_msg)
assert resp == resp_content

View file

@ -6,24 +6,22 @@ import pytest
from metagpt.provider.zhipuai_api import ZhiPuAILLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config_zhipu
from tests.metagpt.provider.req_resp_const import (
get_part_chat_completion,
llm_general_chat_funcs_test,
messages,
prompt,
resp_cont_tmpl,
)
name = "ChatGLM-4"
resp_cont = resp_cont_tmpl.format(name=name)
default_resp = get_part_chat_completion(name)
prompt_msg = "who are you"
messages = [{"role": "user", "content": prompt_msg}]
resp_content = "I'm chatglm-turbo"
default_resp = {
"choices": [{"finish_reason": "stop", "index": 0, "message": {"content": resp_content, "role": "assistant"}}],
"usage": {"completion_tokens": 22, "prompt_tokens": 19, "total_tokens": 41},
}
async def mock_zhipuai_acreate_stream(self, **kwargs):
class MockResponse(object):
async def _aread(self):
class Iterator(object):
events = [{"choices": [{"index": 0, "delta": {"content": resp_cont, "role": "assistant"}}]}]
events = [{"choices": [{"index": 0, "delta": {"content": resp_content, "role": "assistant"}}]}]
async def __aiter__(self):
for event in self.events:
@ -48,12 +46,22 @@ async def test_zhipuai_acompletion(mocker):
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.acreate", mock_zhipuai_acreate)
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.acreate_stream", mock_zhipuai_acreate_stream)
zhipu_llm = ZhiPuAILLM(mock_llm_config_zhipu)
zhipu_gpt = ZhiPuAILLM(mock_llm_config_zhipu)
resp = await zhipu_llm.acompletion(messages)
assert resp["choices"][0]["message"]["content"] == resp_cont
resp = await zhipu_gpt.acompletion(messages)
assert resp["choices"][0]["message"]["content"] == resp_content
await llm_general_chat_funcs_test(zhipu_llm, prompt, messages, resp_cont)
resp = await zhipu_gpt.aask(prompt_msg, stream=False)
assert resp == resp_content
resp = await zhipu_gpt.acompletion_text(messages, stream=False)
assert resp == resp_content
resp = await zhipu_gpt.acompletion_text(messages, stream=True)
assert resp == resp_content
resp = await zhipu_gpt.aask(prompt_msg)
assert resp == resp_content
def test_zhipuai_proxy():

7
tests/spark.yaml Normal file
View file

@ -0,0 +1,7 @@
llm:
api_type: "spark"
app_id: "xxx"
api_key: "xxx"
api_secret: "xxx"
domain: "generalv2"
base_url: "wss://spark-api.xf-yun.com/v3.1/chat"