mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-25 00:36:55 +02:00
Revert "Feat add qianfan api support"
This commit is contained in:
parent
351b3ae8df
commit
5a2084cda8
28 changed files with 319 additions and 574 deletions
1
.github/workflows/fulltest.yaml
vendored
1
.github/workflows/fulltest.yaml
vendored
|
|
@ -54,6 +54,7 @@ jobs:
|
|||
export ALLOW_OPENAI_API_CALL=0
|
||||
echo "${{ secrets.METAGPT_KEY_YAML }}" | base64 -d > config/key.yaml
|
||||
mkdir -p ~/.metagpt && echo "${{ secrets.METAGPT_CONFIG2_YAML }}" | base64 -d > ~/.metagpt/config2.yaml
|
||||
echo "${{ secrets.SPARK_YAML }}" | base64 -d > ~/.metagpt/spark.yaml
|
||||
pytest tests/ --doctest-modules --cov=./metagpt/ --cov-report=xml:cov.xml --cov-report=html:htmlcov --durations=20 | tee unittest.txt
|
||||
- name: Show coverage report
|
||||
run: |
|
||||
|
|
|
|||
2
.github/workflows/unittest.yaml
vendored
2
.github/workflows/unittest.yaml
vendored
|
|
@ -31,7 +31,7 @@ jobs:
|
|||
- name: Test with pytest
|
||||
run: |
|
||||
export ALLOW_OPENAI_API_CALL=0
|
||||
mkdir -p ~/.metagpt && cp tests/config2.yaml ~/.metagpt/config2.yaml
|
||||
mkdir -p ~/.metagpt && cp tests/config2.yaml ~/.metagpt/config2.yaml && cp tests/spark.yaml ~/.metagpt/spark.yaml
|
||||
pytest tests/ --doctest-modules --cov=./metagpt/ --cov-report=xml:cov.xml --cov-report=html:htmlcov --durations=20 | tee unittest.txt
|
||||
- name: Show coverage report
|
||||
run: |
|
||||
|
|
|
|||
|
|
@ -13,18 +13,7 @@ from metagpt.logs import logger
|
|||
|
||||
async def main():
|
||||
llm = LLM()
|
||||
# llm type check
|
||||
id_ques = "what's your name"
|
||||
logger.info(f"{id_ques}: ")
|
||||
logger.info(await llm.aask(id_ques))
|
||||
logger.info("\n\n")
|
||||
|
||||
logger.info(
|
||||
await llm.aask(
|
||||
"who are you", system_msgs=["act as a robot, answer 'I'am robot' if the question is 'who are you'"]
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(await llm.aask("hello world"))
|
||||
logger.info(await llm.aask_batch(["hi", "write python hello world."]))
|
||||
|
||||
hello_msg = [{"role": "user", "content": "count from 1 to 10. split by newline."}]
|
||||
|
|
|
|||
|
|
@ -24,7 +24,6 @@ class LLMType(Enum):
|
|||
METAGPT = "metagpt"
|
||||
AZURE = "azure"
|
||||
OLLAMA = "ollama"
|
||||
QIANFAN = "qianfan" # Baidu BCE
|
||||
|
||||
def __missing__(self, key):
|
||||
return self.OPENAI
|
||||
|
|
@ -37,18 +36,13 @@ class LLMConfig(YamlModel):
|
|||
Optional Fields in pydantic: https://docs.pydantic.dev/latest/migration/#required-optional-and-nullable-fields
|
||||
"""
|
||||
|
||||
api_key: str = "sk-"
|
||||
api_key: str
|
||||
api_type: LLMType = LLMType.OPENAI
|
||||
base_url: str = "https://api.openai.com/v1"
|
||||
api_version: Optional[str] = None
|
||||
|
||||
model: Optional[str] = None # also stands for DEPLOYMENT_NAME
|
||||
|
||||
# For Cloud Service Provider like Baidu/ Alibaba
|
||||
access_key: Optional[str] = None
|
||||
secret_key: Optional[str] = None
|
||||
endpoint: Optional[str] = None # for self-deployed model on the cloud
|
||||
|
||||
# For Spark(Xunfei), maybe remove later
|
||||
app_id: Optional[str] = None
|
||||
api_secret: Optional[str] = None
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ from metagpt.provider.azure_openai_api import AzureOpenAILLM
|
|||
from metagpt.provider.metagpt_api import MetaGPTLLM
|
||||
from metagpt.provider.human_provider import HumanProvider
|
||||
from metagpt.provider.spark_api import SparkLLM
|
||||
from metagpt.provider.qianfan_api import QianFanLLM
|
||||
|
||||
__all__ = [
|
||||
"FireworksLLM",
|
||||
|
|
@ -29,5 +28,4 @@ __all__ = [
|
|||
"OllamaLLM",
|
||||
"HumanProvider",
|
||||
"SparkLLM",
|
||||
"QianFanLLM",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -11,12 +11,11 @@ from abc import ABC, abstractmethod
|
|||
from typing import Optional, Union
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.cost_manager import CostManager, Costs
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
|
||||
|
||||
class BaseLLM(ABC):
|
||||
|
|
@ -68,28 +67,6 @@ class BaseLLM(ABC):
|
|||
def _default_system_msg(self):
|
||||
return self._system_msg(self.system_prompt)
|
||||
|
||||
def _update_costs(self, usage: Union[dict, BaseModel], model: str = None, local_calc_usage: bool = True):
|
||||
"""update each request's token cost
|
||||
Args:
|
||||
model (str): model name or in some scenarios called endpoint
|
||||
local_calc_usage (bool): some models don't calculate usage, it will overwrite LLMConfig.calc_usage
|
||||
"""
|
||||
calc_usage = self.config.calc_usage and local_calc_usage
|
||||
model = model if model else self.model
|
||||
usage = usage.model_dump() if isinstance(usage, BaseModel) else usage
|
||||
if calc_usage and self.cost_manager:
|
||||
try:
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0))
|
||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
self.cost_manager.update_cost(prompt_tokens, completion_tokens, model)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.__class__.__name__} updats costs failed! exp: {e}")
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
if not self.cost_manager:
|
||||
return Costs(0, 0, 0, 0)
|
||||
return self.cost_manager.get_costs()
|
||||
|
||||
async def aask(
|
||||
self,
|
||||
msg: str,
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from metagpt.configs.llm_config import LLMConfig, LLMType
|
|||
from metagpt.logs import logger
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import OpenAILLM, log_and_reraise
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.cost_manager import CostManager, Costs
|
||||
|
||||
MODEL_GRADE_TOKEN_COSTS = {
|
||||
"-1": {"prompt": 0.0, "completion": 0.0}, # abnormal condition
|
||||
|
|
@ -81,6 +81,17 @@ class FireworksLLM(OpenAILLM):
|
|||
kwargs = dict(api_key=self.config.api_key, base_url=self.config.base_url)
|
||||
return kwargs
|
||||
|
||||
def _update_costs(self, usage: CompletionUsage):
|
||||
if self.config.calc_usage and usage:
|
||||
try:
|
||||
# use FireworksCostManager not context.cost_manager
|
||||
self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"updating costs failed!, exp: {e}")
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
return self.cost_manager.get_costs()
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
|
||||
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
|
||||
**self._cons_kwargs(messages), stream=True
|
||||
|
|
@ -102,7 +113,7 @@ class FireworksLLM(OpenAILLM):
|
|||
usage = CompletionUsage(**chunk.usage)
|
||||
|
||||
full_content = "".join(collected_content)
|
||||
self._update_costs(usage.model_dump())
|
||||
self._update_costs(usage)
|
||||
return full_content
|
||||
|
||||
@retry(
|
||||
|
|
|
|||
|
|
@ -72,6 +72,16 @@ class GeminiLLM(BaseLLM):
|
|||
kwargs = {"contents": messages, "generation_config": GenerationConfig(temperature=0.3), "stream": stream}
|
||||
return kwargs
|
||||
|
||||
def _update_costs(self, usage: dict):
|
||||
"""update each request's token cost"""
|
||||
if self.config.calc_usage:
|
||||
try:
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0))
|
||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"google gemini updats costs failed! exp: {e}")
|
||||
|
||||
def get_choice_text(self, resp: GenerateContentResponse) -> str:
|
||||
return resp.text
|
||||
|
||||
|
|
|
|||
|
|
@ -46,6 +46,16 @@ class OllamaLLM(BaseLLM):
|
|||
kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream}
|
||||
return kwargs
|
||||
|
||||
def _update_costs(self, usage: dict):
|
||||
"""update each request's token cost"""
|
||||
if self.config.calc_usage:
|
||||
try:
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0))
|
||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"ollama updats costs failed! exp: {e}")
|
||||
|
||||
def get_choice_text(self, resp: dict) -> str:
|
||||
"""get the resp content from llm response"""
|
||||
assist_msg = resp.get("message", {})
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from metagpt.configs.llm_config import LLMConfig, LLMType
|
|||
from metagpt.logs import logger
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import OpenAILLM
|
||||
from metagpt.utils.cost_manager import TokenCostManager
|
||||
from metagpt.utils.cost_manager import Costs, TokenCostManager
|
||||
from metagpt.utils.token_counter import count_message_tokens, count_string_tokens
|
||||
|
||||
|
||||
|
|
@ -34,3 +34,14 @@ class OpenLLM(OpenAILLM):
|
|||
logger.error(f"usage calculation failed!: {e}")
|
||||
|
||||
return usage
|
||||
|
||||
def _update_costs(self, usage: CompletionUsage):
|
||||
if self.config.calc_usage and usage:
|
||||
try:
|
||||
# use OpenLLMCostManager not CONFIG.cost_manager
|
||||
self._cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"updating costs failed!, exp: {e}")
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
return self._cost_manager.get_costs()
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA
|
|||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import CodeParser, decode_image
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.cost_manager import CostManager, Costs
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
from metagpt.utils.token_counter import (
|
||||
count_message_tokens,
|
||||
|
|
@ -55,13 +55,16 @@ class OpenAILLM(BaseLLM):
|
|||
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.config = config
|
||||
self._init_model()
|
||||
self._init_client()
|
||||
self.auto_max_tokens = False
|
||||
self.cost_manager: Optional[CostManager] = None
|
||||
|
||||
def _init_model(self):
|
||||
self.model = self.config.model # Used in _calc_usage & _cons_kwargs
|
||||
|
||||
def _init_client(self):
|
||||
"""https://github.com/openai/openai-python#async-usage"""
|
||||
self.model = self.config.model # Used in _calc_usage & _cons_kwargs
|
||||
kwargs = self._make_client_kwargs()
|
||||
self.aclient = AsyncOpenAI(**kwargs)
|
||||
|
||||
|
|
@ -237,6 +240,16 @@ class OpenAILLM(BaseLLM):
|
|||
|
||||
return usage
|
||||
|
||||
@handle_exception
|
||||
def _update_costs(self, usage: CompletionUsage):
|
||||
if self.config.calc_usage and usage and self.cost_manager:
|
||||
self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
if not self.cost_manager:
|
||||
return Costs(0, 0, 0, 0)
|
||||
return self.cost_manager.get_costs()
|
||||
|
||||
def _get_max_tokens(self, messages: list[dict]):
|
||||
if not self.auto_max_tokens:
|
||||
return self.config.max_token
|
||||
|
|
|
|||
|
|
@ -1,152 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : llm api of qianfan from Baidu, supports ERNIE(wen xin yi yan) and opensource models
|
||||
import copy
|
||||
import os
|
||||
|
||||
import qianfan
|
||||
from qianfan import ChatCompletion
|
||||
from qianfan.resources.typing import JsonBody
|
||||
from tenacity import (
|
||||
after_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import log_and_reraise
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.token_counter import (
|
||||
QianFan_EndPoint_TOKEN_COSTS,
|
||||
QianFan_MODEL_TOKEN_COSTS,
|
||||
)
|
||||
|
||||
|
||||
@register_provider(LLMType.QIANFAN)
|
||||
class QianFanLLM(BaseLLM):
|
||||
"""
|
||||
Refs
|
||||
Auth: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/3lmokh7n6#%E3%80%90%E6%8E%A8%E8%8D%90%E3%80%91%E4%BD%BF%E7%94%A8%E5%AE%89%E5%85%A8%E8%AE%A4%E8%AF%81aksk%E9%89%B4%E6%9D%83%E8%B0%83%E7%94%A8%E6%B5%81%E7%A8%8B
|
||||
Token Price: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7#tokens%E5%90%8E%E4%BB%98%E8%B4%B9
|
||||
Models: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/wlmhm7vuo#%E5%AF%B9%E8%AF%9Dchat
|
||||
https://cloud.baidu.com/doc/WENXINWORKSHOP/s/xlmokikxe#%E6%94%AF%E6%8C%81%E6%A8%A1%E5%9E%8B%E5%88%97%E8%A1%A8
|
||||
"""
|
||||
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.config = config
|
||||
self.use_system_prompt = False # only some ERNIE-x related models support system_prompt
|
||||
self.__init_qianfan()
|
||||
self.cost_manager = CostManager(token_costs=self.token_costs)
|
||||
|
||||
def __init_qianfan(self):
|
||||
if self.config.access_key and self.config.secret_key:
|
||||
# for system level auth, use access_key and secret_key, recommended by official
|
||||
# set environment variable due to official recommendation
|
||||
os.environ.setdefault("QIANFAN_ACCESS_KEY", self.config.access_key)
|
||||
os.environ.setdefault("QIANFAN_SECRET_KEY", self.config.secret_key)
|
||||
elif self.config.api_key and self.config.secret_key:
|
||||
# for application level auth, use api_key and secret_key
|
||||
# set environment variable due to official recommendation
|
||||
os.environ.setdefault("QIANFAN_AK", self.config.api_key)
|
||||
os.environ.setdefault("QIANFAN_SK", self.config.secret_key)
|
||||
else:
|
||||
raise ValueError("Set the `access_key`&`secret_key` or `api_key`&`secret_key` first")
|
||||
|
||||
support_system_pairs = [
|
||||
("ERNIE-Bot-4", "completions_pro"), # (model, corresponding-endpoint)
|
||||
("ERNIE-Bot-8k", "ernie_bot_8k"),
|
||||
("ERNIE-Bot", "completions"),
|
||||
("ERNIE-Bot-turbo", "eb-instant"),
|
||||
("ERNIE-Speed", "ernie_speed"),
|
||||
("EB-turbo-AppBuilder", "ai_apaas"),
|
||||
]
|
||||
if self.config.model in [pair[0] for pair in support_system_pairs]:
|
||||
# only some ERNIE models support
|
||||
self.use_system_prompt = True
|
||||
if self.config.endpoint in [pair[1] for pair in support_system_pairs]:
|
||||
self.use_system_prompt = True
|
||||
|
||||
assert not (self.config.model and self.config.endpoint), "Only set `model` or `endpoint` in the config"
|
||||
assert self.config.model or self.config.endpoint, "Should set one of `model` or `endpoint` in the config"
|
||||
|
||||
self.token_costs = copy.deepcopy(QianFan_MODEL_TOKEN_COSTS)
|
||||
self.token_costs.update(QianFan_EndPoint_TOKEN_COSTS)
|
||||
|
||||
# self deployed model on the cloud not to calculate usage, it charges resource pool rental fee
|
||||
self.calc_usage = self.config.calc_usage and self.config.endpoint is None
|
||||
self.aclient: ChatCompletion = qianfan.ChatCompletion()
|
||||
|
||||
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
|
||||
kwargs = {
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
}
|
||||
if self.config.temperature > 0:
|
||||
# different model has default temperature. only set when it's specified.
|
||||
kwargs["temperature"] = self.config.temperature
|
||||
if self.config.endpoint:
|
||||
kwargs["endpoint"] = self.config.endpoint
|
||||
elif self.config.model:
|
||||
kwargs["model"] = self.config.model
|
||||
|
||||
if self.use_system_prompt:
|
||||
# if the model support system prompt, extract and pass it
|
||||
if messages[0]["role"] == "system":
|
||||
kwargs["messages"] = messages[1:]
|
||||
kwargs["system"] = messages[0]["content"] # set system prompt here
|
||||
return kwargs
|
||||
|
||||
def _update_costs(self, usage: dict):
|
||||
"""update each request's token cost"""
|
||||
model_or_endpoint = self.config.model if self.config.model else self.config.endpoint
|
||||
local_calc_usage = True if model_or_endpoint in self.token_costs else False
|
||||
super()._update_costs(usage, model_or_endpoint, local_calc_usage)
|
||||
|
||||
def get_choice_text(self, resp: JsonBody) -> str:
|
||||
return resp.get("result", "")
|
||||
|
||||
def completion(self, messages: list[dict]) -> JsonBody:
|
||||
resp = self.aclient.do(**self._const_kwargs(messages=messages, stream=False))
|
||||
self._update_costs(resp.body.get("usage", {}))
|
||||
return resp.body
|
||||
|
||||
async def _achat_completion(self, messages: list[dict]) -> JsonBody:
|
||||
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=False))
|
||||
self._update_costs(resp.body.get("usage", {}))
|
||||
return resp.body
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=3) -> JsonBody:
|
||||
return await self._achat_completion(messages)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict]) -> str:
|
||||
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=True))
|
||||
collected_content = []
|
||||
usage = {}
|
||||
async for chunk in resp:
|
||||
content = chunk.body.get("result", "")
|
||||
usage = chunk.body.get("usage", {})
|
||||
log_llm_stream(content)
|
||||
collected_content.append(content)
|
||||
log_llm_stream("\n")
|
||||
|
||||
self._update_costs(usage)
|
||||
full_content = "".join(collected_content)
|
||||
return full_content
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
after=after_log(logger, logger.level("WARNING").name),
|
||||
retry=retry_if_exception_type(ConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
resp = await self._achat_completion(messages)
|
||||
return self.get_choice_text(resp)
|
||||
|
|
@ -53,6 +53,16 @@ class ZhiPuAILLM(BaseLLM):
|
|||
kwargs = {"model": self.model, "messages": messages, "stream": stream, "temperature": 0.3}
|
||||
return kwargs
|
||||
|
||||
def _update_costs(self, usage: dict):
|
||||
"""update each request's token cost"""
|
||||
if self.config.calc_usage:
|
||||
try:
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0))
|
||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"zhipuai updats costs failed! exp: {e}")
|
||||
|
||||
def completion(self, messages: list[dict], timeout=3) -> dict:
|
||||
resp: Completion = self.llm.chat.completions.create(**self._const_kwargs(messages))
|
||||
usage = resp.usage.model_dump()
|
||||
|
|
|
|||
|
|
@ -29,7 +29,6 @@ class CostManager(BaseModel):
|
|||
total_budget: float = 0
|
||||
max_budget: float = 10.0
|
||||
total_cost: float = 0
|
||||
token_costs: dict[str, dict[str, float]] = TOKEN_COSTS
|
||||
|
||||
def update_cost(self, prompt_tokens, completion_tokens, model):
|
||||
"""
|
||||
|
|
@ -47,8 +46,7 @@ class CostManager(BaseModel):
|
|||
return
|
||||
|
||||
cost = (
|
||||
prompt_tokens * self.token_costs[model]["prompt"]
|
||||
+ completion_tokens * self.token_costs[model]["completion"]
|
||||
prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model]["completion"]
|
||||
) / 1000
|
||||
self.total_cost += cost
|
||||
logger.info(
|
||||
|
|
|
|||
|
|
@ -38,59 +38,6 @@ TOKEN_COSTS = {
|
|||
}
|
||||
|
||||
|
||||
"""
|
||||
QianFan Token Price https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7#tokens%E5%90%8E%E4%BB%98%E8%B4%B9
|
||||
Due to QianFan has multi price strategies, we unify `Tokens post-payment` as a statistical method.
|
||||
"""
|
||||
QianFan_MODEL_TOKEN_COSTS = {
|
||||
"ERNIE-Bot-4": {"prompt": 0.017, "completion": 0.017},
|
||||
"ERNIE-Bot-8k": {"prompt": 0.0034, "completion": 0.0067},
|
||||
"ERNIE-Bot": {"prompt": 0.017, "completion": 0.017},
|
||||
"ERNIE-Bot-turbo": {"prompt": 0.0011, "completion": 0.0011},
|
||||
"EB-turbo-AppBuilder": {"prompt": 0.0011, "completion": 0.0011},
|
||||
"ERNIE-Speed": {"prompt": 0.00056, "completion": 0.0011},
|
||||
"BLOOMZ-7B": {"prompt": 0.00056, "completion": 0.00056},
|
||||
"Llama-2-7B-Chat": {"prompt": 0.00056, "completion": 0.00056},
|
||||
"Llama-2-13B-Chat": {"prompt": 0.00084, "completion": 0.00084},
|
||||
"Llama-2-70B-Chat": {"prompt": 0.0049, "completion": 0.0049},
|
||||
"ChatGLM2-6B-32K": {"prompt": 0.00056, "completion": 0.00056},
|
||||
"AquilaChat-7B": {"prompt": 0.00056, "completion": 0.00056},
|
||||
"Mixtral-8x7B-Instruct": {"prompt": 0.0049, "completion": 0.0049},
|
||||
"SQLCoder-7B": {"prompt": 0.00056, "completion": 0.00056},
|
||||
"CodeLlama-7B-Instruct": {"prompt": 0.00056, "completion": 0.00056},
|
||||
"XuanYuan-70B-Chat-4bit": {"prompt": 0.0049, "completion": 0.0049},
|
||||
"Qianfan-BLOOMZ-7B-compressed": {"prompt": 0.00056, "completion": 0.00056},
|
||||
"Qianfan-Chinese-Llama-2-7B": {"prompt": 0.00056, "completion": 0.00056},
|
||||
"Qianfan-Chinese-Llama-2-13B": {"prompt": 0.00084, "completion": 0.00084},
|
||||
"ChatLaw": {"prompt": 0.0011, "completion": 0.0011},
|
||||
"Yi-34B-Chat": {"prompt": 0.0, "completion": 0.0},
|
||||
}
|
||||
|
||||
QianFan_EndPoint_TOKEN_COSTS = {
|
||||
"completions_pro": QianFan_MODEL_TOKEN_COSTS["ERNIE-Bot-4"],
|
||||
"ernie_bot_8k": QianFan_MODEL_TOKEN_COSTS["ERNIE-Bot-8k"],
|
||||
"completions": QianFan_MODEL_TOKEN_COSTS["ERNIE-Bot"],
|
||||
"eb-instant": QianFan_MODEL_TOKEN_COSTS["ERNIE-Bot-turbo"],
|
||||
"ai_apaas": QianFan_MODEL_TOKEN_COSTS["EB-turbo-AppBuilder"],
|
||||
"ernie_speed": QianFan_MODEL_TOKEN_COSTS["ERNIE-Speed"],
|
||||
"bloomz_7b1": QianFan_MODEL_TOKEN_COSTS["BLOOMZ-7B"],
|
||||
"llama_2_7b": QianFan_MODEL_TOKEN_COSTS["Llama-2-7B-Chat"],
|
||||
"llama_2_13b": QianFan_MODEL_TOKEN_COSTS["Llama-2-13B-Chat"],
|
||||
"llama_2_70b": QianFan_MODEL_TOKEN_COSTS["Llama-2-70B-Chat"],
|
||||
"chatglm2_6b_32k": QianFan_MODEL_TOKEN_COSTS["ChatGLM2-6B-32K"],
|
||||
"aquilachat_7b": QianFan_MODEL_TOKEN_COSTS["AquilaChat-7B"],
|
||||
"mixtral_8x7b_instruct": QianFan_MODEL_TOKEN_COSTS["Mixtral-8x7B-Instruct"],
|
||||
"sqlcoder_7b": QianFan_MODEL_TOKEN_COSTS["SQLCoder-7B"],
|
||||
"codellama_7b_instruct": QianFan_MODEL_TOKEN_COSTS["CodeLlama-7B-Instruct"],
|
||||
"xuanyuan_70b_chat": QianFan_MODEL_TOKEN_COSTS["XuanYuan-70B-Chat-4bit"],
|
||||
"qianfan_bloomz_7b_compressed": QianFan_MODEL_TOKEN_COSTS["Qianfan-BLOOMZ-7B-compressed"],
|
||||
"qianfan_chinese_llama_2_7b": QianFan_MODEL_TOKEN_COSTS["Qianfan-Chinese-Llama-2-7B"],
|
||||
"qianfan_chinese_llama_2_13b": QianFan_MODEL_TOKEN_COSTS["Qianfan-Chinese-Llama-2-13B"],
|
||||
"chatlaw": QianFan_MODEL_TOKEN_COSTS["ChatLaw"],
|
||||
"yi_34b_chat": QianFan_MODEL_TOKEN_COSTS["Yi-34B-Chat"],
|
||||
}
|
||||
|
||||
|
||||
TOKEN_MAX = {
|
||||
"gpt-3.5-turbo": 4096,
|
||||
"gpt-3.5-turbo-0301": 4096,
|
||||
|
|
|
|||
|
|
@ -67,4 +67,3 @@ playwright>=1.26 # used at metagpt/tools/libs/web_scraping.py
|
|||
anytree
|
||||
ipywidgets==8.1.1
|
||||
Pillow
|
||||
qianfan==0.3.1
|
||||
|
|
|
|||
|
|
@ -42,15 +42,3 @@ mock_llm_config_zhipu = LLMConfig(
|
|||
model="mock_zhipu_model",
|
||||
proxy="http://localhost:8080",
|
||||
)
|
||||
|
||||
|
||||
mock_llm_config_spark = LLMConfig(
|
||||
api_type="spark",
|
||||
app_id="xxx",
|
||||
api_key="xxx",
|
||||
api_secret="xxx",
|
||||
domain="generalv2",
|
||||
base_url="wss://spark-api.xf-yun.com/v3.1/chat",
|
||||
)
|
||||
|
||||
mock_llm_config_qianfan = LLMConfig(api_type="qianfan", access_key="xxx", secret_key="xxx", model="ERNIE-Bot-turbo")
|
||||
|
|
|
|||
|
|
@ -1,117 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : default request & response data for provider unittest
|
||||
|
||||
|
||||
from openai.types.chat.chat_completion import (
|
||||
ChatCompletion,
|
||||
ChatCompletionMessage,
|
||||
Choice,
|
||||
)
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion_chunk import Choice as AChoice
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
from qianfan.resources.typing import QfResponse
|
||||
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
|
||||
prompt = "who are you?"
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
resp_cont_tmpl = "I'm {name}"
|
||||
default_resp_cont = resp_cont_tmpl.format(name="GPT")
|
||||
|
||||
|
||||
# part of whole ChatCompletion of openai like structure
|
||||
def get_part_chat_completion(name: str) -> dict:
|
||||
part_chat_completion = {
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": resp_cont_tmpl.format(name=name),
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {"completion_tokens": 22, "prompt_tokens": 19, "total_tokens": 41},
|
||||
}
|
||||
return part_chat_completion
|
||||
|
||||
|
||||
def get_openai_chat_completion(name: str) -> ChatCompletion:
|
||||
openai_chat_completion = ChatCompletion(
|
||||
id="cmpl-a6652c1bb181caae8dd19ad8",
|
||||
model="xx/xxx",
|
||||
object="chat.completion",
|
||||
created=1703300855,
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(role="assistant", content=resp_cont_tmpl.format(name=name)),
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202),
|
||||
)
|
||||
return openai_chat_completion
|
||||
|
||||
|
||||
def get_openai_chat_completion_chunk(name: str, usage_as_dict: bool = False) -> ChatCompletionChunk:
|
||||
usage = CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202)
|
||||
usage = usage if not usage_as_dict else usage.model_dump()
|
||||
openai_chat_completion_chunk = ChatCompletionChunk(
|
||||
id="cmpl-a6652c1bb181caae8dd19ad8",
|
||||
model="xx/xxx",
|
||||
object="chat.completion.chunk",
|
||||
created=1703300855,
|
||||
choices=[
|
||||
AChoice(
|
||||
delta=ChoiceDelta(role="assistant", content=resp_cont_tmpl.format(name=name)),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
usage=usage,
|
||||
)
|
||||
return openai_chat_completion_chunk
|
||||
|
||||
|
||||
# For gemini
|
||||
gemini_messages = [{"role": "user", "parts": prompt}]
|
||||
|
||||
|
||||
# For QianFan
|
||||
qf_jsonbody_dict = {
|
||||
"id": "as-4v1h587fyv",
|
||||
"object": "chat.completion",
|
||||
"created": 1695021339,
|
||||
"result": "",
|
||||
"is_truncated": False,
|
||||
"need_clear_history": False,
|
||||
"usage": {"prompt_tokens": 7, "completion_tokens": 15, "total_tokens": 22},
|
||||
}
|
||||
|
||||
|
||||
def get_qianfan_response(name: str) -> QfResponse:
|
||||
qf_jsonbody_dict["result"] = resp_cont_tmpl.format(name=name)
|
||||
return QfResponse(code=200, body=qf_jsonbody_dict)
|
||||
|
||||
|
||||
# For llm general chat functions call
|
||||
async def llm_general_chat_funcs_test(llm: BaseLLM, prompt: str, messages: list[dict], resp_cont: str):
|
||||
resp = await llm.aask(prompt, stream=False)
|
||||
assert resp == resp_cont
|
||||
|
||||
resp = await llm.aask(prompt)
|
||||
assert resp == resp_cont
|
||||
|
||||
resp = await llm.acompletion_text(messages, stream=False)
|
||||
assert resp == resp_cont
|
||||
|
||||
resp = await llm.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_cont
|
||||
|
|
@ -8,25 +8,25 @@ from anthropic.resources.completions import Completion
|
|||
|
||||
from metagpt.provider.anthropic_api import Claude2
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config
|
||||
from tests.metagpt.provider.req_resp_const import prompt, resp_cont_tmpl
|
||||
|
||||
resp_cont = resp_cont_tmpl.format(name="Claude")
|
||||
prompt = "who are you"
|
||||
resp = "I'am Claude2"
|
||||
|
||||
|
||||
def mock_anthropic_completions_create(self, model: str, prompt: str, max_tokens_to_sample: int) -> Completion:
|
||||
return Completion(id="xx", completion=resp_cont, model="claude-2", stop_reason="stop_sequence", type="completion")
|
||||
return Completion(id="xx", completion=resp, model="claude-2", stop_reason="stop_sequence", type="completion")
|
||||
|
||||
|
||||
async def mock_anthropic_acompletions_create(self, model: str, prompt: str, max_tokens_to_sample: int) -> Completion:
|
||||
return Completion(id="xx", completion=resp_cont, model="claude-2", stop_reason="stop_sequence", type="completion")
|
||||
return Completion(id="xx", completion=resp, model="claude-2", stop_reason="stop_sequence", type="completion")
|
||||
|
||||
|
||||
def test_claude2_ask(mocker):
|
||||
mocker.patch("anthropic.resources.completions.Completions.create", mock_anthropic_completions_create)
|
||||
assert resp_cont == Claude2(mock_llm_config).ask(prompt)
|
||||
assert resp == Claude2(mock_llm_config).ask(prompt)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_claude2_aask(mocker):
|
||||
mocker.patch("anthropic.resources.completions.AsyncCompletions.create", mock_anthropic_acompletions_create)
|
||||
assert resp_cont == await Claude2(mock_llm_config).aask(prompt)
|
||||
assert resp == await Claude2(mock_llm_config).aask(prompt)
|
||||
|
|
|
|||
|
|
@ -11,13 +11,21 @@ import pytest
|
|||
from metagpt.configs.llm_config import LLMConfig
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Message
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
default_resp_cont,
|
||||
get_part_chat_completion,
|
||||
prompt,
|
||||
)
|
||||
|
||||
name = "GPT"
|
||||
default_chat_resp = {
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "I'am GPT",
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
]
|
||||
}
|
||||
prompt_msg = "who are you"
|
||||
resp_content = default_chat_resp["choices"][0]["message"]["content"]
|
||||
|
||||
|
||||
class MockBaseLLM(BaseLLM):
|
||||
|
|
@ -25,13 +33,16 @@ class MockBaseLLM(BaseLLM):
|
|||
pass
|
||||
|
||||
def completion(self, messages: list[dict], timeout=3):
|
||||
return get_part_chat_completion(name)
|
||||
return default_chat_resp
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=3):
|
||||
return get_part_chat_completion(name)
|
||||
return default_chat_resp
|
||||
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
|
||||
return default_resp_cont
|
||||
return resp_content
|
||||
|
||||
async def close(self):
|
||||
return default_chat_resp
|
||||
|
||||
|
||||
def test_base_llm():
|
||||
|
|
@ -75,25 +86,25 @@ def test_base_llm():
|
|||
choice_text = base_llm.get_choice_text(openai_funccall_resp)
|
||||
assert choice_text == openai_funccall_resp["choices"][0]["message"]["content"]
|
||||
|
||||
# resp = base_llm.ask(prompt)
|
||||
# assert resp == default_resp_cont
|
||||
# resp = base_llm.ask(prompt_msg)
|
||||
# assert resp == resp_content
|
||||
|
||||
# resp = base_llm.ask_batch([prompt])
|
||||
# assert resp == default_resp_cont
|
||||
# resp = base_llm.ask_batch([prompt_msg])
|
||||
# assert resp == resp_content
|
||||
|
||||
# resp = base_llm.ask_code([prompt])
|
||||
# assert resp == default_resp_cont
|
||||
# resp = base_llm.ask_code([prompt_msg])
|
||||
# assert resp == resp_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_base_llm():
|
||||
base_llm = MockBaseLLM()
|
||||
|
||||
resp = await base_llm.aask(prompt)
|
||||
assert resp == default_resp_cont
|
||||
resp = await base_llm.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await base_llm.aask_batch([prompt])
|
||||
assert resp == default_resp_cont
|
||||
resp = await base_llm.aask_batch([prompt_msg])
|
||||
assert resp == resp_content
|
||||
|
||||
# resp = await base_llm.aask_code([prompt])
|
||||
# assert resp == default_resp_cont
|
||||
# resp = await base_llm.aask_code([prompt_msg])
|
||||
# assert resp == resp_content
|
||||
|
|
|
|||
|
|
@ -3,7 +3,14 @@
|
|||
# @Desc : the unittest of fireworks api
|
||||
|
||||
import pytest
|
||||
from openai.types.chat.chat_completion import (
|
||||
ChatCompletion,
|
||||
ChatCompletionMessage,
|
||||
Choice,
|
||||
)
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion_chunk import Choice as AChoice
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
|
||||
from metagpt.provider.fireworks_api import (
|
||||
|
|
@ -13,19 +20,42 @@ from metagpt.provider.fireworks_api import (
|
|||
)
|
||||
from metagpt.utils.cost_manager import Costs
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
get_openai_chat_completion,
|
||||
get_openai_chat_completion_chunk,
|
||||
llm_general_chat_funcs_test,
|
||||
messages,
|
||||
prompt,
|
||||
resp_cont_tmpl,
|
||||
|
||||
resp_content = "I'm fireworks"
|
||||
default_resp = ChatCompletion(
|
||||
id="cmpl-a6652c1bb181caae8dd19ad8",
|
||||
model="accounts/fireworks/models/llama-v2-13b-chat",
|
||||
object="chat.completion",
|
||||
created=1703300855,
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(role="assistant", content=resp_content),
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202),
|
||||
)
|
||||
|
||||
name = "fireworks"
|
||||
resp_cont = resp_cont_tmpl.format(name=name)
|
||||
default_resp = get_openai_chat_completion(name)
|
||||
default_resp_chunk = get_openai_chat_completion_chunk(name, usage_as_dict=True)
|
||||
default_resp_chunk = ChatCompletionChunk(
|
||||
id=default_resp.id,
|
||||
model=default_resp.model,
|
||||
object="chat.completion.chunk",
|
||||
created=default_resp.created,
|
||||
choices=[
|
||||
AChoice(
|
||||
delta=ChoiceDelta(content=resp_content, role="assistant"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
usage=dict(default_resp.usage),
|
||||
)
|
||||
|
||||
prompt_msg = "who are you"
|
||||
messages = [{"role": "user", "content": prompt_msg}]
|
||||
|
||||
|
||||
def test_fireworks_costmanager():
|
||||
|
|
@ -58,17 +88,27 @@ async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs)
|
|||
async def test_fireworks_acompletion(mocker):
|
||||
mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create)
|
||||
|
||||
fireworks_llm = FireworksLLM(mock_llm_config)
|
||||
fireworks_llm.model = "llama-v2-13b-chat"
|
||||
fireworks_gpt = FireworksLLM(mock_llm_config)
|
||||
fireworks_gpt.model = "llama-v2-13b-chat"
|
||||
|
||||
fireworks_llm._update_costs(
|
||||
fireworks_gpt._update_costs(
|
||||
usage=CompletionUsage(prompt_tokens=500000, completion_tokens=500000, total_tokens=1000000)
|
||||
)
|
||||
assert fireworks_llm.get_costs() == Costs(
|
||||
assert fireworks_gpt.get_costs() == Costs(
|
||||
total_prompt_tokens=500000, total_completion_tokens=500000, total_cost=0.5, total_budget=0
|
||||
)
|
||||
|
||||
resp = await fireworks_llm.acompletion(messages)
|
||||
assert resp.choices[0].message.content in resp_cont
|
||||
resp = await fireworks_gpt.acompletion(messages)
|
||||
assert resp.choices[0].message.content in resp_content
|
||||
|
||||
await llm_general_chat_funcs_test(fireworks_llm, prompt, messages, resp_cont)
|
||||
resp = await fireworks_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await fireworks_gpt.acompletion_text(messages, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await fireworks_gpt.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await fireworks_gpt.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
|
|
|||
|
|
@ -11,12 +11,6 @@ from google.generativeai.types import content_types
|
|||
|
||||
from metagpt.provider.google_gemini_api import GeminiLLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
gemini_messages,
|
||||
llm_general_chat_funcs_test,
|
||||
prompt,
|
||||
resp_cont_tmpl,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -24,8 +18,10 @@ class MockGeminiResponse(ABC):
|
|||
text: str
|
||||
|
||||
|
||||
resp_cont = resp_cont_tmpl.format(name="gemini")
|
||||
default_resp = MockGeminiResponse(text=resp_cont)
|
||||
prompt_msg = "who are you"
|
||||
messages = [{"role": "user", "parts": prompt_msg}]
|
||||
resp_content = "I'm gemini from google"
|
||||
default_resp = MockGeminiResponse(text=resp_content)
|
||||
|
||||
|
||||
def mock_gemini_count_tokens(self, contents: content_types.ContentsType) -> glm.CountTokensResponse:
|
||||
|
|
@ -64,18 +60,28 @@ async def test_gemini_acompletion(mocker):
|
|||
mock_gemini_generate_content_async,
|
||||
)
|
||||
|
||||
gemini_llm = GeminiLLM(mock_llm_config)
|
||||
gemini_gpt = GeminiLLM(mock_llm_config)
|
||||
|
||||
assert gemini_llm._user_msg(prompt) == {"role": "user", "parts": [prompt]}
|
||||
assert gemini_llm._assistant_msg(prompt) == {"role": "model", "parts": [prompt]}
|
||||
assert gemini_gpt._user_msg(prompt_msg) == {"role": "user", "parts": [prompt_msg]}
|
||||
assert gemini_gpt._assistant_msg(prompt_msg) == {"role": "model", "parts": [prompt_msg]}
|
||||
|
||||
usage = gemini_llm.get_usage(gemini_messages, resp_cont)
|
||||
usage = gemini_gpt.get_usage(messages, resp_content)
|
||||
assert usage == {"prompt_tokens": 20, "completion_tokens": 20}
|
||||
|
||||
resp = gemini_llm.completion(gemini_messages)
|
||||
resp = gemini_gpt.completion(messages)
|
||||
assert resp == default_resp
|
||||
|
||||
resp = await gemini_llm.acompletion(gemini_messages)
|
||||
resp = await gemini_gpt.acompletion(messages)
|
||||
assert resp.text == default_resp.text
|
||||
|
||||
await llm_general_chat_funcs_test(gemini_llm, prompt, gemini_messages, resp_cont)
|
||||
resp = await gemini_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await gemini_gpt.acompletion_text(messages, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await gemini_gpt.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await gemini_gpt.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
|
|
|||
|
|
@ -9,15 +9,12 @@ import pytest
|
|||
|
||||
from metagpt.provider.ollama_api import OllamaLLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
llm_general_chat_funcs_test,
|
||||
messages,
|
||||
prompt,
|
||||
resp_cont_tmpl,
|
||||
)
|
||||
|
||||
resp_cont = resp_cont_tmpl.format(name="ollama")
|
||||
default_resp = {"message": {"role": "assistant", "content": resp_cont}}
|
||||
prompt_msg = "who are you"
|
||||
messages = [{"role": "user", "content": prompt_msg}]
|
||||
|
||||
resp_content = "I'm ollama"
|
||||
default_resp = {"message": {"role": "assistant", "content": resp_content}}
|
||||
|
||||
|
||||
async def mock_ollama_arequest(self, stream: bool = False, **kwargs) -> Tuple[Any, Any, bool]:
|
||||
|
|
@ -44,12 +41,19 @@ async def mock_ollama_arequest(self, stream: bool = False, **kwargs) -> Tuple[An
|
|||
async def test_gemini_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.general_api_requestor.GeneralAPIRequestor.arequest", mock_ollama_arequest)
|
||||
|
||||
ollama_llm = OllamaLLM(mock_llm_config)
|
||||
ollama_gpt = OllamaLLM(mock_llm_config)
|
||||
|
||||
resp = await ollama_llm.acompletion(messages)
|
||||
resp = await ollama_gpt.acompletion(messages)
|
||||
assert resp["message"]["content"] == default_resp["message"]["content"]
|
||||
|
||||
resp = await ollama_llm.aask(prompt, stream=False)
|
||||
assert resp == resp_cont
|
||||
resp = await ollama_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
await llm_general_chat_funcs_test(ollama_llm, prompt, messages, resp_cont)
|
||||
resp = await ollama_gpt.acompletion_text(messages, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await ollama_gpt.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await ollama_gpt.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
|
|
|||
|
|
@ -3,26 +3,53 @@
|
|||
# @Desc :
|
||||
|
||||
import pytest
|
||||
from openai.types.chat.chat_completion import (
|
||||
ChatCompletion,
|
||||
ChatCompletionMessage,
|
||||
Choice,
|
||||
)
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion_chunk import Choice as AChoice
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
|
||||
from metagpt.provider.open_llm_api import OpenLLM
|
||||
from metagpt.utils.cost_manager import CostManager, Costs
|
||||
from metagpt.utils.cost_manager import Costs
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
get_openai_chat_completion,
|
||||
get_openai_chat_completion_chunk,
|
||||
llm_general_chat_funcs_test,
|
||||
messages,
|
||||
prompt,
|
||||
resp_cont_tmpl,
|
||||
|
||||
resp_content = "I'm llama2"
|
||||
default_resp = ChatCompletion(
|
||||
id="cmpl-a6652c1bb181caae8dd19ad8",
|
||||
model="llama-v2-13b-chat",
|
||||
object="chat.completion",
|
||||
created=1703302755,
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(role="assistant", content=resp_content),
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
name = "llama2-7b"
|
||||
resp_cont = resp_cont_tmpl.format(name=name)
|
||||
default_resp = get_openai_chat_completion(name)
|
||||
default_resp_chunk = ChatCompletionChunk(
|
||||
id=default_resp.id,
|
||||
model=default_resp.model,
|
||||
object="chat.completion.chunk",
|
||||
created=default_resp.created,
|
||||
choices=[
|
||||
AChoice(
|
||||
delta=ChoiceDelta(content=resp_content, role="assistant"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
default_resp_chunk = get_openai_chat_completion_chunk(name)
|
||||
prompt_msg = "who are you"
|
||||
messages = [{"role": "user", "content": prompt_msg}]
|
||||
|
||||
|
||||
async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) -> ChatCompletionChunk:
|
||||
|
|
@ -41,16 +68,25 @@ async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs)
|
|||
async def test_openllm_acompletion(mocker):
|
||||
mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create)
|
||||
|
||||
openllm_llm = OpenLLM(mock_llm_config)
|
||||
openllm_llm.model = "llama-v2-13b-chat"
|
||||
openllm_gpt = OpenLLM(mock_llm_config)
|
||||
openllm_gpt.model = "llama-v2-13b-chat"
|
||||
|
||||
openllm_llm.cost_manager = CostManager()
|
||||
openllm_llm._update_costs(usage=CompletionUsage(prompt_tokens=100, completion_tokens=100, total_tokens=200))
|
||||
assert openllm_llm.get_costs() == Costs(
|
||||
openllm_gpt._update_costs(usage=CompletionUsage(prompt_tokens=100, completion_tokens=100, total_tokens=200))
|
||||
assert openllm_gpt.get_costs() == Costs(
|
||||
total_prompt_tokens=100, total_completion_tokens=100, total_cost=0, total_budget=0
|
||||
)
|
||||
|
||||
resp = await openllm_llm.acompletion(messages)
|
||||
assert resp.choices[0].message.content in resp_cont
|
||||
resp = await openllm_gpt.acompletion(messages)
|
||||
assert resp.choices[0].message.content in resp_content
|
||||
|
||||
await llm_general_chat_funcs_test(openllm_llm, prompt, messages, resp_cont)
|
||||
resp = await openllm_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await openllm_gpt.acompletion_text(messages, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await openllm_gpt.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await openllm_gpt.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
|
|
|||
|
|
@ -1,56 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of qianfan api
|
||||
|
||||
from typing import AsyncIterator, Union
|
||||
|
||||
import pytest
|
||||
from qianfan.resources.typing import JsonBody, QfResponse
|
||||
|
||||
from metagpt.provider.qianfan_api import QianFanLLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config_qianfan
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
get_qianfan_response,
|
||||
llm_general_chat_funcs_test,
|
||||
messages,
|
||||
prompt,
|
||||
resp_cont_tmpl,
|
||||
)
|
||||
|
||||
name = "ERNIE-Bot-turbo"
|
||||
resp_cont = resp_cont_tmpl.format(name=name)
|
||||
|
||||
|
||||
def mock_qianfan_do(self, messages: list[dict], model: str, stream: bool = False, system: str = None) -> QfResponse:
|
||||
return get_qianfan_response(name=name)
|
||||
|
||||
|
||||
async def mock_qianfan_ado(
|
||||
self, messages: list[dict], model: str, stream: bool = True, system: str = None
|
||||
) -> Union[QfResponse, AsyncIterator[QfResponse]]:
|
||||
resps = [get_qianfan_response(name=name)]
|
||||
if stream:
|
||||
|
||||
async def aresp_iterator(resps: list[JsonBody]):
|
||||
for resp in resps:
|
||||
yield resp
|
||||
|
||||
return aresp_iterator(resps)
|
||||
else:
|
||||
return resps[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_qianfan_acompletion(mocker):
|
||||
mocker.patch("qianfan.resources.llm.chat_completion.ChatCompletion.do", mock_qianfan_do)
|
||||
mocker.patch("qianfan.resources.llm.chat_completion.ChatCompletion.ado", mock_qianfan_ado)
|
||||
|
||||
qianfan_llm = QianFanLLM(mock_llm_config_qianfan)
|
||||
|
||||
resp = qianfan_llm.completion(messages)
|
||||
assert resp.get("result") == resp_cont
|
||||
|
||||
resp = await qianfan_llm.acompletion(messages)
|
||||
assert resp.get("result") == resp_cont
|
||||
|
||||
await llm_general_chat_funcs_test(qianfan_llm, prompt, messages, resp_cont)
|
||||
|
|
@ -4,18 +4,12 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.provider.spark_api import GetMessageFromWeb, SparkLLM
|
||||
from tests.metagpt.provider.mock_llm_config import (
|
||||
mock_llm_config,
|
||||
mock_llm_config_spark,
|
||||
)
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
llm_general_chat_funcs_test,
|
||||
prompt,
|
||||
resp_cont_tmpl,
|
||||
)
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config
|
||||
|
||||
resp_cont = resp_cont_tmpl.format(name="Spark")
|
||||
prompt_msg = "who are you"
|
||||
resp_content = "I'm Spark"
|
||||
|
||||
|
||||
class MockWebSocketApp(object):
|
||||
|
|
@ -29,7 +23,7 @@ class MockWebSocketApp(object):
|
|||
def test_get_msg_from_web(mocker):
|
||||
mocker.patch("websocket.WebSocketApp", MockWebSocketApp)
|
||||
|
||||
get_msg_from_web = GetMessageFromWeb(prompt, mock_llm_config)
|
||||
get_msg_from_web = GetMessageFromWeb(prompt_msg, mock_llm_config)
|
||||
assert get_msg_from_web.gen_params()["parameter"]["chat"]["domain"] == "mock_domain"
|
||||
|
||||
ret = get_msg_from_web.run()
|
||||
|
|
@ -37,26 +31,34 @@ def test_get_msg_from_web(mocker):
|
|||
|
||||
|
||||
def mock_spark_get_msg_from_web_run(self) -> str:
|
||||
return resp_cont
|
||||
return resp_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spark_aask(mocker):
|
||||
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
|
||||
|
||||
llm = SparkLLM(mock_llm_config_spark)
|
||||
async def test_spark_aask():
|
||||
llm = SparkLLM(Config.from_home("spark.yaml").llm)
|
||||
|
||||
resp = await llm.aask("Hello!")
|
||||
assert resp == resp_cont
|
||||
print(resp)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spark_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
|
||||
|
||||
spark_llm = SparkLLM(mock_llm_config)
|
||||
spark_gpt = SparkLLM(mock_llm_config)
|
||||
|
||||
resp = await spark_llm.acompletion([])
|
||||
assert resp == resp_cont
|
||||
resp = await spark_gpt.acompletion([])
|
||||
assert resp == resp_content
|
||||
|
||||
await llm_general_chat_funcs_test(spark_llm, prompt, prompt, resp_cont)
|
||||
resp = await spark_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await spark_gpt.acompletion_text([], stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await spark_gpt.acompletion_text([], stream=True)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await spark_gpt.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
|
|
|||
|
|
@ -6,24 +6,22 @@ import pytest
|
|||
|
||||
from metagpt.provider.zhipuai_api import ZhiPuAILLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config_zhipu
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
get_part_chat_completion,
|
||||
llm_general_chat_funcs_test,
|
||||
messages,
|
||||
prompt,
|
||||
resp_cont_tmpl,
|
||||
)
|
||||
|
||||
name = "ChatGLM-4"
|
||||
resp_cont = resp_cont_tmpl.format(name=name)
|
||||
default_resp = get_part_chat_completion(name)
|
||||
prompt_msg = "who are you"
|
||||
messages = [{"role": "user", "content": prompt_msg}]
|
||||
|
||||
resp_content = "I'm chatglm-turbo"
|
||||
default_resp = {
|
||||
"choices": [{"finish_reason": "stop", "index": 0, "message": {"content": resp_content, "role": "assistant"}}],
|
||||
"usage": {"completion_tokens": 22, "prompt_tokens": 19, "total_tokens": 41},
|
||||
}
|
||||
|
||||
|
||||
async def mock_zhipuai_acreate_stream(self, **kwargs):
|
||||
class MockResponse(object):
|
||||
async def _aread(self):
|
||||
class Iterator(object):
|
||||
events = [{"choices": [{"index": 0, "delta": {"content": resp_cont, "role": "assistant"}}]}]
|
||||
events = [{"choices": [{"index": 0, "delta": {"content": resp_content, "role": "assistant"}}]}]
|
||||
|
||||
async def __aiter__(self):
|
||||
for event in self.events:
|
||||
|
|
@ -48,12 +46,22 @@ async def test_zhipuai_acompletion(mocker):
|
|||
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.acreate", mock_zhipuai_acreate)
|
||||
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.acreate_stream", mock_zhipuai_acreate_stream)
|
||||
|
||||
zhipu_llm = ZhiPuAILLM(mock_llm_config_zhipu)
|
||||
zhipu_gpt = ZhiPuAILLM(mock_llm_config_zhipu)
|
||||
|
||||
resp = await zhipu_llm.acompletion(messages)
|
||||
assert resp["choices"][0]["message"]["content"] == resp_cont
|
||||
resp = await zhipu_gpt.acompletion(messages)
|
||||
assert resp["choices"][0]["message"]["content"] == resp_content
|
||||
|
||||
await llm_general_chat_funcs_test(zhipu_llm, prompt, messages, resp_cont)
|
||||
resp = await zhipu_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await zhipu_gpt.acompletion_text(messages, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await zhipu_gpt.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await zhipu_gpt.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
||||
|
||||
def test_zhipuai_proxy():
|
||||
|
|
|
|||
7
tests/spark.yaml
Normal file
7
tests/spark.yaml
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
llm:
|
||||
api_type: "spark"
|
||||
app_id: "xxx"
|
||||
api_key: "xxx"
|
||||
api_secret: "xxx"
|
||||
domain: "generalv2"
|
||||
base_url: "wss://spark-api.xf-yun.com/v3.1/chat"
|
||||
Loading…
Add table
Add a link
Reference in a new issue