diff --git a/.github/workflows/fulltest.yaml b/.github/workflows/fulltest.yaml index 70c800481..f5c6049e1 100644 --- a/.github/workflows/fulltest.yaml +++ b/.github/workflows/fulltest.yaml @@ -54,6 +54,7 @@ jobs: export ALLOW_OPENAI_API_CALL=0 echo "${{ secrets.METAGPT_KEY_YAML }}" | base64 -d > config/key.yaml mkdir -p ~/.metagpt && echo "${{ secrets.METAGPT_CONFIG2_YAML }}" | base64 -d > ~/.metagpt/config2.yaml + echo "${{ secrets.SPARK_YAML }}" | base64 -d > ~/.metagpt/spark.yaml pytest tests/ --doctest-modules --cov=./metagpt/ --cov-report=xml:cov.xml --cov-report=html:htmlcov --durations=20 | tee unittest.txt - name: Show coverage report run: | diff --git a/.github/workflows/unittest.yaml b/.github/workflows/unittest.yaml index afa9faba7..2e7e3ce2b 100644 --- a/.github/workflows/unittest.yaml +++ b/.github/workflows/unittest.yaml @@ -31,7 +31,7 @@ jobs: - name: Test with pytest run: | export ALLOW_OPENAI_API_CALL=0 - mkdir -p ~/.metagpt && cp tests/config2.yaml ~/.metagpt/config2.yaml + mkdir -p ~/.metagpt && cp tests/config2.yaml ~/.metagpt/config2.yaml && cp tests/spark.yaml ~/.metagpt/spark.yaml pytest tests/ --doctest-modules --cov=./metagpt/ --cov-report=xml:cov.xml --cov-report=html:htmlcov --durations=20 | tee unittest.txt - name: Show coverage report run: | diff --git a/examples/llm_hello_world.py b/examples/llm_hello_world.py index e22edbdf2..219a303c8 100644 --- a/examples/llm_hello_world.py +++ b/examples/llm_hello_world.py @@ -13,18 +13,7 @@ from metagpt.logs import logger async def main(): llm = LLM() - # llm type check - id_ques = "what's your name" - logger.info(f"{id_ques}: ") - logger.info(await llm.aask(id_ques)) - logger.info("\n\n") - - logger.info( - await llm.aask( - "who are you", system_msgs=["act as a robot, answer 'I'am robot' if the question is 'who are you'"] - ) - ) - + logger.info(await llm.aask("hello world")) logger.info(await llm.aask_batch(["hi", "write python hello world."])) hello_msg = [{"role": "user", "content": "count from 1 to 10. split by newline."}] diff --git a/metagpt/configs/llm_config.py b/metagpt/configs/llm_config.py index 1b05b5270..fb923d3e4 100644 --- a/metagpt/configs/llm_config.py +++ b/metagpt/configs/llm_config.py @@ -24,7 +24,6 @@ class LLMType(Enum): METAGPT = "metagpt" AZURE = "azure" OLLAMA = "ollama" - QIANFAN = "qianfan" # Baidu BCE def __missing__(self, key): return self.OPENAI @@ -37,18 +36,13 @@ class LLMConfig(YamlModel): Optional Fields in pydantic: https://docs.pydantic.dev/latest/migration/#required-optional-and-nullable-fields """ - api_key: str = "sk-" + api_key: str api_type: LLMType = LLMType.OPENAI base_url: str = "https://api.openai.com/v1" api_version: Optional[str] = None model: Optional[str] = None # also stands for DEPLOYMENT_NAME - # For Cloud Service Provider like Baidu/ Alibaba - access_key: Optional[str] = None - secret_key: Optional[str] = None - endpoint: Optional[str] = None # for self-deployed model on the cloud - # For Spark(Xunfei), maybe remove later app_id: Optional[str] = None api_secret: Optional[str] = None diff --git a/metagpt/provider/__init__.py b/metagpt/provider/__init__.py index 8c0aab836..675734811 100644 --- a/metagpt/provider/__init__.py +++ b/metagpt/provider/__init__.py @@ -16,7 +16,6 @@ from metagpt.provider.azure_openai_api import AzureOpenAILLM from metagpt.provider.metagpt_api import MetaGPTLLM from metagpt.provider.human_provider import HumanProvider from metagpt.provider.spark_api import SparkLLM -from metagpt.provider.qianfan_api import QianFanLLM __all__ = [ "FireworksLLM", @@ -29,5 +28,4 @@ __all__ = [ "OllamaLLM", "HumanProvider", "SparkLLM", - "QianFanLLM", ] diff --git a/metagpt/provider/base_llm.py b/metagpt/provider/base_llm.py index 2f57b15aa..b144471b5 100644 --- a/metagpt/provider/base_llm.py +++ b/metagpt/provider/base_llm.py @@ -11,12 +11,11 @@ from abc import ABC, abstractmethod from typing import Optional, Union from openai import AsyncOpenAI -from pydantic import BaseModel from metagpt.configs.llm_config import LLMConfig from metagpt.logs import logger from metagpt.schema import Message -from metagpt.utils.cost_manager import CostManager, Costs +from metagpt.utils.cost_manager import CostManager class BaseLLM(ABC): @@ -68,28 +67,6 @@ class BaseLLM(ABC): def _default_system_msg(self): return self._system_msg(self.system_prompt) - def _update_costs(self, usage: Union[dict, BaseModel], model: str = None, local_calc_usage: bool = True): - """update each request's token cost - Args: - model (str): model name or in some scenarios called endpoint - local_calc_usage (bool): some models don't calculate usage, it will overwrite LLMConfig.calc_usage - """ - calc_usage = self.config.calc_usage and local_calc_usage - model = model if model else self.model - usage = usage.model_dump() if isinstance(usage, BaseModel) else usage - if calc_usage and self.cost_manager: - try: - prompt_tokens = int(usage.get("prompt_tokens", 0)) - completion_tokens = int(usage.get("completion_tokens", 0)) - self.cost_manager.update_cost(prompt_tokens, completion_tokens, model) - except Exception as e: - logger.error(f"{self.__class__.__name__} updats costs failed! exp: {e}") - - def get_costs(self) -> Costs: - if not self.cost_manager: - return Costs(0, 0, 0, 0) - return self.cost_manager.get_costs() - async def aask( self, msg: str, diff --git a/metagpt/provider/fireworks_api.py b/metagpt/provider/fireworks_api.py index e62a7066e..d56453a85 100644 --- a/metagpt/provider/fireworks_api.py +++ b/metagpt/provider/fireworks_api.py @@ -19,7 +19,7 @@ from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.logs import logger from metagpt.provider.llm_provider_registry import register_provider from metagpt.provider.openai_api import OpenAILLM, log_and_reraise -from metagpt.utils.cost_manager import CostManager +from metagpt.utils.cost_manager import CostManager, Costs MODEL_GRADE_TOKEN_COSTS = { "-1": {"prompt": 0.0, "completion": 0.0}, # abnormal condition @@ -81,6 +81,17 @@ class FireworksLLM(OpenAILLM): kwargs = dict(api_key=self.config.api_key, base_url=self.config.base_url) return kwargs + def _update_costs(self, usage: CompletionUsage): + if self.config.calc_usage and usage: + try: + # use FireworksCostManager not context.cost_manager + self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model) + except Exception as e: + logger.error(f"updating costs failed!, exp: {e}") + + def get_costs(self) -> Costs: + return self.cost_manager.get_costs() + async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str: response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create( **self._cons_kwargs(messages), stream=True @@ -102,7 +113,7 @@ class FireworksLLM(OpenAILLM): usage = CompletionUsage(**chunk.usage) full_content = "".join(collected_content) - self._update_costs(usage.model_dump()) + self._update_costs(usage) return full_content @retry( diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index 87ea81c80..2647ab16b 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -72,6 +72,16 @@ class GeminiLLM(BaseLLM): kwargs = {"contents": messages, "generation_config": GenerationConfig(temperature=0.3), "stream": stream} return kwargs + def _update_costs(self, usage: dict): + """update each request's token cost""" + if self.config.calc_usage: + try: + prompt_tokens = int(usage.get("prompt_tokens", 0)) + completion_tokens = int(usage.get("completion_tokens", 0)) + self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) + except Exception as e: + logger.error(f"google gemini updats costs failed! exp: {e}") + def get_choice_text(self, resp: GenerateContentResponse) -> str: return resp.text diff --git a/metagpt/provider/ollama_api.py b/metagpt/provider/ollama_api.py index 52e8dbe36..c9103b018 100644 --- a/metagpt/provider/ollama_api.py +++ b/metagpt/provider/ollama_api.py @@ -46,6 +46,16 @@ class OllamaLLM(BaseLLM): kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream} return kwargs + def _update_costs(self, usage: dict): + """update each request's token cost""" + if self.config.calc_usage: + try: + prompt_tokens = int(usage.get("prompt_tokens", 0)) + completion_tokens = int(usage.get("completion_tokens", 0)) + self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) + except Exception as e: + logger.error(f"ollama updats costs failed! exp: {e}") + def get_choice_text(self, resp: dict) -> str: """get the resp content from llm response""" assist_msg = resp.get("message", {}) diff --git a/metagpt/provider/open_llm_api.py b/metagpt/provider/open_llm_api.py index 69371e379..a29b263a4 100644 --- a/metagpt/provider/open_llm_api.py +++ b/metagpt/provider/open_llm_api.py @@ -8,7 +8,7 @@ from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.logs import logger from metagpt.provider.llm_provider_registry import register_provider from metagpt.provider.openai_api import OpenAILLM -from metagpt.utils.cost_manager import TokenCostManager +from metagpt.utils.cost_manager import Costs, TokenCostManager from metagpt.utils.token_counter import count_message_tokens, count_string_tokens @@ -34,3 +34,14 @@ class OpenLLM(OpenAILLM): logger.error(f"usage calculation failed!: {e}") return usage + + def _update_costs(self, usage: CompletionUsage): + if self.config.calc_usage and usage: + try: + # use OpenLLMCostManager not CONFIG.cost_manager + self._cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model) + except Exception as e: + logger.error(f"updating costs failed!, exp: {e}") + + def get_costs(self) -> Costs: + return self._cost_manager.get_costs() diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 2ae14f437..fe41fb05f 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -29,7 +29,7 @@ from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA from metagpt.provider.llm_provider_registry import register_provider from metagpt.schema import Message from metagpt.utils.common import CodeParser, decode_image -from metagpt.utils.cost_manager import CostManager +from metagpt.utils.cost_manager import CostManager, Costs from metagpt.utils.exceptions import handle_exception from metagpt.utils.token_counter import ( count_message_tokens, @@ -55,13 +55,16 @@ class OpenAILLM(BaseLLM): def __init__(self, config: LLMConfig): self.config = config + self._init_model() self._init_client() self.auto_max_tokens = False self.cost_manager: Optional[CostManager] = None + def _init_model(self): + self.model = self.config.model # Used in _calc_usage & _cons_kwargs + def _init_client(self): """https://github.com/openai/openai-python#async-usage""" - self.model = self.config.model # Used in _calc_usage & _cons_kwargs kwargs = self._make_client_kwargs() self.aclient = AsyncOpenAI(**kwargs) @@ -237,6 +240,16 @@ class OpenAILLM(BaseLLM): return usage + @handle_exception + def _update_costs(self, usage: CompletionUsage): + if self.config.calc_usage and usage and self.cost_manager: + self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model) + + def get_costs(self) -> Costs: + if not self.cost_manager: + return Costs(0, 0, 0, 0) + return self.cost_manager.get_costs() + def _get_max_tokens(self, messages: list[dict]): if not self.auto_max_tokens: return self.config.max_token diff --git a/metagpt/provider/qianfan_api.py b/metagpt/provider/qianfan_api.py deleted file mode 100644 index 6f94b9cea..000000000 --- a/metagpt/provider/qianfan_api.py +++ /dev/null @@ -1,152 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# @Desc : llm api of qianfan from Baidu, supports ERNIE(wen xin yi yan) and opensource models -import copy -import os - -import qianfan -from qianfan import ChatCompletion -from qianfan.resources.typing import JsonBody -from tenacity import ( - after_log, - retry, - retry_if_exception_type, - stop_after_attempt, - wait_random_exponential, -) - -from metagpt.configs.llm_config import LLMConfig, LLMType -from metagpt.logs import log_llm_stream, logger -from metagpt.provider.base_llm import BaseLLM -from metagpt.provider.llm_provider_registry import register_provider -from metagpt.provider.openai_api import log_and_reraise -from metagpt.utils.cost_manager import CostManager -from metagpt.utils.token_counter import ( - QianFan_EndPoint_TOKEN_COSTS, - QianFan_MODEL_TOKEN_COSTS, -) - - -@register_provider(LLMType.QIANFAN) -class QianFanLLM(BaseLLM): - """ - Refs - Auth: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/3lmokh7n6#%E3%80%90%E6%8E%A8%E8%8D%90%E3%80%91%E4%BD%BF%E7%94%A8%E5%AE%89%E5%85%A8%E8%AE%A4%E8%AF%81aksk%E9%89%B4%E6%9D%83%E8%B0%83%E7%94%A8%E6%B5%81%E7%A8%8B - Token Price: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7#tokens%E5%90%8E%E4%BB%98%E8%B4%B9 - Models: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/wlmhm7vuo#%E5%AF%B9%E8%AF%9Dchat - https://cloud.baidu.com/doc/WENXINWORKSHOP/s/xlmokikxe#%E6%94%AF%E6%8C%81%E6%A8%A1%E5%9E%8B%E5%88%97%E8%A1%A8 - """ - - def __init__(self, config: LLMConfig): - self.config = config - self.use_system_prompt = False # only some ERNIE-x related models support system_prompt - self.__init_qianfan() - self.cost_manager = CostManager(token_costs=self.token_costs) - - def __init_qianfan(self): - if self.config.access_key and self.config.secret_key: - # for system level auth, use access_key and secret_key, recommended by official - # set environment variable due to official recommendation - os.environ.setdefault("QIANFAN_ACCESS_KEY", self.config.access_key) - os.environ.setdefault("QIANFAN_SECRET_KEY", self.config.secret_key) - elif self.config.api_key and self.config.secret_key: - # for application level auth, use api_key and secret_key - # set environment variable due to official recommendation - os.environ.setdefault("QIANFAN_AK", self.config.api_key) - os.environ.setdefault("QIANFAN_SK", self.config.secret_key) - else: - raise ValueError("Set the `access_key`&`secret_key` or `api_key`&`secret_key` first") - - support_system_pairs = [ - ("ERNIE-Bot-4", "completions_pro"), # (model, corresponding-endpoint) - ("ERNIE-Bot-8k", "ernie_bot_8k"), - ("ERNIE-Bot", "completions"), - ("ERNIE-Bot-turbo", "eb-instant"), - ("ERNIE-Speed", "ernie_speed"), - ("EB-turbo-AppBuilder", "ai_apaas"), - ] - if self.config.model in [pair[0] for pair in support_system_pairs]: - # only some ERNIE models support - self.use_system_prompt = True - if self.config.endpoint in [pair[1] for pair in support_system_pairs]: - self.use_system_prompt = True - - assert not (self.config.model and self.config.endpoint), "Only set `model` or `endpoint` in the config" - assert self.config.model or self.config.endpoint, "Should set one of `model` or `endpoint` in the config" - - self.token_costs = copy.deepcopy(QianFan_MODEL_TOKEN_COSTS) - self.token_costs.update(QianFan_EndPoint_TOKEN_COSTS) - - # self deployed model on the cloud not to calculate usage, it charges resource pool rental fee - self.calc_usage = self.config.calc_usage and self.config.endpoint is None - self.aclient: ChatCompletion = qianfan.ChatCompletion() - - def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict: - kwargs = { - "messages": messages, - "stream": stream, - } - if self.config.temperature > 0: - # different model has default temperature. only set when it's specified. - kwargs["temperature"] = self.config.temperature - if self.config.endpoint: - kwargs["endpoint"] = self.config.endpoint - elif self.config.model: - kwargs["model"] = self.config.model - - if self.use_system_prompt: - # if the model support system prompt, extract and pass it - if messages[0]["role"] == "system": - kwargs["messages"] = messages[1:] - kwargs["system"] = messages[0]["content"] # set system prompt here - return kwargs - - def _update_costs(self, usage: dict): - """update each request's token cost""" - model_or_endpoint = self.config.model if self.config.model else self.config.endpoint - local_calc_usage = True if model_or_endpoint in self.token_costs else False - super()._update_costs(usage, model_or_endpoint, local_calc_usage) - - def get_choice_text(self, resp: JsonBody) -> str: - return resp.get("result", "") - - def completion(self, messages: list[dict]) -> JsonBody: - resp = self.aclient.do(**self._const_kwargs(messages=messages, stream=False)) - self._update_costs(resp.body.get("usage", {})) - return resp.body - - async def _achat_completion(self, messages: list[dict]) -> JsonBody: - resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=False)) - self._update_costs(resp.body.get("usage", {})) - return resp.body - - async def acompletion(self, messages: list[dict], timeout=3) -> JsonBody: - return await self._achat_completion(messages) - - async def _achat_completion_stream(self, messages: list[dict]) -> str: - resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=True)) - collected_content = [] - usage = {} - async for chunk in resp: - content = chunk.body.get("result", "") - usage = chunk.body.get("usage", {}) - log_llm_stream(content) - collected_content.append(content) - log_llm_stream("\n") - - self._update_costs(usage) - full_content = "".join(collected_content) - return full_content - - @retry( - stop=stop_after_attempt(3), - wait=wait_random_exponential(min=1, max=60), - after=after_log(logger, logger.level("WARNING").name), - retry=retry_if_exception_type(ConnectionError), - retry_error_callback=log_and_reraise, - ) - async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str: - if stream: - return await self._achat_completion_stream(messages) - resp = await self._achat_completion(messages) - return self.get_choice_text(resp) diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index 4cbee4038..9e8e5fb53 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -53,6 +53,16 @@ class ZhiPuAILLM(BaseLLM): kwargs = {"model": self.model, "messages": messages, "stream": stream, "temperature": 0.3} return kwargs + def _update_costs(self, usage: dict): + """update each request's token cost""" + if self.config.calc_usage: + try: + prompt_tokens = int(usage.get("prompt_tokens", 0)) + completion_tokens = int(usage.get("completion_tokens", 0)) + self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) + except Exception as e: + logger.error(f"zhipuai updats costs failed! exp: {e}") + def completion(self, messages: list[dict], timeout=3) -> dict: resp: Completion = self.llm.chat.completions.create(**self._const_kwargs(messages)) usage = resp.usage.model_dump() diff --git a/metagpt/utils/cost_manager.py b/metagpt/utils/cost_manager.py index 4e6b65b2c..c4c93f91f 100644 --- a/metagpt/utils/cost_manager.py +++ b/metagpt/utils/cost_manager.py @@ -29,7 +29,6 @@ class CostManager(BaseModel): total_budget: float = 0 max_budget: float = 10.0 total_cost: float = 0 - token_costs: dict[str, dict[str, float]] = TOKEN_COSTS def update_cost(self, prompt_tokens, completion_tokens, model): """ @@ -47,8 +46,7 @@ class CostManager(BaseModel): return cost = ( - prompt_tokens * self.token_costs[model]["prompt"] - + completion_tokens * self.token_costs[model]["completion"] + prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model]["completion"] ) / 1000 self.total_cost += cost logger.info( diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index 2ec0edc99..65f5fe76f 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -38,59 +38,6 @@ TOKEN_COSTS = { } -""" -QianFan Token Price https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7#tokens%E5%90%8E%E4%BB%98%E8%B4%B9 -Due to QianFan has multi price strategies, we unify `Tokens post-payment` as a statistical method. -""" -QianFan_MODEL_TOKEN_COSTS = { - "ERNIE-Bot-4": {"prompt": 0.017, "completion": 0.017}, - "ERNIE-Bot-8k": {"prompt": 0.0034, "completion": 0.0067}, - "ERNIE-Bot": {"prompt": 0.017, "completion": 0.017}, - "ERNIE-Bot-turbo": {"prompt": 0.0011, "completion": 0.0011}, - "EB-turbo-AppBuilder": {"prompt": 0.0011, "completion": 0.0011}, - "ERNIE-Speed": {"prompt": 0.00056, "completion": 0.0011}, - "BLOOMZ-7B": {"prompt": 0.00056, "completion": 0.00056}, - "Llama-2-7B-Chat": {"prompt": 0.00056, "completion": 0.00056}, - "Llama-2-13B-Chat": {"prompt": 0.00084, "completion": 0.00084}, - "Llama-2-70B-Chat": {"prompt": 0.0049, "completion": 0.0049}, - "ChatGLM2-6B-32K": {"prompt": 0.00056, "completion": 0.00056}, - "AquilaChat-7B": {"prompt": 0.00056, "completion": 0.00056}, - "Mixtral-8x7B-Instruct": {"prompt": 0.0049, "completion": 0.0049}, - "SQLCoder-7B": {"prompt": 0.00056, "completion": 0.00056}, - "CodeLlama-7B-Instruct": {"prompt": 0.00056, "completion": 0.00056}, - "XuanYuan-70B-Chat-4bit": {"prompt": 0.0049, "completion": 0.0049}, - "Qianfan-BLOOMZ-7B-compressed": {"prompt": 0.00056, "completion": 0.00056}, - "Qianfan-Chinese-Llama-2-7B": {"prompt": 0.00056, "completion": 0.00056}, - "Qianfan-Chinese-Llama-2-13B": {"prompt": 0.00084, "completion": 0.00084}, - "ChatLaw": {"prompt": 0.0011, "completion": 0.0011}, - "Yi-34B-Chat": {"prompt": 0.0, "completion": 0.0}, -} - -QianFan_EndPoint_TOKEN_COSTS = { - "completions_pro": QianFan_MODEL_TOKEN_COSTS["ERNIE-Bot-4"], - "ernie_bot_8k": QianFan_MODEL_TOKEN_COSTS["ERNIE-Bot-8k"], - "completions": QianFan_MODEL_TOKEN_COSTS["ERNIE-Bot"], - "eb-instant": QianFan_MODEL_TOKEN_COSTS["ERNIE-Bot-turbo"], - "ai_apaas": QianFan_MODEL_TOKEN_COSTS["EB-turbo-AppBuilder"], - "ernie_speed": QianFan_MODEL_TOKEN_COSTS["ERNIE-Speed"], - "bloomz_7b1": QianFan_MODEL_TOKEN_COSTS["BLOOMZ-7B"], - "llama_2_7b": QianFan_MODEL_TOKEN_COSTS["Llama-2-7B-Chat"], - "llama_2_13b": QianFan_MODEL_TOKEN_COSTS["Llama-2-13B-Chat"], - "llama_2_70b": QianFan_MODEL_TOKEN_COSTS["Llama-2-70B-Chat"], - "chatglm2_6b_32k": QianFan_MODEL_TOKEN_COSTS["ChatGLM2-6B-32K"], - "aquilachat_7b": QianFan_MODEL_TOKEN_COSTS["AquilaChat-7B"], - "mixtral_8x7b_instruct": QianFan_MODEL_TOKEN_COSTS["Mixtral-8x7B-Instruct"], - "sqlcoder_7b": QianFan_MODEL_TOKEN_COSTS["SQLCoder-7B"], - "codellama_7b_instruct": QianFan_MODEL_TOKEN_COSTS["CodeLlama-7B-Instruct"], - "xuanyuan_70b_chat": QianFan_MODEL_TOKEN_COSTS["XuanYuan-70B-Chat-4bit"], - "qianfan_bloomz_7b_compressed": QianFan_MODEL_TOKEN_COSTS["Qianfan-BLOOMZ-7B-compressed"], - "qianfan_chinese_llama_2_7b": QianFan_MODEL_TOKEN_COSTS["Qianfan-Chinese-Llama-2-7B"], - "qianfan_chinese_llama_2_13b": QianFan_MODEL_TOKEN_COSTS["Qianfan-Chinese-Llama-2-13B"], - "chatlaw": QianFan_MODEL_TOKEN_COSTS["ChatLaw"], - "yi_34b_chat": QianFan_MODEL_TOKEN_COSTS["Yi-34B-Chat"], -} - - TOKEN_MAX = { "gpt-3.5-turbo": 4096, "gpt-3.5-turbo-0301": 4096, diff --git a/requirements.txt b/requirements.txt index b5d8d7d51..1426500ce 100644 --- a/requirements.txt +++ b/requirements.txt @@ -67,4 +67,3 @@ playwright>=1.26 # used at metagpt/tools/libs/web_scraping.py anytree ipywidgets==8.1.1 Pillow -qianfan==0.3.1 diff --git a/tests/metagpt/provider/mock_llm_config.py b/tests/metagpt/provider/mock_llm_config.py index e0afaa51e..e2f626a6a 100644 --- a/tests/metagpt/provider/mock_llm_config.py +++ b/tests/metagpt/provider/mock_llm_config.py @@ -42,15 +42,3 @@ mock_llm_config_zhipu = LLMConfig( model="mock_zhipu_model", proxy="http://localhost:8080", ) - - -mock_llm_config_spark = LLMConfig( - api_type="spark", - app_id="xxx", - api_key="xxx", - api_secret="xxx", - domain="generalv2", - base_url="wss://spark-api.xf-yun.com/v3.1/chat", -) - -mock_llm_config_qianfan = LLMConfig(api_type="qianfan", access_key="xxx", secret_key="xxx", model="ERNIE-Bot-turbo") diff --git a/tests/metagpt/provider/req_resp_const.py b/tests/metagpt/provider/req_resp_const.py deleted file mode 100644 index 73939e1c6..000000000 --- a/tests/metagpt/provider/req_resp_const.py +++ /dev/null @@ -1,117 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# @Desc : default request & response data for provider unittest - - -from openai.types.chat.chat_completion import ( - ChatCompletion, - ChatCompletionMessage, - Choice, -) -from openai.types.chat.chat_completion_chunk import ChatCompletionChunk -from openai.types.chat.chat_completion_chunk import Choice as AChoice -from openai.types.chat.chat_completion_chunk import ChoiceDelta -from openai.types.completion_usage import CompletionUsage -from qianfan.resources.typing import QfResponse - -from metagpt.provider.base_llm import BaseLLM - -prompt = "who are you?" -messages = [{"role": "user", "content": prompt}] - -resp_cont_tmpl = "I'm {name}" -default_resp_cont = resp_cont_tmpl.format(name="GPT") - - -# part of whole ChatCompletion of openai like structure -def get_part_chat_completion(name: str) -> dict: - part_chat_completion = { - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": resp_cont_tmpl.format(name=name), - }, - "finish_reason": "stop", - } - ], - "usage": {"completion_tokens": 22, "prompt_tokens": 19, "total_tokens": 41}, - } - return part_chat_completion - - -def get_openai_chat_completion(name: str) -> ChatCompletion: - openai_chat_completion = ChatCompletion( - id="cmpl-a6652c1bb181caae8dd19ad8", - model="xx/xxx", - object="chat.completion", - created=1703300855, - choices=[ - Choice( - finish_reason="stop", - index=0, - message=ChatCompletionMessage(role="assistant", content=resp_cont_tmpl.format(name=name)), - logprobs=None, - ) - ], - usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202), - ) - return openai_chat_completion - - -def get_openai_chat_completion_chunk(name: str, usage_as_dict: bool = False) -> ChatCompletionChunk: - usage = CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202) - usage = usage if not usage_as_dict else usage.model_dump() - openai_chat_completion_chunk = ChatCompletionChunk( - id="cmpl-a6652c1bb181caae8dd19ad8", - model="xx/xxx", - object="chat.completion.chunk", - created=1703300855, - choices=[ - AChoice( - delta=ChoiceDelta(role="assistant", content=resp_cont_tmpl.format(name=name)), - finish_reason="stop", - index=0, - logprobs=None, - ) - ], - usage=usage, - ) - return openai_chat_completion_chunk - - -# For gemini -gemini_messages = [{"role": "user", "parts": prompt}] - - -# For QianFan -qf_jsonbody_dict = { - "id": "as-4v1h587fyv", - "object": "chat.completion", - "created": 1695021339, - "result": "", - "is_truncated": False, - "need_clear_history": False, - "usage": {"prompt_tokens": 7, "completion_tokens": 15, "total_tokens": 22}, -} - - -def get_qianfan_response(name: str) -> QfResponse: - qf_jsonbody_dict["result"] = resp_cont_tmpl.format(name=name) - return QfResponse(code=200, body=qf_jsonbody_dict) - - -# For llm general chat functions call -async def llm_general_chat_funcs_test(llm: BaseLLM, prompt: str, messages: list[dict], resp_cont: str): - resp = await llm.aask(prompt, stream=False) - assert resp == resp_cont - - resp = await llm.aask(prompt) - assert resp == resp_cont - - resp = await llm.acompletion_text(messages, stream=False) - assert resp == resp_cont - - resp = await llm.acompletion_text(messages, stream=True) - assert resp == resp_cont diff --git a/tests/metagpt/provider/test_anthropic_api.py b/tests/metagpt/provider/test_anthropic_api.py index 93cfd7dbc..6962ab064 100644 --- a/tests/metagpt/provider/test_anthropic_api.py +++ b/tests/metagpt/provider/test_anthropic_api.py @@ -8,25 +8,25 @@ from anthropic.resources.completions import Completion from metagpt.provider.anthropic_api import Claude2 from tests.metagpt.provider.mock_llm_config import mock_llm_config -from tests.metagpt.provider.req_resp_const import prompt, resp_cont_tmpl -resp_cont = resp_cont_tmpl.format(name="Claude") +prompt = "who are you" +resp = "I'am Claude2" def mock_anthropic_completions_create(self, model: str, prompt: str, max_tokens_to_sample: int) -> Completion: - return Completion(id="xx", completion=resp_cont, model="claude-2", stop_reason="stop_sequence", type="completion") + return Completion(id="xx", completion=resp, model="claude-2", stop_reason="stop_sequence", type="completion") async def mock_anthropic_acompletions_create(self, model: str, prompt: str, max_tokens_to_sample: int) -> Completion: - return Completion(id="xx", completion=resp_cont, model="claude-2", stop_reason="stop_sequence", type="completion") + return Completion(id="xx", completion=resp, model="claude-2", stop_reason="stop_sequence", type="completion") def test_claude2_ask(mocker): mocker.patch("anthropic.resources.completions.Completions.create", mock_anthropic_completions_create) - assert resp_cont == Claude2(mock_llm_config).ask(prompt) + assert resp == Claude2(mock_llm_config).ask(prompt) @pytest.mark.asyncio async def test_claude2_aask(mocker): mocker.patch("anthropic.resources.completions.AsyncCompletions.create", mock_anthropic_acompletions_create) - assert resp_cont == await Claude2(mock_llm_config).aask(prompt) + assert resp == await Claude2(mock_llm_config).aask(prompt) diff --git a/tests/metagpt/provider/test_base_llm.py b/tests/metagpt/provider/test_base_llm.py index cf44343bc..cc781f78a 100644 --- a/tests/metagpt/provider/test_base_llm.py +++ b/tests/metagpt/provider/test_base_llm.py @@ -11,13 +11,21 @@ import pytest from metagpt.configs.llm_config import LLMConfig from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Message -from tests.metagpt.provider.req_resp_const import ( - default_resp_cont, - get_part_chat_completion, - prompt, -) -name = "GPT" +default_chat_resp = { + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "I'am GPT", + }, + "finish_reason": "stop", + } + ] +} +prompt_msg = "who are you" +resp_content = default_chat_resp["choices"][0]["message"]["content"] class MockBaseLLM(BaseLLM): @@ -25,13 +33,16 @@ class MockBaseLLM(BaseLLM): pass def completion(self, messages: list[dict], timeout=3): - return get_part_chat_completion(name) + return default_chat_resp async def acompletion(self, messages: list[dict], timeout=3): - return get_part_chat_completion(name) + return default_chat_resp async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str: - return default_resp_cont + return resp_content + + async def close(self): + return default_chat_resp def test_base_llm(): @@ -75,25 +86,25 @@ def test_base_llm(): choice_text = base_llm.get_choice_text(openai_funccall_resp) assert choice_text == openai_funccall_resp["choices"][0]["message"]["content"] - # resp = base_llm.ask(prompt) - # assert resp == default_resp_cont + # resp = base_llm.ask(prompt_msg) + # assert resp == resp_content - # resp = base_llm.ask_batch([prompt]) - # assert resp == default_resp_cont + # resp = base_llm.ask_batch([prompt_msg]) + # assert resp == resp_content - # resp = base_llm.ask_code([prompt]) - # assert resp == default_resp_cont + # resp = base_llm.ask_code([prompt_msg]) + # assert resp == resp_content @pytest.mark.asyncio async def test_async_base_llm(): base_llm = MockBaseLLM() - resp = await base_llm.aask(prompt) - assert resp == default_resp_cont + resp = await base_llm.aask(prompt_msg) + assert resp == resp_content - resp = await base_llm.aask_batch([prompt]) - assert resp == default_resp_cont + resp = await base_llm.aask_batch([prompt_msg]) + assert resp == resp_content - # resp = await base_llm.aask_code([prompt]) - # assert resp == default_resp_cont + # resp = await base_llm.aask_code([prompt_msg]) + # assert resp == resp_content diff --git a/tests/metagpt/provider/test_fireworks_llm.py b/tests/metagpt/provider/test_fireworks_llm.py index 1c1aa9caa..66b55e5b2 100644 --- a/tests/metagpt/provider/test_fireworks_llm.py +++ b/tests/metagpt/provider/test_fireworks_llm.py @@ -3,7 +3,14 @@ # @Desc : the unittest of fireworks api import pytest +from openai.types.chat.chat_completion import ( + ChatCompletion, + ChatCompletionMessage, + Choice, +) from openai.types.chat.chat_completion_chunk import ChatCompletionChunk +from openai.types.chat.chat_completion_chunk import Choice as AChoice +from openai.types.chat.chat_completion_chunk import ChoiceDelta from openai.types.completion_usage import CompletionUsage from metagpt.provider.fireworks_api import ( @@ -13,19 +20,42 @@ from metagpt.provider.fireworks_api import ( ) from metagpt.utils.cost_manager import Costs from tests.metagpt.provider.mock_llm_config import mock_llm_config -from tests.metagpt.provider.req_resp_const import ( - get_openai_chat_completion, - get_openai_chat_completion_chunk, - llm_general_chat_funcs_test, - messages, - prompt, - resp_cont_tmpl, + +resp_content = "I'm fireworks" +default_resp = ChatCompletion( + id="cmpl-a6652c1bb181caae8dd19ad8", + model="accounts/fireworks/models/llama-v2-13b-chat", + object="chat.completion", + created=1703300855, + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage(role="assistant", content=resp_content), + logprobs=None, + ) + ], + usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202), ) -name = "fireworks" -resp_cont = resp_cont_tmpl.format(name=name) -default_resp = get_openai_chat_completion(name) -default_resp_chunk = get_openai_chat_completion_chunk(name, usage_as_dict=True) +default_resp_chunk = ChatCompletionChunk( + id=default_resp.id, + model=default_resp.model, + object="chat.completion.chunk", + created=default_resp.created, + choices=[ + AChoice( + delta=ChoiceDelta(content=resp_content, role="assistant"), + finish_reason="stop", + index=0, + logprobs=None, + ) + ], + usage=dict(default_resp.usage), +) + +prompt_msg = "who are you" +messages = [{"role": "user", "content": prompt_msg}] def test_fireworks_costmanager(): @@ -58,17 +88,27 @@ async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) async def test_fireworks_acompletion(mocker): mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create) - fireworks_llm = FireworksLLM(mock_llm_config) - fireworks_llm.model = "llama-v2-13b-chat" + fireworks_gpt = FireworksLLM(mock_llm_config) + fireworks_gpt.model = "llama-v2-13b-chat" - fireworks_llm._update_costs( + fireworks_gpt._update_costs( usage=CompletionUsage(prompt_tokens=500000, completion_tokens=500000, total_tokens=1000000) ) - assert fireworks_llm.get_costs() == Costs( + assert fireworks_gpt.get_costs() == Costs( total_prompt_tokens=500000, total_completion_tokens=500000, total_cost=0.5, total_budget=0 ) - resp = await fireworks_llm.acompletion(messages) - assert resp.choices[0].message.content in resp_cont + resp = await fireworks_gpt.acompletion(messages) + assert resp.choices[0].message.content in resp_content - await llm_general_chat_funcs_test(fireworks_llm, prompt, messages, resp_cont) + resp = await fireworks_gpt.aask(prompt_msg, stream=False) + assert resp == resp_content + + resp = await fireworks_gpt.acompletion_text(messages, stream=False) + assert resp == resp_content + + resp = await fireworks_gpt.acompletion_text(messages, stream=True) + assert resp == resp_content + + resp = await fireworks_gpt.aask(prompt_msg) + assert resp == resp_content diff --git a/tests/metagpt/provider/test_google_gemini_api.py b/tests/metagpt/provider/test_google_gemini_api.py index 50c15ee19..404ae1e90 100644 --- a/tests/metagpt/provider/test_google_gemini_api.py +++ b/tests/metagpt/provider/test_google_gemini_api.py @@ -11,12 +11,6 @@ from google.generativeai.types import content_types from metagpt.provider.google_gemini_api import GeminiLLM from tests.metagpt.provider.mock_llm_config import mock_llm_config -from tests.metagpt.provider.req_resp_const import ( - gemini_messages, - llm_general_chat_funcs_test, - prompt, - resp_cont_tmpl, -) @dataclass @@ -24,8 +18,10 @@ class MockGeminiResponse(ABC): text: str -resp_cont = resp_cont_tmpl.format(name="gemini") -default_resp = MockGeminiResponse(text=resp_cont) +prompt_msg = "who are you" +messages = [{"role": "user", "parts": prompt_msg}] +resp_content = "I'm gemini from google" +default_resp = MockGeminiResponse(text=resp_content) def mock_gemini_count_tokens(self, contents: content_types.ContentsType) -> glm.CountTokensResponse: @@ -64,18 +60,28 @@ async def test_gemini_acompletion(mocker): mock_gemini_generate_content_async, ) - gemini_llm = GeminiLLM(mock_llm_config) + gemini_gpt = GeminiLLM(mock_llm_config) - assert gemini_llm._user_msg(prompt) == {"role": "user", "parts": [prompt]} - assert gemini_llm._assistant_msg(prompt) == {"role": "model", "parts": [prompt]} + assert gemini_gpt._user_msg(prompt_msg) == {"role": "user", "parts": [prompt_msg]} + assert gemini_gpt._assistant_msg(prompt_msg) == {"role": "model", "parts": [prompt_msg]} - usage = gemini_llm.get_usage(gemini_messages, resp_cont) + usage = gemini_gpt.get_usage(messages, resp_content) assert usage == {"prompt_tokens": 20, "completion_tokens": 20} - resp = gemini_llm.completion(gemini_messages) + resp = gemini_gpt.completion(messages) assert resp == default_resp - resp = await gemini_llm.acompletion(gemini_messages) + resp = await gemini_gpt.acompletion(messages) assert resp.text == default_resp.text - await llm_general_chat_funcs_test(gemini_llm, prompt, gemini_messages, resp_cont) + resp = await gemini_gpt.aask(prompt_msg, stream=False) + assert resp == resp_content + + resp = await gemini_gpt.acompletion_text(messages, stream=False) + assert resp == resp_content + + resp = await gemini_gpt.acompletion_text(messages, stream=True) + assert resp == resp_content + + resp = await gemini_gpt.aask(prompt_msg) + assert resp == resp_content diff --git a/tests/metagpt/provider/test_ollama_api.py b/tests/metagpt/provider/test_ollama_api.py index af2e929e9..5d942598b 100644 --- a/tests/metagpt/provider/test_ollama_api.py +++ b/tests/metagpt/provider/test_ollama_api.py @@ -9,15 +9,12 @@ import pytest from metagpt.provider.ollama_api import OllamaLLM from tests.metagpt.provider.mock_llm_config import mock_llm_config -from tests.metagpt.provider.req_resp_const import ( - llm_general_chat_funcs_test, - messages, - prompt, - resp_cont_tmpl, -) -resp_cont = resp_cont_tmpl.format(name="ollama") -default_resp = {"message": {"role": "assistant", "content": resp_cont}} +prompt_msg = "who are you" +messages = [{"role": "user", "content": prompt_msg}] + +resp_content = "I'm ollama" +default_resp = {"message": {"role": "assistant", "content": resp_content}} async def mock_ollama_arequest(self, stream: bool = False, **kwargs) -> Tuple[Any, Any, bool]: @@ -44,12 +41,19 @@ async def mock_ollama_arequest(self, stream: bool = False, **kwargs) -> Tuple[An async def test_gemini_acompletion(mocker): mocker.patch("metagpt.provider.general_api_requestor.GeneralAPIRequestor.arequest", mock_ollama_arequest) - ollama_llm = OllamaLLM(mock_llm_config) + ollama_gpt = OllamaLLM(mock_llm_config) - resp = await ollama_llm.acompletion(messages) + resp = await ollama_gpt.acompletion(messages) assert resp["message"]["content"] == default_resp["message"]["content"] - resp = await ollama_llm.aask(prompt, stream=False) - assert resp == resp_cont + resp = await ollama_gpt.aask(prompt_msg, stream=False) + assert resp == resp_content - await llm_general_chat_funcs_test(ollama_llm, prompt, messages, resp_cont) + resp = await ollama_gpt.acompletion_text(messages, stream=False) + assert resp == resp_content + + resp = await ollama_gpt.acompletion_text(messages, stream=True) + assert resp == resp_content + + resp = await ollama_gpt.aask(prompt_msg) + assert resp == resp_content diff --git a/tests/metagpt/provider/test_open_llm_api.py b/tests/metagpt/provider/test_open_llm_api.py index aa38b95a6..fc7b510cc 100644 --- a/tests/metagpt/provider/test_open_llm_api.py +++ b/tests/metagpt/provider/test_open_llm_api.py @@ -3,26 +3,53 @@ # @Desc : import pytest +from openai.types.chat.chat_completion import ( + ChatCompletion, + ChatCompletionMessage, + Choice, +) from openai.types.chat.chat_completion_chunk import ChatCompletionChunk +from openai.types.chat.chat_completion_chunk import Choice as AChoice +from openai.types.chat.chat_completion_chunk import ChoiceDelta from openai.types.completion_usage import CompletionUsage from metagpt.provider.open_llm_api import OpenLLM -from metagpt.utils.cost_manager import CostManager, Costs +from metagpt.utils.cost_manager import Costs from tests.metagpt.provider.mock_llm_config import mock_llm_config -from tests.metagpt.provider.req_resp_const import ( - get_openai_chat_completion, - get_openai_chat_completion_chunk, - llm_general_chat_funcs_test, - messages, - prompt, - resp_cont_tmpl, + +resp_content = "I'm llama2" +default_resp = ChatCompletion( + id="cmpl-a6652c1bb181caae8dd19ad8", + model="llama-v2-13b-chat", + object="chat.completion", + created=1703302755, + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage(role="assistant", content=resp_content), + logprobs=None, + ) + ], ) -name = "llama2-7b" -resp_cont = resp_cont_tmpl.format(name=name) -default_resp = get_openai_chat_completion(name) +default_resp_chunk = ChatCompletionChunk( + id=default_resp.id, + model=default_resp.model, + object="chat.completion.chunk", + created=default_resp.created, + choices=[ + AChoice( + delta=ChoiceDelta(content=resp_content, role="assistant"), + finish_reason="stop", + index=0, + logprobs=None, + ) + ], +) -default_resp_chunk = get_openai_chat_completion_chunk(name) +prompt_msg = "who are you" +messages = [{"role": "user", "content": prompt_msg}] async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) -> ChatCompletionChunk: @@ -41,16 +68,25 @@ async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) async def test_openllm_acompletion(mocker): mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create) - openllm_llm = OpenLLM(mock_llm_config) - openllm_llm.model = "llama-v2-13b-chat" + openllm_gpt = OpenLLM(mock_llm_config) + openllm_gpt.model = "llama-v2-13b-chat" - openllm_llm.cost_manager = CostManager() - openllm_llm._update_costs(usage=CompletionUsage(prompt_tokens=100, completion_tokens=100, total_tokens=200)) - assert openllm_llm.get_costs() == Costs( + openllm_gpt._update_costs(usage=CompletionUsage(prompt_tokens=100, completion_tokens=100, total_tokens=200)) + assert openllm_gpt.get_costs() == Costs( total_prompt_tokens=100, total_completion_tokens=100, total_cost=0, total_budget=0 ) - resp = await openllm_llm.acompletion(messages) - assert resp.choices[0].message.content in resp_cont + resp = await openllm_gpt.acompletion(messages) + assert resp.choices[0].message.content in resp_content - await llm_general_chat_funcs_test(openllm_llm, prompt, messages, resp_cont) + resp = await openllm_gpt.aask(prompt_msg, stream=False) + assert resp == resp_content + + resp = await openllm_gpt.acompletion_text(messages, stream=False) + assert resp == resp_content + + resp = await openllm_gpt.acompletion_text(messages, stream=True) + assert resp == resp_content + + resp = await openllm_gpt.aask(prompt_msg) + assert resp == resp_content diff --git a/tests/metagpt/provider/test_qianfan_api.py b/tests/metagpt/provider/test_qianfan_api.py deleted file mode 100644 index 28341425c..000000000 --- a/tests/metagpt/provider/test_qianfan_api.py +++ /dev/null @@ -1,56 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# @Desc : the unittest of qianfan api - -from typing import AsyncIterator, Union - -import pytest -from qianfan.resources.typing import JsonBody, QfResponse - -from metagpt.provider.qianfan_api import QianFanLLM -from tests.metagpt.provider.mock_llm_config import mock_llm_config_qianfan -from tests.metagpt.provider.req_resp_const import ( - get_qianfan_response, - llm_general_chat_funcs_test, - messages, - prompt, - resp_cont_tmpl, -) - -name = "ERNIE-Bot-turbo" -resp_cont = resp_cont_tmpl.format(name=name) - - -def mock_qianfan_do(self, messages: list[dict], model: str, stream: bool = False, system: str = None) -> QfResponse: - return get_qianfan_response(name=name) - - -async def mock_qianfan_ado( - self, messages: list[dict], model: str, stream: bool = True, system: str = None -) -> Union[QfResponse, AsyncIterator[QfResponse]]: - resps = [get_qianfan_response(name=name)] - if stream: - - async def aresp_iterator(resps: list[JsonBody]): - for resp in resps: - yield resp - - return aresp_iterator(resps) - else: - return resps[0] - - -@pytest.mark.asyncio -async def test_qianfan_acompletion(mocker): - mocker.patch("qianfan.resources.llm.chat_completion.ChatCompletion.do", mock_qianfan_do) - mocker.patch("qianfan.resources.llm.chat_completion.ChatCompletion.ado", mock_qianfan_ado) - - qianfan_llm = QianFanLLM(mock_llm_config_qianfan) - - resp = qianfan_llm.completion(messages) - assert resp.get("result") == resp_cont - - resp = await qianfan_llm.acompletion(messages) - assert resp.get("result") == resp_cont - - await llm_general_chat_funcs_test(qianfan_llm, prompt, messages, resp_cont) diff --git a/tests/metagpt/provider/test_spark_api.py b/tests/metagpt/provider/test_spark_api.py index 9c278267d..f5a6f66fd 100644 --- a/tests/metagpt/provider/test_spark_api.py +++ b/tests/metagpt/provider/test_spark_api.py @@ -4,18 +4,12 @@ import pytest +from metagpt.config2 import Config from metagpt.provider.spark_api import GetMessageFromWeb, SparkLLM -from tests.metagpt.provider.mock_llm_config import ( - mock_llm_config, - mock_llm_config_spark, -) -from tests.metagpt.provider.req_resp_const import ( - llm_general_chat_funcs_test, - prompt, - resp_cont_tmpl, -) +from tests.metagpt.provider.mock_llm_config import mock_llm_config -resp_cont = resp_cont_tmpl.format(name="Spark") +prompt_msg = "who are you" +resp_content = "I'm Spark" class MockWebSocketApp(object): @@ -29,7 +23,7 @@ class MockWebSocketApp(object): def test_get_msg_from_web(mocker): mocker.patch("websocket.WebSocketApp", MockWebSocketApp) - get_msg_from_web = GetMessageFromWeb(prompt, mock_llm_config) + get_msg_from_web = GetMessageFromWeb(prompt_msg, mock_llm_config) assert get_msg_from_web.gen_params()["parameter"]["chat"]["domain"] == "mock_domain" ret = get_msg_from_web.run() @@ -37,26 +31,34 @@ def test_get_msg_from_web(mocker): def mock_spark_get_msg_from_web_run(self) -> str: - return resp_cont + return resp_content @pytest.mark.asyncio -async def test_spark_aask(mocker): - mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run) - - llm = SparkLLM(mock_llm_config_spark) +async def test_spark_aask(): + llm = SparkLLM(Config.from_home("spark.yaml").llm) resp = await llm.aask("Hello!") - assert resp == resp_cont + print(resp) @pytest.mark.asyncio async def test_spark_acompletion(mocker): mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run) - spark_llm = SparkLLM(mock_llm_config) + spark_gpt = SparkLLM(mock_llm_config) - resp = await spark_llm.acompletion([]) - assert resp == resp_cont + resp = await spark_gpt.acompletion([]) + assert resp == resp_content - await llm_general_chat_funcs_test(spark_llm, prompt, prompt, resp_cont) + resp = await spark_gpt.aask(prompt_msg, stream=False) + assert resp == resp_content + + resp = await spark_gpt.acompletion_text([], stream=False) + assert resp == resp_content + + resp = await spark_gpt.acompletion_text([], stream=True) + assert resp == resp_content + + resp = await spark_gpt.aask(prompt_msg) + assert resp == resp_content diff --git a/tests/metagpt/provider/test_zhipuai_api.py b/tests/metagpt/provider/test_zhipuai_api.py index c51010122..ad2ececa2 100644 --- a/tests/metagpt/provider/test_zhipuai_api.py +++ b/tests/metagpt/provider/test_zhipuai_api.py @@ -6,24 +6,22 @@ import pytest from metagpt.provider.zhipuai_api import ZhiPuAILLM from tests.metagpt.provider.mock_llm_config import mock_llm_config_zhipu -from tests.metagpt.provider.req_resp_const import ( - get_part_chat_completion, - llm_general_chat_funcs_test, - messages, - prompt, - resp_cont_tmpl, -) -name = "ChatGLM-4" -resp_cont = resp_cont_tmpl.format(name=name) -default_resp = get_part_chat_completion(name) +prompt_msg = "who are you" +messages = [{"role": "user", "content": prompt_msg}] + +resp_content = "I'm chatglm-turbo" +default_resp = { + "choices": [{"finish_reason": "stop", "index": 0, "message": {"content": resp_content, "role": "assistant"}}], + "usage": {"completion_tokens": 22, "prompt_tokens": 19, "total_tokens": 41}, +} async def mock_zhipuai_acreate_stream(self, **kwargs): class MockResponse(object): async def _aread(self): class Iterator(object): - events = [{"choices": [{"index": 0, "delta": {"content": resp_cont, "role": "assistant"}}]}] + events = [{"choices": [{"index": 0, "delta": {"content": resp_content, "role": "assistant"}}]}] async def __aiter__(self): for event in self.events: @@ -48,12 +46,22 @@ async def test_zhipuai_acompletion(mocker): mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.acreate", mock_zhipuai_acreate) mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.acreate_stream", mock_zhipuai_acreate_stream) - zhipu_llm = ZhiPuAILLM(mock_llm_config_zhipu) + zhipu_gpt = ZhiPuAILLM(mock_llm_config_zhipu) - resp = await zhipu_llm.acompletion(messages) - assert resp["choices"][0]["message"]["content"] == resp_cont + resp = await zhipu_gpt.acompletion(messages) + assert resp["choices"][0]["message"]["content"] == resp_content - await llm_general_chat_funcs_test(zhipu_llm, prompt, messages, resp_cont) + resp = await zhipu_gpt.aask(prompt_msg, stream=False) + assert resp == resp_content + + resp = await zhipu_gpt.acompletion_text(messages, stream=False) + assert resp == resp_content + + resp = await zhipu_gpt.acompletion_text(messages, stream=True) + assert resp == resp_content + + resp = await zhipu_gpt.aask(prompt_msg) + assert resp == resp_content def test_zhipuai_proxy(): diff --git a/tests/spark.yaml b/tests/spark.yaml new file mode 100644 index 000000000..a5bbd98bd --- /dev/null +++ b/tests/spark.yaml @@ -0,0 +1,7 @@ +llm: + api_type: "spark" + app_id: "xxx" + api_key: "xxx" + api_secret: "xxx" + domain: "generalv2" + base_url: "wss://spark-api.xf-yun.com/v3.1/chat" \ No newline at end of file