From 78989b0eb7dd012442cb480ceed217d8ecc28f03 Mon Sep 17 00:00:00 2001 From: yzlin Date: Tue, 6 Feb 2024 23:37:24 +0800 Subject: [PATCH 01/15] skip two individual tests --- tests/metagpt/actions/test_rebuild_class_view.py | 1 + tests/metagpt/actions/test_summarize_code.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/metagpt/actions/test_rebuild_class_view.py b/tests/metagpt/actions/test_rebuild_class_view.py index 403109cc0..2188d6b85 100644 --- a/tests/metagpt/actions/test_rebuild_class_view.py +++ b/tests/metagpt/actions/test_rebuild_class_view.py @@ -14,6 +14,7 @@ from metagpt.actions.rebuild_class_view import RebuildClassView from metagpt.llm import LLM +@pytest.mark.skip @pytest.mark.asyncio async def test_rebuild(context): action = RebuildClassView( diff --git a/tests/metagpt/actions/test_summarize_code.py b/tests/metagpt/actions/test_summarize_code.py index a404047c1..3cfe7ca81 100644 --- a/tests/metagpt/actions/test_summarize_code.py +++ b/tests/metagpt/actions/test_summarize_code.py @@ -176,6 +176,7 @@ class Snake: """ +@pytest.mark.skip @pytest.mark.asyncio async def test_summarize_code(context): git_dir = Path(__file__).parent / f"unittest/{uuid.uuid4().hex}" From 6f31289e7e0efd96a22400b31df8179eab286875 Mon Sep 17 00:00:00 2001 From: better629 Date: Wed, 7 Feb 2024 10:02:15 +0800 Subject: [PATCH 02/15] re-commit zhipu-api due to merge mistake --- examples/llm_hello_world.py | 8 ------- examples/llm_vision.py | 23 ++++++++++++++++++ metagpt/provider/general_api_requestor.py | 3 ++- metagpt/provider/zhipuai_api.py | 28 ++++++++++------------ metagpt/utils/token_counter.py | 7 +++--- tests/metagpt/provider/test_zhipuai_api.py | 4 ++-- 6 files changed, 43 insertions(+), 30 deletions(-) create mode 100644 examples/llm_vision.py diff --git a/examples/llm_hello_world.py b/examples/llm_hello_world.py index 1d132eb8a..219a303c8 100644 --- a/examples/llm_hello_world.py +++ b/examples/llm_hello_world.py @@ -6,11 +6,9 @@ @File : llm_hello_world.py """ import asyncio -from pathlib import Path from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.utils.common import encode_image async def main(): @@ -29,12 +27,6 @@ async def main(): if hasattr(llm, "completion"): logger.info(llm.completion(hello_msg)) - # check if the configured llm supports llm-vision capacity. If not, it will throw a error - invoice_path = Path(__file__).parent.joinpath("..", "tests", "data", "invoices", "invoice-2.png") - img_base64 = encode_image(invoice_path) - res = await llm.aask(msg="if this is a invoice, just return True else return False", images=[img_base64]) - assert "true" in res.lower() - if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/llm_vision.py b/examples/llm_vision.py new file mode 100644 index 000000000..276decd59 --- /dev/null +++ b/examples/llm_vision.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : example to run the ability of LLM vision + +import asyncio +from pathlib import Path + +from metagpt.llm import LLM +from metagpt.utils.common import encode_image + + +async def main(): + llm = LLM() + + # check if the configured llm supports llm-vision capacity. If not, it will throw a error + invoice_path = Path(__file__).parent.joinpath("..", "tests", "data", "invoices", "invoice-2.png") + img_base64 = encode_image(invoice_path) + res = await llm.aask(msg="if this is a invoice, just return True else return False", images=[img_base64]) + assert "true" in res.lower() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/metagpt/provider/general_api_requestor.py b/metagpt/provider/general_api_requestor.py index 500cd1426..18f4dd909 100644 --- a/metagpt/provider/general_api_requestor.py +++ b/metagpt/provider/general_api_requestor.py @@ -60,7 +60,8 @@ class GeneralAPIRequestor(APIRequestor): self, result: requests.Response, stream: bool ) -> Tuple[Union[bytes, Iterator[Generator]], bytes]: """Returns the response(s) and a bool indicating whether it is a stream.""" - if stream and "text/event-stream" in result.headers.get("Content-Type", ""): + content_type = result.headers.get("Content-Type", "") + if stream and ("text/event-stream" in content_type or "application/x-ndjson" in content_type): return ( self._interpret_response_line(line, result.status_code, result.headers, stream=True) for line in parse_stream(result.iter_lines()) diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index 9108a1fba..9e8e5fb53 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -3,9 +3,8 @@ # @Desc : zhipuai LLM from https://open.bigmodel.cn/dev/api#sdk from enum import Enum +from typing import Optional -import openai -import zhipuai from requests import ConnectionError from tenacity import ( after_log, @@ -14,6 +13,7 @@ from tenacity import ( stop_after_attempt, wait_random_exponential, ) +from zhipuai.types.chat.chat_completion import Completion from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.logs import log_llm_stream, logger @@ -21,6 +21,7 @@ from metagpt.provider.base_llm import BaseLLM from metagpt.provider.llm_provider_registry import register_provider from metagpt.provider.openai_api import log_and_reraise from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI +from metagpt.utils.cost_manager import CostManager class ZhiPuEvent(Enum): @@ -38,20 +39,15 @@ class ZhiPuAILLM(BaseLLM): """ def __init__(self, config: LLMConfig): - self.__init_zhipuai(config) - self.llm = ZhiPuModelAPI - self.model = "chatglm_turbo" # so far only one model, just use it - self.use_system_prompt: bool = False # zhipuai has no system prompt when use api self.config = config + self.__init_zhipuai() + self.cost_manager: Optional[CostManager] = None - def __init_zhipuai(self, config: LLMConfig): - assert config.api_key - zhipuai.api_key = config.api_key - # due to use openai sdk, set the api_key but it will't be used. - # openai.api_key = zhipuai.api_key # due to use openai sdk, set the api_key but it will't be used. - if config.proxy: - # FIXME: openai v1.x sdk has no proxy support - openai.proxy = config.proxy + def __init_zhipuai(self): + assert self.config.api_key + self.api_key = self.config.api_key + self.model = self.config.model # so far, it support glm-3-turbo、glm-4 + self.llm = ZhiPuModelAPI(api_key=self.api_key) def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict: kwargs = {"model": self.model, "messages": messages, "stream": stream, "temperature": 0.3} @@ -63,12 +59,12 @@ class ZhiPuAILLM(BaseLLM): try: prompt_tokens = int(usage.get("prompt_tokens", 0)) completion_tokens = int(usage.get("completion_tokens", 0)) - self.config.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) + self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) except Exception as e: logger.error(f"zhipuai updats costs failed! exp: {e}") def completion(self, messages: list[dict], timeout=3) -> dict: - resp = self.llm.chat.completions.create(**self._const_kwargs(messages)) + resp: Completion = self.llm.chat.completions.create(**self._const_kwargs(messages)) usage = resp.usage.model_dump() self._update_costs(usage) return resp.model_dump() diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index a0fb3b70d..65f5fe76f 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -32,8 +32,8 @@ TOKEN_COSTS = { "gpt-4-vision-preview": {"prompt": 0.01, "completion": 0.03}, # TODO add extra image price calculator "gpt-4-1106-vision-preview": {"prompt": 0.01, "completion": 0.03}, "text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0}, - "glm-3-turbo": {"prompt": 0.0, "completion": 0.0007}, # 128k version, prompt + completion tokens=0.005¥/k-tokens - "glm-4": {"prompt": 0.0, "completion": 0.014}, # 128k version, prompt + completion tokens=0.1¥/k-tokens + "glm-3-turbo": {"prompt": 0.0007, "completion": 0.0007}, # 128k version, prompt + completion tokens=0.005¥/k-tokens + "glm-4": {"prompt": 0.014, "completion": 0.014}, # 128k version, prompt + completion tokens=0.1¥/k-tokens "gemini-pro": {"prompt": 0.00025, "completion": 0.0005}, } @@ -58,7 +58,8 @@ TOKEN_MAX = { "gpt-4-vision-preview": 128000, "gpt-4-1106-vision-preview": 128000, "text-embedding-ada-002": 8192, - "chatglm_turbo": 32768, + "glm-3-turbo": 128000, + "glm-4": 128000, "gemini-pro": 32768, } diff --git a/tests/metagpt/provider/test_zhipuai_api.py b/tests/metagpt/provider/test_zhipuai_api.py index 798209710..ad2ececa2 100644 --- a/tests/metagpt/provider/test_zhipuai_api.py +++ b/tests/metagpt/provider/test_zhipuai_api.py @@ -17,7 +17,7 @@ default_resp = { } -async def mock_zhipuai_acreate_stream(**kwargs): +async def mock_zhipuai_acreate_stream(self, **kwargs): class MockResponse(object): async def _aread(self): class Iterator(object): @@ -37,7 +37,7 @@ async def mock_zhipuai_acreate_stream(**kwargs): return MockResponse() -async def mock_zhipuai_acreate(**kwargs) -> dict: +async def mock_zhipuai_acreate(self, **kwargs) -> dict: return default_resp From 3b4379d12569cae719ff58f6c39208eed05483aa Mon Sep 17 00:00:00 2001 From: voidking Date: Wed, 7 Feb 2024 10:34:04 +0800 Subject: [PATCH 03/15] chore: move the required playwright to requirements.txt --- requirements.txt | 2 +- setup.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 6cb25d52b..804ff4359 100644 --- a/requirements.txt +++ b/requirements.txt @@ -63,7 +63,7 @@ gitignore-parser==0.1.9 websockets~=12.0 networkx~=3.2.1 google-generativeai==0.3.2 -# playwright==1.40.0 # playwright extras require +playwright>=1.26 # used at metagpt/tools/libs/web_scraping.py anytree ipywidgets==8.1.1 Pillow diff --git a/setup.py b/setup.py index b16d978cf..be3956ea4 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,6 @@ requirements = (here / "requirements.txt").read_text(encoding="utf-8").splitline extras_require = { - "playwright": ["playwright>=1.26", "beautifulsoup4"], "selenium": ["selenium>4", "webdriver_manager", "beautifulsoup4"], "search-google": ["google-api-python-client==2.94.0"], "search-ddg": ["duckduckgo-search~=4.1.1"], From 63ab24a77bbb850baed77b515941342d48329aca Mon Sep 17 00:00:00 2001 From: voidking Date: Wed, 7 Feb 2024 11:54:31 +0800 Subject: [PATCH 04/15] chore: add one more space --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 804ff4359..1426500ce 100644 --- a/requirements.txt +++ b/requirements.txt @@ -63,7 +63,7 @@ gitignore-parser==0.1.9 websockets~=12.0 networkx~=3.2.1 google-generativeai==0.3.2 -playwright>=1.26 # used at metagpt/tools/libs/web_scraping.py +playwright>=1.26 # used at metagpt/tools/libs/web_scraping.py anytree ipywidgets==8.1.1 Pillow From d180d3912e33aca2c5968f4a80c6a94b2189d020 Mon Sep 17 00:00:00 2001 From: better629 Date: Wed, 7 Feb 2024 15:56:01 +0800 Subject: [PATCH 05/15] add qianfan api support --- examples/llm_hello_world.py | 21 +++-- metagpt/configs/llm_config.py | 8 +- metagpt/provider/__init__.py | 2 + metagpt/provider/base_llm.py | 16 ++++ metagpt/provider/qianfan_api.py | 151 ++++++++++++++++++++++++++++++++ metagpt/utils/cost_manager.py | 4 +- metagpt/utils/token_counter.py | 53 +++++++++++ requirements.txt | 1 + 8 files changed, 245 insertions(+), 11 deletions(-) create mode 100644 metagpt/provider/qianfan_api.py diff --git a/examples/llm_hello_world.py b/examples/llm_hello_world.py index 1d132eb8a..e22edbdf2 100644 --- a/examples/llm_hello_world.py +++ b/examples/llm_hello_world.py @@ -6,16 +6,25 @@ @File : llm_hello_world.py """ import asyncio -from pathlib import Path from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.utils.common import encode_image async def main(): llm = LLM() - logger.info(await llm.aask("hello world")) + # llm type check + id_ques = "what's your name" + logger.info(f"{id_ques}: ") + logger.info(await llm.aask(id_ques)) + logger.info("\n\n") + + logger.info( + await llm.aask( + "who are you", system_msgs=["act as a robot, answer 'I'am robot' if the question is 'who are you'"] + ) + ) + logger.info(await llm.aask_batch(["hi", "write python hello world."])) hello_msg = [{"role": "user", "content": "count from 1 to 10. split by newline."}] @@ -29,12 +38,6 @@ async def main(): if hasattr(llm, "completion"): logger.info(llm.completion(hello_msg)) - # check if the configured llm supports llm-vision capacity. If not, it will throw a error - invoice_path = Path(__file__).parent.joinpath("..", "tests", "data", "invoices", "invoice-2.png") - img_base64 = encode_image(invoice_path) - res = await llm.aask(msg="if this is a invoice, just return True else return False", images=[img_base64]) - assert "true" in res.lower() - if __name__ == "__main__": asyncio.run(main()) diff --git a/metagpt/configs/llm_config.py b/metagpt/configs/llm_config.py index fb923d3e4..1b05b5270 100644 --- a/metagpt/configs/llm_config.py +++ b/metagpt/configs/llm_config.py @@ -24,6 +24,7 @@ class LLMType(Enum): METAGPT = "metagpt" AZURE = "azure" OLLAMA = "ollama" + QIANFAN = "qianfan" # Baidu BCE def __missing__(self, key): return self.OPENAI @@ -36,13 +37,18 @@ class LLMConfig(YamlModel): Optional Fields in pydantic: https://docs.pydantic.dev/latest/migration/#required-optional-and-nullable-fields """ - api_key: str + api_key: str = "sk-" api_type: LLMType = LLMType.OPENAI base_url: str = "https://api.openai.com/v1" api_version: Optional[str] = None model: Optional[str] = None # also stands for DEPLOYMENT_NAME + # For Cloud Service Provider like Baidu/ Alibaba + access_key: Optional[str] = None + secret_key: Optional[str] = None + endpoint: Optional[str] = None # for self-deployed model on the cloud + # For Spark(Xunfei), maybe remove later app_id: Optional[str] = None api_secret: Optional[str] = None diff --git a/metagpt/provider/__init__.py b/metagpt/provider/__init__.py index 675734811..8c0aab836 100644 --- a/metagpt/provider/__init__.py +++ b/metagpt/provider/__init__.py @@ -16,6 +16,7 @@ from metagpt.provider.azure_openai_api import AzureOpenAILLM from metagpt.provider.metagpt_api import MetaGPTLLM from metagpt.provider.human_provider import HumanProvider from metagpt.provider.spark_api import SparkLLM +from metagpt.provider.qianfan_api import QianFanLLM __all__ = [ "FireworksLLM", @@ -28,4 +29,5 @@ __all__ = [ "OllamaLLM", "HumanProvider", "SparkLLM", + "QianFanLLM", ] diff --git a/metagpt/provider/base_llm.py b/metagpt/provider/base_llm.py index b144471b5..d3d9c829b 100644 --- a/metagpt/provider/base_llm.py +++ b/metagpt/provider/base_llm.py @@ -67,6 +67,22 @@ class BaseLLM(ABC): def _default_system_msg(self): return self._system_msg(self.system_prompt) + def _update_costs(self, usage: dict, model: str = None, local_calc_usage: bool = True): + """update each request's token cost + Args: + model (str): model name or in some scenarios called endpoint + local_calc_usage (bool): some models don't calculate usage, it will overwrite calc_usage + """ + calc_usage = self.config.calc_usage and local_calc_usage + model = model if model else self.model + if calc_usage and self.cost_manager: + try: + prompt_tokens = int(usage.get("prompt_tokens", 0)) + completion_tokens = int(usage.get("completion_tokens", 0)) + self.cost_manager.update_cost(prompt_tokens, completion_tokens, model) + except Exception as e: + logger.error(f"{self.__class__.__name__} updats costs failed! exp: {e}") + async def aask( self, msg: str, diff --git a/metagpt/provider/qianfan_api.py b/metagpt/provider/qianfan_api.py new file mode 100644 index 000000000..180935e61 --- /dev/null +++ b/metagpt/provider/qianfan_api.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : llm api of qianfan from Baidu, supports ERNIE(wen xin yi yan) and opensource models +import copy +import os + +import qianfan +from qianfan.resources.typing import JsonBody +from tenacity import ( + after_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_random_exponential, +) + +from metagpt.configs.llm_config import LLMConfig, LLMType +from metagpt.logs import log_llm_stream, logger +from metagpt.provider.base_llm import BaseLLM +from metagpt.provider.llm_provider_registry import register_provider +from metagpt.provider.openai_api import log_and_reraise +from metagpt.utils.cost_manager import CostManager +from metagpt.utils.token_counter import ( + QianFan_EndPoint_TOKEN_COSTS, + QianFan_MODEL_TOKEN_COSTS, +) + + +@register_provider(LLMType.QIANFAN) +class QianFanLLM(BaseLLM): + """ + Refs + Auth: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/3lmokh7n6#%E3%80%90%E6%8E%A8%E8%8D%90%E3%80%91%E4%BD%BF%E7%94%A8%E5%AE%89%E5%85%A8%E8%AE%A4%E8%AF%81aksk%E9%89%B4%E6%9D%83%E8%B0%83%E7%94%A8%E6%B5%81%E7%A8%8B + Token Price: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7#tokens%E5%90%8E%E4%BB%98%E8%B4%B9 + Models: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/wlmhm7vuo#%E5%AF%B9%E8%AF%9Dchat + https://cloud.baidu.com/doc/WENXINWORKSHOP/s/xlmokikxe#%E6%94%AF%E6%8C%81%E6%A8%A1%E5%9E%8B%E5%88%97%E8%A1%A8 + """ + + def __init__(self, config: LLMConfig): + self.config = config + self.use_system_prompt = False # only some ERNIE-x related models support system_prompt + self.__init_qianfan() + self.cost_manager = CostManager(token_costs=self.token_costs) + + def __init_qianfan(self): + if self.config.access_key and self.config.secret_key: + # for system level auth, use access_key and secret_key, recommended by official + # set environment variable due to official recommendation + os.environ.setdefault("QIANFAN_ACCESS_KEY", self.config.access_key) + os.environ.setdefault("QIANFAN_SECRET_KEY", self.config.secret_key) + elif self.config.api_key and self.config.secret_key: + # for application level auth, use api_key and secret_key + # set environment variable due to official recommendation + os.environ.setdefault("QIANFAN_AK", self.config.api_key) + os.environ.setdefault("QIANFAN_SK", self.config.secret_key) + else: + raise ValueError("Set the `access_key`&`secret_key` or `api_key`&`secret_key` first") + + support_system_pairs = [ + ("ERNIE-Bot-4", "completions_pro"), # (model, corresponding-endpoint) + ("ERNIE-Bot-8k", "ernie_bot_8k"), + ("ERNIE-Bot", "completions"), + ("ERNIE-Bot-turbo", "eb-instant"), + ("ERNIE-Speed", "ernie_speed"), + ("EB-turbo-AppBuilder", "ai_apaas"), + ] + if self.config.model in [pair[0] for pair in support_system_pairs]: + # only some ERNIE models support + self.use_system_prompt = True + if self.config.endpoint in [pair[1] for pair in support_system_pairs]: + self.use_system_prompt = True + + assert not (self.config.model and self.config.endpoint), "Only set `model` or `endpoint` in the config" + assert self.config.model or self.config.endpoint, "Should set one of `model` or `endpoint` in the config" + + self.token_costs = copy.deepcopy(QianFan_MODEL_TOKEN_COSTS) + self.token_costs.update(QianFan_EndPoint_TOKEN_COSTS) + + # self deployed model on the cloud not to calculate usage, it charges resource pool rental fee + self.calc_usage = self.config.calc_usage and self.config.endpoint is None + self.client = qianfan.ChatCompletion() + + def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict: + kwargs = { + "messages": messages, + "stream": stream, + } + if self.config.temperature > 0: + # different model has default temperature. only set when it's specified. + kwargs["temperature"] = self.config.temperature + if self.config.endpoint: + kwargs["endpoint"] = self.config.endpoint + elif self.config.model: + kwargs["model"] = self.config.model + + if self.use_system_prompt: + # if the model support system prompt, extract and pass it + if messages[0]["role"] == "system": + kwargs["messages"] = messages[1:] + kwargs["system"] = messages[0]["content"] # set system prompt here + return kwargs + + def _update_costs(self, usage: dict): + """update each request's token cost""" + model_or_endpoint = self.config.model if self.config.model else self.config.endpoint + local_calc_usage = True if model_or_endpoint in self.token_costs else False + super()._update_costs(usage, model_or_endpoint, local_calc_usage) + + def get_choice_text(self, resp: JsonBody) -> str: + return resp.get("result", "") + + def completion(self, messages: list[dict]) -> JsonBody: + resp = self.client.do(**self._const_kwargs(messages=messages, stream=False)) + self._update_costs(resp.body.get("usage", {})) + return resp.body + + async def _achat_completion(self, messages: list[dict]) -> JsonBody: + resp = await self.client.ado(**self._const_kwargs(messages=messages, stream=False)) + self._update_costs(resp.body.get("usage", {})) + return resp.body + + async def acompletion(self, messages: list[dict], timeout=3) -> JsonBody: + return await self._achat_completion(messages) + + async def _achat_completion_stream(self, messages: list[dict]) -> str: + resp = await self.client.ado(**self._const_kwargs(messages=messages, stream=True)) + collected_content = [] + usage = {} + async for chunk in resp: + content = chunk.body.get("result", "") + usage = chunk.body.get("usage", {}) + log_llm_stream(content) + collected_content.append(content) + log_llm_stream("\n") + + self._update_costs(usage) + full_content = "".join(collected_content) + return full_content + + @retry( + stop=stop_after_attempt(3), + wait=wait_random_exponential(min=1, max=60), + after=after_log(logger, logger.level("WARNING").name), + retry=retry_if_exception_type(ConnectionError), + retry_error_callback=log_and_reraise, + ) + async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str: + if stream: + return await self._achat_completion_stream(messages) + resp = await self._achat_completion(messages) + return self.get_choice_text(resp) diff --git a/metagpt/utils/cost_manager.py b/metagpt/utils/cost_manager.py index 7bf5154b6..e1c0f415b 100644 --- a/metagpt/utils/cost_manager.py +++ b/metagpt/utils/cost_manager.py @@ -29,6 +29,7 @@ class CostManager(BaseModel): total_budget: float = 0 max_budget: float = 10.0 total_cost: float = 0 + token_costs: dict[str, dict[str, float]] = TOKEN_COSTS def update_cost(self, prompt_tokens, completion_tokens, model): """ @@ -42,7 +43,8 @@ class CostManager(BaseModel): self.total_prompt_tokens += prompt_tokens self.total_completion_tokens += completion_tokens cost = ( - prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model]["completion"] + prompt_tokens * self.token_costs[model]["prompt"] + + completion_tokens * self.token_costs[model]["completion"] ) / 1000 self.total_cost += cost logger.info( diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index a0fb3b70d..b69ec73d3 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -38,6 +38,59 @@ TOKEN_COSTS = { } +""" +QianFan Token Price https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7#tokens%E5%90%8E%E4%BB%98%E8%B4%B9 +Due to QianFan has multi price strategies, we unify `Tokens post-payment` as a statistical method. +""" +QianFan_MODEL_TOKEN_COSTS = { + "ERNIE-Bot-4": {"prompt": 0.017, "completion": 0.017}, + "ERNIE-Bot-8k": {"prompt": 0.0034, "completion": 0.0067}, + "ERNIE-Bot": {"prompt": 0.017, "completion": 0.017}, + "ERNIE-Bot-turbo": {"prompt": 0.0011, "completion": 0.0011}, + "EB-turbo-AppBuilder": {"prompt": 0.0011, "completion": 0.0011}, + "ERNIE-Speed": {"prompt": 0.00056, "completion": 0.0011}, + "BLOOMZ-7B": {"prompt": 0.00056, "completion": 0.00056}, + "Llama-2-7B-Chat": {"prompt": 0.00056, "completion": 0.00056}, + "Llama-2-13B-Chat": {"prompt": 0.00084, "completion": 0.00084}, + "Llama-2-70B-Chat": {"prompt": 0.0049, "completion": 0.0049}, + "ChatGLM2-6B-32K": {"prompt": 0.00056, "completion": 0.00056}, + "AquilaChat-7B": {"prompt": 0.00056, "completion": 0.00056}, + "Mixtral-8x7B-Instruct": {"prompt": 0.0049, "completion": 0.0049}, + "SQLCoder-7B": {"prompt": 0.00056, "completion": 0.00056}, + "CodeLlama-7B-Instruct": {"prompt": 0.00056, "completion": 0.00056}, + "XuanYuan-70B-Chat-4bit": {"prompt": 0.0049, "completion": 0.0049}, + "Qianfan-BLOOMZ-7B-compressed": {"prompt": 0.00056, "completion": 0.00056}, + "Qianfan-Chinese-Llama-2-7B": {"prompt": 0.00056, "completion": 0.00056}, + "Qianfan-Chinese-Llama-2-13B": {"prompt": 0.00084, "completion": 0.00084}, + "ChatLaw": {"prompt": 0.0011, "completion": 0.0011}, + "Yi-34B-Chat": {"prompt": 0.0, "completion": 0.0}, +} + +QianFan_EndPoint_TOKEN_COSTS = { + "completions_pro": QianFan_MODEL_TOKEN_COSTS["ERNIE-Bot-4"], + "ernie_bot_8k": QianFan_MODEL_TOKEN_COSTS["ERNIE-Bot-8k"], + "completions": QianFan_MODEL_TOKEN_COSTS["ERNIE-Bot"], + "eb-instant": QianFan_MODEL_TOKEN_COSTS["ERNIE-Bot-turbo"], + "ai_apaas": QianFan_MODEL_TOKEN_COSTS["EB-turbo-AppBuilder"], + "ernie_speed": QianFan_MODEL_TOKEN_COSTS["ERNIE-Speed"], + "bloomz_7b1": QianFan_MODEL_TOKEN_COSTS["BLOOMZ-7B"], + "llama_2_7b": QianFan_MODEL_TOKEN_COSTS["Llama-2-7B-Chat"], + "llama_2_13b": QianFan_MODEL_TOKEN_COSTS["Llama-2-13B-Chat"], + "llama_2_70b": QianFan_MODEL_TOKEN_COSTS["Llama-2-70B-Chat"], + "chatglm2_6b_32k": QianFan_MODEL_TOKEN_COSTS["ChatGLM2-6B-32K"], + "aquilachat_7b": QianFan_MODEL_TOKEN_COSTS["AquilaChat-7B"], + "mixtral_8x7b_instruct": QianFan_MODEL_TOKEN_COSTS["Mixtral-8x7B-Instruct"], + "sqlcoder_7b": QianFan_MODEL_TOKEN_COSTS["SQLCoder-7B"], + "codellama_7b_instruct": QianFan_MODEL_TOKEN_COSTS["CodeLlama-7B-Instruct"], + "xuanyuan_70b_chat": QianFan_MODEL_TOKEN_COSTS["XuanYuan-70B-Chat-4bit"], + "qianfan_bloomz_7b_compressed": QianFan_MODEL_TOKEN_COSTS["Qianfan-BLOOMZ-7B-compressed"], + "qianfan_chinese_llama_2_7b": QianFan_MODEL_TOKEN_COSTS["Qianfan-Chinese-Llama-2-7B"], + "qianfan_chinese_llama_2_13b": QianFan_MODEL_TOKEN_COSTS["Qianfan-Chinese-Llama-2-13B"], + "chatlaw": QianFan_MODEL_TOKEN_COSTS["ChatLaw"], + "yi_34b_chat": QianFan_MODEL_TOKEN_COSTS["Yi-34B-Chat"], +} + + TOKEN_MAX = { "gpt-3.5-turbo": 4096, "gpt-3.5-turbo-0301": 4096, diff --git a/requirements.txt b/requirements.txt index 6cb25d52b..c893bd713 100644 --- a/requirements.txt +++ b/requirements.txt @@ -67,3 +67,4 @@ google-generativeai==0.3.2 anytree ipywidgets==8.1.1 Pillow +qianfan==0.3.1 From 15a9c5e94135992e9854a57d14d581040879386f Mon Sep 17 00:00:00 2001 From: better629 Date: Wed, 7 Feb 2024 16:16:14 +0800 Subject: [PATCH 06/15] simplify _update_costs and related code --- metagpt/provider/base_llm.py | 13 ++++++++++--- metagpt/provider/fireworks_api.py | 15 ++------------- metagpt/provider/google_gemini_api.py | 10 ---------- metagpt/provider/ollama_api.py | 10 ---------- metagpt/provider/open_llm_api.py | 13 +------------ metagpt/provider/openai_api.py | 17 ++--------------- metagpt/provider/qianfan_api.py | 8 ++++---- metagpt/provider/zhipuai_api.py | 10 ---------- 8 files changed, 19 insertions(+), 77 deletions(-) diff --git a/metagpt/provider/base_llm.py b/metagpt/provider/base_llm.py index d3d9c829b..2f57b15aa 100644 --- a/metagpt/provider/base_llm.py +++ b/metagpt/provider/base_llm.py @@ -11,11 +11,12 @@ from abc import ABC, abstractmethod from typing import Optional, Union from openai import AsyncOpenAI +from pydantic import BaseModel from metagpt.configs.llm_config import LLMConfig from metagpt.logs import logger from metagpt.schema import Message -from metagpt.utils.cost_manager import CostManager +from metagpt.utils.cost_manager import CostManager, Costs class BaseLLM(ABC): @@ -67,14 +68,15 @@ class BaseLLM(ABC): def _default_system_msg(self): return self._system_msg(self.system_prompt) - def _update_costs(self, usage: dict, model: str = None, local_calc_usage: bool = True): + def _update_costs(self, usage: Union[dict, BaseModel], model: str = None, local_calc_usage: bool = True): """update each request's token cost Args: model (str): model name or in some scenarios called endpoint - local_calc_usage (bool): some models don't calculate usage, it will overwrite calc_usage + local_calc_usage (bool): some models don't calculate usage, it will overwrite LLMConfig.calc_usage """ calc_usage = self.config.calc_usage and local_calc_usage model = model if model else self.model + usage = usage.model_dump() if isinstance(usage, BaseModel) else usage if calc_usage and self.cost_manager: try: prompt_tokens = int(usage.get("prompt_tokens", 0)) @@ -83,6 +85,11 @@ class BaseLLM(ABC): except Exception as e: logger.error(f"{self.__class__.__name__} updats costs failed! exp: {e}") + def get_costs(self) -> Costs: + if not self.cost_manager: + return Costs(0, 0, 0, 0) + return self.cost_manager.get_costs() + async def aask( self, msg: str, diff --git a/metagpt/provider/fireworks_api.py b/metagpt/provider/fireworks_api.py index d56453a85..e62a7066e 100644 --- a/metagpt/provider/fireworks_api.py +++ b/metagpt/provider/fireworks_api.py @@ -19,7 +19,7 @@ from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.logs import logger from metagpt.provider.llm_provider_registry import register_provider from metagpt.provider.openai_api import OpenAILLM, log_and_reraise -from metagpt.utils.cost_manager import CostManager, Costs +from metagpt.utils.cost_manager import CostManager MODEL_GRADE_TOKEN_COSTS = { "-1": {"prompt": 0.0, "completion": 0.0}, # abnormal condition @@ -81,17 +81,6 @@ class FireworksLLM(OpenAILLM): kwargs = dict(api_key=self.config.api_key, base_url=self.config.base_url) return kwargs - def _update_costs(self, usage: CompletionUsage): - if self.config.calc_usage and usage: - try: - # use FireworksCostManager not context.cost_manager - self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model) - except Exception as e: - logger.error(f"updating costs failed!, exp: {e}") - - def get_costs(self) -> Costs: - return self.cost_manager.get_costs() - async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str: response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create( **self._cons_kwargs(messages), stream=True @@ -113,7 +102,7 @@ class FireworksLLM(OpenAILLM): usage = CompletionUsage(**chunk.usage) full_content = "".join(collected_content) - self._update_costs(usage) + self._update_costs(usage.model_dump()) return full_content @retry( diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index 2647ab16b..87ea81c80 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -72,16 +72,6 @@ class GeminiLLM(BaseLLM): kwargs = {"contents": messages, "generation_config": GenerationConfig(temperature=0.3), "stream": stream} return kwargs - def _update_costs(self, usage: dict): - """update each request's token cost""" - if self.config.calc_usage: - try: - prompt_tokens = int(usage.get("prompt_tokens", 0)) - completion_tokens = int(usage.get("completion_tokens", 0)) - self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) - except Exception as e: - logger.error(f"google gemini updats costs failed! exp: {e}") - def get_choice_text(self, resp: GenerateContentResponse) -> str: return resp.text diff --git a/metagpt/provider/ollama_api.py b/metagpt/provider/ollama_api.py index c9103b018..52e8dbe36 100644 --- a/metagpt/provider/ollama_api.py +++ b/metagpt/provider/ollama_api.py @@ -46,16 +46,6 @@ class OllamaLLM(BaseLLM): kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream} return kwargs - def _update_costs(self, usage: dict): - """update each request's token cost""" - if self.config.calc_usage: - try: - prompt_tokens = int(usage.get("prompt_tokens", 0)) - completion_tokens = int(usage.get("completion_tokens", 0)) - self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) - except Exception as e: - logger.error(f"ollama updats costs failed! exp: {e}") - def get_choice_text(self, resp: dict) -> str: """get the resp content from llm response""" assist_msg = resp.get("message", {}) diff --git a/metagpt/provider/open_llm_api.py b/metagpt/provider/open_llm_api.py index a29b263a4..69371e379 100644 --- a/metagpt/provider/open_llm_api.py +++ b/metagpt/provider/open_llm_api.py @@ -8,7 +8,7 @@ from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.logs import logger from metagpt.provider.llm_provider_registry import register_provider from metagpt.provider.openai_api import OpenAILLM -from metagpt.utils.cost_manager import Costs, TokenCostManager +from metagpt.utils.cost_manager import TokenCostManager from metagpt.utils.token_counter import count_message_tokens, count_string_tokens @@ -34,14 +34,3 @@ class OpenLLM(OpenAILLM): logger.error(f"usage calculation failed!: {e}") return usage - - def _update_costs(self, usage: CompletionUsage): - if self.config.calc_usage and usage: - try: - # use OpenLLMCostManager not CONFIG.cost_manager - self._cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model) - except Exception as e: - logger.error(f"updating costs failed!, exp: {e}") - - def get_costs(self) -> Costs: - return self._cost_manager.get_costs() diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 63e68c9bd..1e5770d74 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -29,7 +29,7 @@ from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA from metagpt.provider.llm_provider_registry import register_provider from metagpt.schema import Message from metagpt.utils.common import CodeParser, decode_image -from metagpt.utils.cost_manager import CostManager, Costs +from metagpt.utils.cost_manager import CostManager from metagpt.utils.exceptions import handle_exception from metagpt.utils.token_counter import ( count_message_tokens, @@ -55,16 +55,13 @@ class OpenAILLM(BaseLLM): def __init__(self, config: LLMConfig): self.config = config - self._init_model() self._init_client() self.auto_max_tokens = False self.cost_manager: Optional[CostManager] = None - def _init_model(self): - self.model = self.config.model # Used in _calc_usage & _cons_kwargs - def _init_client(self): """https://github.com/openai/openai-python#async-usage""" + self.model = self.config.model # Used in _calc_usage & _cons_kwargs kwargs = self._make_client_kwargs() self.aclient = AsyncOpenAI(**kwargs) @@ -240,16 +237,6 @@ class OpenAILLM(BaseLLM): return usage - @handle_exception - def _update_costs(self, usage: CompletionUsage): - if self.config.calc_usage and usage and self.cost_manager: - self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model) - - def get_costs(self) -> Costs: - if not self.cost_manager: - return Costs(0, 0, 0, 0) - return self.cost_manager.get_costs() - def _get_max_tokens(self, messages: list[dict]): if not self.auto_max_tokens: return self.config.max_token diff --git a/metagpt/provider/qianfan_api.py b/metagpt/provider/qianfan_api.py index 180935e61..fbbff7085 100644 --- a/metagpt/provider/qianfan_api.py +++ b/metagpt/provider/qianfan_api.py @@ -78,7 +78,7 @@ class QianFanLLM(BaseLLM): # self deployed model on the cloud not to calculate usage, it charges resource pool rental fee self.calc_usage = self.config.calc_usage and self.config.endpoint is None - self.client = qianfan.ChatCompletion() + self.aclient = qianfan.ChatCompletion() def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict: kwargs = { @@ -110,12 +110,12 @@ class QianFanLLM(BaseLLM): return resp.get("result", "") def completion(self, messages: list[dict]) -> JsonBody: - resp = self.client.do(**self._const_kwargs(messages=messages, stream=False)) + resp = self.aclient.do(**self._const_kwargs(messages=messages, stream=False)) self._update_costs(resp.body.get("usage", {})) return resp.body async def _achat_completion(self, messages: list[dict]) -> JsonBody: - resp = await self.client.ado(**self._const_kwargs(messages=messages, stream=False)) + resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=False)) self._update_costs(resp.body.get("usage", {})) return resp.body @@ -123,7 +123,7 @@ class QianFanLLM(BaseLLM): return await self._achat_completion(messages) async def _achat_completion_stream(self, messages: list[dict]) -> str: - resp = await self.client.ado(**self._const_kwargs(messages=messages, stream=True)) + resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=True)) collected_content = [] usage = {} async for chunk in resp: diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index 9108a1fba..b7c160a41 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -57,16 +57,6 @@ class ZhiPuAILLM(BaseLLM): kwargs = {"model": self.model, "messages": messages, "stream": stream, "temperature": 0.3} return kwargs - def _update_costs(self, usage: dict): - """update each request's token cost""" - if self.config.calc_usage: - try: - prompt_tokens = int(usage.get("prompt_tokens", 0)) - completion_tokens = int(usage.get("completion_tokens", 0)) - self.config.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) - except Exception as e: - logger.error(f"zhipuai updats costs failed! exp: {e}") - def completion(self, messages: list[dict], timeout=3) -> dict: resp = self.llm.chat.completions.create(**self._const_kwargs(messages)) usage = resp.usage.model_dump() From 4370060802b3da936880aefb7aa28a6ba22780cd Mon Sep 17 00:00:00 2001 From: geekan Date: Wed, 7 Feb 2024 16:23:54 +0800 Subject: [PATCH 07/15] fix bug --- config/config2.yaml.example | 2 +- metagpt/actions/research.py | 2 +- metagpt/utils/cost_manager.py | 4 ++++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/config/config2.yaml.example b/config/config2.yaml.example index 8f4a33fc1..2217f1b2c 100644 --- a/config/config2.yaml.example +++ b/config/config2.yaml.example @@ -1,5 +1,5 @@ llm: - api_type: "openai" + api_type: "openai" # or azure / ollama etc. base_url: "YOUR_BASE_URL" api_key: "YOUR_API_KEY" model: "gpt-4-turbo-preview" # or gpt-3.5-turbo-1106 / gpt-4-1106-preview diff --git a/metagpt/actions/research.py b/metagpt/actions/research.py index 2ebeadb66..316e9f299 100644 --- a/metagpt/actions/research.py +++ b/metagpt/actions/research.py @@ -133,7 +133,7 @@ class CollectLinks(Action): if len(remove) == 0: break - model_name = config.get_openai_llm().model + model_name = config.model prompt = reduce_message_length(gen_msg(), model_name, system_text, 4096) logger.debug(prompt) queries = await self._aask(prompt, [system_text]) diff --git a/metagpt/utils/cost_manager.py b/metagpt/utils/cost_manager.py index 7bf5154b6..c4c93f91f 100644 --- a/metagpt/utils/cost_manager.py +++ b/metagpt/utils/cost_manager.py @@ -41,6 +41,10 @@ class CostManager(BaseModel): """ self.total_prompt_tokens += prompt_tokens self.total_completion_tokens += completion_tokens + if model not in TOKEN_COSTS: + logger.warning(f"Model {model} not found in TOKEN_COSTS.") + return + cost = ( prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model]["completion"] ) / 1000 From c0867643d828084e7503f05ae44987dccf3687d1 Mon Sep 17 00:00:00 2001 From: geekan Date: Wed, 7 Feb 2024 16:24:33 +0800 Subject: [PATCH 08/15] fix bug --- metagpt/actions/research.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metagpt/actions/research.py b/metagpt/actions/research.py index 316e9f299..ce8d8a967 100644 --- a/metagpt/actions/research.py +++ b/metagpt/actions/research.py @@ -133,7 +133,7 @@ class CollectLinks(Action): if len(remove) == 0: break - model_name = config.model + model_name = config.llm.model prompt = reduce_message_length(gen_msg(), model_name, system_text, 4096) logger.debug(prompt) queries = await self._aask(prompt, [system_text]) From d112371dadf02ee9a828c6708d0bbaa3e600c113 Mon Sep 17 00:00:00 2001 From: geekan Date: Wed, 7 Feb 2024 16:33:24 +0800 Subject: [PATCH 09/15] fix bug --- metagpt/utils/text.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metagpt/utils/text.py b/metagpt/utils/text.py index dd9678438..921efe706 100644 --- a/metagpt/utils/text.py +++ b/metagpt/utils/text.py @@ -25,7 +25,7 @@ def reduce_message_length( """ max_token = TOKEN_MAX.get(model_name, 2048) - count_string_tokens(system_text, model_name) - reserved for msg in msgs: - if count_string_tokens(msg, model_name) < max_token: + if count_string_tokens(msg, model_name) < max_token or model_name not in TOKEN_MAX: return msg raise RuntimeError("fail to reduce message length") From 50a14718baeacfada5cf7008e2761a801adbd968 Mon Sep 17 00:00:00 2001 From: geekan Date: Wed, 7 Feb 2024 16:37:23 +0800 Subject: [PATCH 10/15] refine log --- metagpt/provider/openai_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 63e68c9bd..120748d15 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -236,7 +236,7 @@ class OpenAILLM(BaseLLM): usage.prompt_tokens = count_message_tokens(messages, self.model) usage.completion_tokens = count_string_tokens(rsp, self.model) except Exception as e: - logger.error(f"usage calculation failed: {e}") + logger.warning(f"usage calculation failed: {e}") return usage From ce63e455dfe1071a99ee421c1e17df07db20200d Mon Sep 17 00:00:00 2001 From: geekan Date: Wed, 7 Feb 2024 17:03:10 +0800 Subject: [PATCH 11/15] fix bug --- metagpt/provider/openai_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 120748d15..756f8c483 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -253,7 +253,7 @@ class OpenAILLM(BaseLLM): def _get_max_tokens(self, messages: list[dict]): if not self.auto_max_tokens: return self.config.max_token - return get_max_completion_tokens(messages, self.model, self.config.max_tokens) + return get_max_completion_tokens(messages, self.model, self.config.max_token) @handle_exception async def amoderation(self, content: Union[str, list[str]]): From dc240a2efd161614f2e4b5090238f72682158ae5 Mon Sep 17 00:00:00 2001 From: better629 Date: Wed, 7 Feb 2024 17:40:27 +0800 Subject: [PATCH 12/15] simplify provider ut code --- .github/workflows/fulltest.yaml | 1 - .github/workflows/unittest.yaml | 2 +- tests/metagpt/provider/mock_llm_config.py | 10 +++ tests/metagpt/provider/req_resp_const.py | 80 +++++++++++++++++++ tests/metagpt/provider/test_anthropic_api.py | 12 +-- tests/metagpt/provider/test_base_llm.py | 53 +++++------- tests/metagpt/provider/test_fireworks_llm.py | 65 ++++----------- .../provider/test_google_gemini_api.py | 37 +++++---- tests/metagpt/provider/test_ollama_api.py | 20 +++-- tests/metagpt/provider/test_open_llm_api.py | 65 +++++---------- tests/metagpt/provider/test_qianfan_api.py | 15 ++++ tests/metagpt/provider/test_spark_api.py | 36 +++++---- tests/metagpt/provider/test_zhipuai_api.py | 33 ++++---- tests/spark.yaml | 7 -- 14 files changed, 235 insertions(+), 201 deletions(-) create mode 100644 tests/metagpt/provider/req_resp_const.py create mode 100644 tests/metagpt/provider/test_qianfan_api.py delete mode 100644 tests/spark.yaml diff --git a/.github/workflows/fulltest.yaml b/.github/workflows/fulltest.yaml index f5c6049e1..70c800481 100644 --- a/.github/workflows/fulltest.yaml +++ b/.github/workflows/fulltest.yaml @@ -54,7 +54,6 @@ jobs: export ALLOW_OPENAI_API_CALL=0 echo "${{ secrets.METAGPT_KEY_YAML }}" | base64 -d > config/key.yaml mkdir -p ~/.metagpt && echo "${{ secrets.METAGPT_CONFIG2_YAML }}" | base64 -d > ~/.metagpt/config2.yaml - echo "${{ secrets.SPARK_YAML }}" | base64 -d > ~/.metagpt/spark.yaml pytest tests/ --doctest-modules --cov=./metagpt/ --cov-report=xml:cov.xml --cov-report=html:htmlcov --durations=20 | tee unittest.txt - name: Show coverage report run: | diff --git a/.github/workflows/unittest.yaml b/.github/workflows/unittest.yaml index 2e7e3ce2b..afa9faba7 100644 --- a/.github/workflows/unittest.yaml +++ b/.github/workflows/unittest.yaml @@ -31,7 +31,7 @@ jobs: - name: Test with pytest run: | export ALLOW_OPENAI_API_CALL=0 - mkdir -p ~/.metagpt && cp tests/config2.yaml ~/.metagpt/config2.yaml && cp tests/spark.yaml ~/.metagpt/spark.yaml + mkdir -p ~/.metagpt && cp tests/config2.yaml ~/.metagpt/config2.yaml pytest tests/ --doctest-modules --cov=./metagpt/ --cov-report=xml:cov.xml --cov-report=html:htmlcov --durations=20 | tee unittest.txt - name: Show coverage report run: | diff --git a/tests/metagpt/provider/mock_llm_config.py b/tests/metagpt/provider/mock_llm_config.py index e2f626a6a..21780f914 100644 --- a/tests/metagpt/provider/mock_llm_config.py +++ b/tests/metagpt/provider/mock_llm_config.py @@ -42,3 +42,13 @@ mock_llm_config_zhipu = LLMConfig( model="mock_zhipu_model", proxy="http://localhost:8080", ) + + +mock_llm_config_spark = LLMConfig( + api_type="spark", + app_id="xxx", + api_key="xxx", + api_secret="xxx", + domain="generalv2", + base_url="wss://spark-api.xf-yun.com/v3.1/chat", +) diff --git a/tests/metagpt/provider/req_resp_const.py b/tests/metagpt/provider/req_resp_const.py new file mode 100644 index 000000000..a3a7a363c --- /dev/null +++ b/tests/metagpt/provider/req_resp_const.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : default request & response data for provider unittest + +from openai.types.chat.chat_completion import ( + ChatCompletion, + ChatCompletionMessage, + Choice, +) +from openai.types.chat.chat_completion_chunk import ChatCompletionChunk +from openai.types.chat.chat_completion_chunk import Choice as AChoice +from openai.types.chat.chat_completion_chunk import ChoiceDelta +from openai.types.completion_usage import CompletionUsage + +prompt = "who are you?" +messages = [{"role": "user", "content": prompt}] + +resp_cont_tmpl = "I'm {name}" +default_resp_cont = resp_cont_tmpl.format(name="GPT") + + +# part of whole ChatCompletion of openai like structure +def get_part_chat_completion(llm_name: str) -> dict: + part_chat_completion = { + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": resp_cont_tmpl.format(name=llm_name), + }, + "finish_reason": "stop", + } + ], + "usage": {"completion_tokens": 22, "prompt_tokens": 19, "total_tokens": 41}, + } + return part_chat_completion + + +def get_openai_chat_completion(llm_name: str) -> ChatCompletion: + openai_chat_completion = ChatCompletion( + id="cmpl-a6652c1bb181caae8dd19ad8", + model="xx/xxx", + object="chat.completion", + created=1703300855, + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage(role="assistant", content=resp_cont_tmpl.format(name=llm_name)), + logprobs=None, + ) + ], + usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202), + ) + return openai_chat_completion + + +def get_openai_chat_completion_chunk(llm_name: str, usage_as_dict: bool = False) -> ChatCompletionChunk: + usage = CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202) + usage = usage if not usage_as_dict else usage.model_dump() + openai_chat_completion_chunk = ChatCompletionChunk( + id="cmpl-a6652c1bb181caae8dd19ad8", + model="xx/xxx", + object="chat.completion.chunk", + created=1703300855, + choices=[ + AChoice( + delta=ChoiceDelta(role="assistant", content=resp_cont_tmpl.format(name=llm_name)), + finish_reason="stop", + index=0, + logprobs=None, + ) + ], + usage=usage, + ) + return openai_chat_completion_chunk + + +gemini_messages = [{"role": "user", "parts": prompt}] diff --git a/tests/metagpt/provider/test_anthropic_api.py b/tests/metagpt/provider/test_anthropic_api.py index 6962ab064..93cfd7dbc 100644 --- a/tests/metagpt/provider/test_anthropic_api.py +++ b/tests/metagpt/provider/test_anthropic_api.py @@ -8,25 +8,25 @@ from anthropic.resources.completions import Completion from metagpt.provider.anthropic_api import Claude2 from tests.metagpt.provider.mock_llm_config import mock_llm_config +from tests.metagpt.provider.req_resp_const import prompt, resp_cont_tmpl -prompt = "who are you" -resp = "I'am Claude2" +resp_cont = resp_cont_tmpl.format(name="Claude") def mock_anthropic_completions_create(self, model: str, prompt: str, max_tokens_to_sample: int) -> Completion: - return Completion(id="xx", completion=resp, model="claude-2", stop_reason="stop_sequence", type="completion") + return Completion(id="xx", completion=resp_cont, model="claude-2", stop_reason="stop_sequence", type="completion") async def mock_anthropic_acompletions_create(self, model: str, prompt: str, max_tokens_to_sample: int) -> Completion: - return Completion(id="xx", completion=resp, model="claude-2", stop_reason="stop_sequence", type="completion") + return Completion(id="xx", completion=resp_cont, model="claude-2", stop_reason="stop_sequence", type="completion") def test_claude2_ask(mocker): mocker.patch("anthropic.resources.completions.Completions.create", mock_anthropic_completions_create) - assert resp == Claude2(mock_llm_config).ask(prompt) + assert resp_cont == Claude2(mock_llm_config).ask(prompt) @pytest.mark.asyncio async def test_claude2_aask(mocker): mocker.patch("anthropic.resources.completions.AsyncCompletions.create", mock_anthropic_acompletions_create) - assert resp == await Claude2(mock_llm_config).aask(prompt) + assert resp_cont == await Claude2(mock_llm_config).aask(prompt) diff --git a/tests/metagpt/provider/test_base_llm.py b/tests/metagpt/provider/test_base_llm.py index cc781f78a..0babd6d5f 100644 --- a/tests/metagpt/provider/test_base_llm.py +++ b/tests/metagpt/provider/test_base_llm.py @@ -11,21 +11,13 @@ import pytest from metagpt.configs.llm_config import LLMConfig from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Message +from tests.metagpt.provider.req_resp_const import ( + default_resp_cont, + get_part_chat_completion, + prompt, +) -default_chat_resp = { - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": "I'am GPT", - }, - "finish_reason": "stop", - } - ] -} -prompt_msg = "who are you" -resp_content = default_chat_resp["choices"][0]["message"]["content"] +llm_name = "GPT" class MockBaseLLM(BaseLLM): @@ -33,16 +25,13 @@ class MockBaseLLM(BaseLLM): pass def completion(self, messages: list[dict], timeout=3): - return default_chat_resp + return get_part_chat_completion(llm_name) async def acompletion(self, messages: list[dict], timeout=3): - return default_chat_resp + return get_part_chat_completion(llm_name) async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str: - return resp_content - - async def close(self): - return default_chat_resp + return default_resp_cont def test_base_llm(): @@ -86,25 +75,25 @@ def test_base_llm(): choice_text = base_llm.get_choice_text(openai_funccall_resp) assert choice_text == openai_funccall_resp["choices"][0]["message"]["content"] - # resp = base_llm.ask(prompt_msg) - # assert resp == resp_content + # resp = base_llm.ask(prompt) + # assert resp == default_resp_cont - # resp = base_llm.ask_batch([prompt_msg]) - # assert resp == resp_content + # resp = base_llm.ask_batch([prompt]) + # assert resp == default_resp_cont - # resp = base_llm.ask_code([prompt_msg]) - # assert resp == resp_content + # resp = base_llm.ask_code([prompt]) + # assert resp == default_resp_cont @pytest.mark.asyncio async def test_async_base_llm(): base_llm = MockBaseLLM() - resp = await base_llm.aask(prompt_msg) - assert resp == resp_content + resp = await base_llm.aask(prompt) + assert resp == default_resp_cont - resp = await base_llm.aask_batch([prompt_msg]) - assert resp == resp_content + resp = await base_llm.aask_batch([prompt]) + assert resp == default_resp_cont - # resp = await base_llm.aask_code([prompt_msg]) - # assert resp == resp_content + # resp = await base_llm.aask_code([prompt]) + # assert resp == default_resp_cont diff --git a/tests/metagpt/provider/test_fireworks_llm.py b/tests/metagpt/provider/test_fireworks_llm.py index 66b55e5b2..834f6305f 100644 --- a/tests/metagpt/provider/test_fireworks_llm.py +++ b/tests/metagpt/provider/test_fireworks_llm.py @@ -3,14 +3,7 @@ # @Desc : the unittest of fireworks api import pytest -from openai.types.chat.chat_completion import ( - ChatCompletion, - ChatCompletionMessage, - Choice, -) from openai.types.chat.chat_completion_chunk import ChatCompletionChunk -from openai.types.chat.chat_completion_chunk import Choice as AChoice -from openai.types.chat.chat_completion_chunk import ChoiceDelta from openai.types.completion_usage import CompletionUsage from metagpt.provider.fireworks_api import ( @@ -20,42 +13,18 @@ from metagpt.provider.fireworks_api import ( ) from metagpt.utils.cost_manager import Costs from tests.metagpt.provider.mock_llm_config import mock_llm_config - -resp_content = "I'm fireworks" -default_resp = ChatCompletion( - id="cmpl-a6652c1bb181caae8dd19ad8", - model="accounts/fireworks/models/llama-v2-13b-chat", - object="chat.completion", - created=1703300855, - choices=[ - Choice( - finish_reason="stop", - index=0, - message=ChatCompletionMessage(role="assistant", content=resp_content), - logprobs=None, - ) - ], - usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202), +from tests.metagpt.provider.req_resp_const import ( + get_openai_chat_completion, + get_openai_chat_completion_chunk, + messages, + prompt, + resp_cont_tmpl, ) -default_resp_chunk = ChatCompletionChunk( - id=default_resp.id, - model=default_resp.model, - object="chat.completion.chunk", - created=default_resp.created, - choices=[ - AChoice( - delta=ChoiceDelta(content=resp_content, role="assistant"), - finish_reason="stop", - index=0, - logprobs=None, - ) - ], - usage=dict(default_resp.usage), -) - -prompt_msg = "who are you" -messages = [{"role": "user", "content": prompt_msg}] +llm_name = "fireworks" +resp_cont = resp_cont_tmpl.format(name=llm_name) +default_resp = get_openai_chat_completion(llm_name) +default_resp_chunk = get_openai_chat_completion_chunk(llm_name, usage_as_dict=True) def test_fireworks_costmanager(): @@ -99,16 +68,16 @@ async def test_fireworks_acompletion(mocker): ) resp = await fireworks_gpt.acompletion(messages) - assert resp.choices[0].message.content in resp_content + assert resp.choices[0].message.content in resp_cont - resp = await fireworks_gpt.aask(prompt_msg, stream=False) - assert resp == resp_content + resp = await fireworks_gpt.aask(prompt, stream=False) + assert resp == resp_cont resp = await fireworks_gpt.acompletion_text(messages, stream=False) - assert resp == resp_content + assert resp == resp_cont resp = await fireworks_gpt.acompletion_text(messages, stream=True) - assert resp == resp_content + assert resp == resp_cont - resp = await fireworks_gpt.aask(prompt_msg) - assert resp == resp_content + resp = await fireworks_gpt.aask(prompt) + assert resp == resp_cont diff --git a/tests/metagpt/provider/test_google_gemini_api.py b/tests/metagpt/provider/test_google_gemini_api.py index 404ae1e90..ad0c7bbfe 100644 --- a/tests/metagpt/provider/test_google_gemini_api.py +++ b/tests/metagpt/provider/test_google_gemini_api.py @@ -11,6 +11,11 @@ from google.generativeai.types import content_types from metagpt.provider.google_gemini_api import GeminiLLM from tests.metagpt.provider.mock_llm_config import mock_llm_config +from tests.metagpt.provider.req_resp_const import ( + gemini_messages, + prompt, + resp_cont_tmpl, +) @dataclass @@ -18,10 +23,8 @@ class MockGeminiResponse(ABC): text: str -prompt_msg = "who are you" -messages = [{"role": "user", "parts": prompt_msg}] -resp_content = "I'm gemini from google" -default_resp = MockGeminiResponse(text=resp_content) +resp_cont = resp_cont_tmpl.format(name="gemini") +default_resp = MockGeminiResponse(text=resp_cont) def mock_gemini_count_tokens(self, contents: content_types.ContentsType) -> glm.CountTokensResponse: @@ -62,26 +65,26 @@ async def test_gemini_acompletion(mocker): gemini_gpt = GeminiLLM(mock_llm_config) - assert gemini_gpt._user_msg(prompt_msg) == {"role": "user", "parts": [prompt_msg]} - assert gemini_gpt._assistant_msg(prompt_msg) == {"role": "model", "parts": [prompt_msg]} + assert gemini_gpt._user_msg(prompt) == {"role": "user", "parts": [prompt]} + assert gemini_gpt._assistant_msg(prompt) == {"role": "model", "parts": [prompt]} - usage = gemini_gpt.get_usage(messages, resp_content) + usage = gemini_gpt.get_usage(gemini_messages, resp_cont) assert usage == {"prompt_tokens": 20, "completion_tokens": 20} - resp = gemini_gpt.completion(messages) + resp = gemini_gpt.completion(gemini_messages) assert resp == default_resp - resp = await gemini_gpt.acompletion(messages) + resp = await gemini_gpt.acompletion(gemini_messages) assert resp.text == default_resp.text - resp = await gemini_gpt.aask(prompt_msg, stream=False) - assert resp == resp_content + resp = await gemini_gpt.aask(prompt, stream=False) + assert resp == resp_cont - resp = await gemini_gpt.acompletion_text(messages, stream=False) - assert resp == resp_content + resp = await gemini_gpt.acompletion_text(gemini_messages, stream=False) + assert resp == resp_cont - resp = await gemini_gpt.acompletion_text(messages, stream=True) - assert resp == resp_content + resp = await gemini_gpt.acompletion_text(gemini_messages, stream=True) + assert resp == resp_cont - resp = await gemini_gpt.aask(prompt_msg) - assert resp == resp_content + resp = await gemini_gpt.aask(prompt) + assert resp == resp_cont diff --git a/tests/metagpt/provider/test_ollama_api.py b/tests/metagpt/provider/test_ollama_api.py index 5d942598b..8e2625e35 100644 --- a/tests/metagpt/provider/test_ollama_api.py +++ b/tests/metagpt/provider/test_ollama_api.py @@ -9,12 +9,10 @@ import pytest from metagpt.provider.ollama_api import OllamaLLM from tests.metagpt.provider.mock_llm_config import mock_llm_config +from tests.metagpt.provider.req_resp_const import messages, prompt, resp_cont_tmpl -prompt_msg = "who are you" -messages = [{"role": "user", "content": prompt_msg}] - -resp_content = "I'm ollama" -default_resp = {"message": {"role": "assistant", "content": resp_content}} +resp_cont = resp_cont_tmpl.format(name="ollama") +default_resp = {"message": {"role": "assistant", "content": resp_cont}} async def mock_ollama_arequest(self, stream: bool = False, **kwargs) -> Tuple[Any, Any, bool]: @@ -46,14 +44,14 @@ async def test_gemini_acompletion(mocker): resp = await ollama_gpt.acompletion(messages) assert resp["message"]["content"] == default_resp["message"]["content"] - resp = await ollama_gpt.aask(prompt_msg, stream=False) - assert resp == resp_content + resp = await ollama_gpt.aask(prompt, stream=False) + assert resp == resp_cont resp = await ollama_gpt.acompletion_text(messages, stream=False) - assert resp == resp_content + assert resp == resp_cont resp = await ollama_gpt.acompletion_text(messages, stream=True) - assert resp == resp_content + assert resp == resp_cont - resp = await ollama_gpt.aask(prompt_msg) - assert resp == resp_content + resp = await ollama_gpt.aask(prompt) + assert resp == resp_cont diff --git a/tests/metagpt/provider/test_open_llm_api.py b/tests/metagpt/provider/test_open_llm_api.py index fc7b510cc..5b8a506e9 100644 --- a/tests/metagpt/provider/test_open_llm_api.py +++ b/tests/metagpt/provider/test_open_llm_api.py @@ -3,53 +3,25 @@ # @Desc : import pytest -from openai.types.chat.chat_completion import ( - ChatCompletion, - ChatCompletionMessage, - Choice, -) from openai.types.chat.chat_completion_chunk import ChatCompletionChunk -from openai.types.chat.chat_completion_chunk import Choice as AChoice -from openai.types.chat.chat_completion_chunk import ChoiceDelta from openai.types.completion_usage import CompletionUsage from metagpt.provider.open_llm_api import OpenLLM -from metagpt.utils.cost_manager import Costs +from metagpt.utils.cost_manager import CostManager, Costs from tests.metagpt.provider.mock_llm_config import mock_llm_config - -resp_content = "I'm llama2" -default_resp = ChatCompletion( - id="cmpl-a6652c1bb181caae8dd19ad8", - model="llama-v2-13b-chat", - object="chat.completion", - created=1703302755, - choices=[ - Choice( - finish_reason="stop", - index=0, - message=ChatCompletionMessage(role="assistant", content=resp_content), - logprobs=None, - ) - ], +from tests.metagpt.provider.req_resp_const import ( + get_openai_chat_completion, + get_openai_chat_completion_chunk, + messages, + prompt, + resp_cont_tmpl, ) -default_resp_chunk = ChatCompletionChunk( - id=default_resp.id, - model=default_resp.model, - object="chat.completion.chunk", - created=default_resp.created, - choices=[ - AChoice( - delta=ChoiceDelta(content=resp_content, role="assistant"), - finish_reason="stop", - index=0, - logprobs=None, - ) - ], -) +llm_name = "llama2-7b" +resp_cont = resp_cont_tmpl.format(name=llm_name) +default_resp = get_openai_chat_completion(llm_name) -prompt_msg = "who are you" -messages = [{"role": "user", "content": prompt_msg}] +default_resp_chunk = get_openai_chat_completion_chunk(llm_name) async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) -> ChatCompletionChunk: @@ -71,22 +43,23 @@ async def test_openllm_acompletion(mocker): openllm_gpt = OpenLLM(mock_llm_config) openllm_gpt.model = "llama-v2-13b-chat" + openllm_gpt.cost_manager = CostManager() openllm_gpt._update_costs(usage=CompletionUsage(prompt_tokens=100, completion_tokens=100, total_tokens=200)) assert openllm_gpt.get_costs() == Costs( total_prompt_tokens=100, total_completion_tokens=100, total_cost=0, total_budget=0 ) resp = await openllm_gpt.acompletion(messages) - assert resp.choices[0].message.content in resp_content + assert resp.choices[0].message.content in resp_cont - resp = await openllm_gpt.aask(prompt_msg, stream=False) - assert resp == resp_content + resp = await openllm_gpt.aask(prompt, stream=False) + assert resp == resp_cont resp = await openllm_gpt.acompletion_text(messages, stream=False) - assert resp == resp_content + assert resp == resp_cont resp = await openllm_gpt.acompletion_text(messages, stream=True) - assert resp == resp_content + assert resp == resp_cont - resp = await openllm_gpt.aask(prompt_msg) - assert resp == resp_content + resp = await openllm_gpt.aask(prompt) + assert resp == resp_cont diff --git a/tests/metagpt/provider/test_qianfan_api.py b/tests/metagpt/provider/test_qianfan_api.py new file mode 100644 index 000000000..76271b1e8 --- /dev/null +++ b/tests/metagpt/provider/test_qianfan_api.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the unittest of qianfan api + +import pytest + +from metagpt.provider.qianfan_api import QianFanLLM +from tests.metagpt.provider.req_resp_const import prompt, messages, resp_cont_tmpl + + +resp_cont = resp_cont_tmpl.format(name="ERNIE-Bot-turbo") + + +def test_qianfan_acompletion(mocker): + assert True, True diff --git a/tests/metagpt/provider/test_spark_api.py b/tests/metagpt/provider/test_spark_api.py index f5a6f66fd..32a839393 100644 --- a/tests/metagpt/provider/test_spark_api.py +++ b/tests/metagpt/provider/test_spark_api.py @@ -4,12 +4,14 @@ import pytest -from metagpt.config2 import Config from metagpt.provider.spark_api import GetMessageFromWeb, SparkLLM -from tests.metagpt.provider.mock_llm_config import mock_llm_config +from tests.metagpt.provider.mock_llm_config import ( + mock_llm_config, + mock_llm_config_spark, +) +from tests.metagpt.provider.req_resp_const import prompt, resp_cont_tmpl -prompt_msg = "who are you" -resp_content = "I'm Spark" +resp_cont = resp_cont_tmpl.format(name="Spark") class MockWebSocketApp(object): @@ -23,7 +25,7 @@ class MockWebSocketApp(object): def test_get_msg_from_web(mocker): mocker.patch("websocket.WebSocketApp", MockWebSocketApp) - get_msg_from_web = GetMessageFromWeb(prompt_msg, mock_llm_config) + get_msg_from_web = GetMessageFromWeb(prompt, mock_llm_config) assert get_msg_from_web.gen_params()["parameter"]["chat"]["domain"] == "mock_domain" ret = get_msg_from_web.run() @@ -31,15 +33,17 @@ def test_get_msg_from_web(mocker): def mock_spark_get_msg_from_web_run(self) -> str: - return resp_content + return resp_cont @pytest.mark.asyncio -async def test_spark_aask(): - llm = SparkLLM(Config.from_home("spark.yaml").llm) +async def test_spark_aask(mocker): + mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run) + + llm = SparkLLM(mock_llm_config_spark) resp = await llm.aask("Hello!") - print(resp) + assert resp == resp_cont @pytest.mark.asyncio @@ -49,16 +53,16 @@ async def test_spark_acompletion(mocker): spark_gpt = SparkLLM(mock_llm_config) resp = await spark_gpt.acompletion([]) - assert resp == resp_content + assert resp == resp_cont - resp = await spark_gpt.aask(prompt_msg, stream=False) - assert resp == resp_content + resp = await spark_gpt.aask(prompt, stream=False) + assert resp == resp_cont resp = await spark_gpt.acompletion_text([], stream=False) - assert resp == resp_content + assert resp == resp_cont resp = await spark_gpt.acompletion_text([], stream=True) - assert resp == resp_content + assert resp == resp_cont - resp = await spark_gpt.aask(prompt_msg) - assert resp == resp_content + resp = await spark_gpt.aask(prompt) + assert resp == resp_cont diff --git a/tests/metagpt/provider/test_zhipuai_api.py b/tests/metagpt/provider/test_zhipuai_api.py index 798209710..064562bff 100644 --- a/tests/metagpt/provider/test_zhipuai_api.py +++ b/tests/metagpt/provider/test_zhipuai_api.py @@ -6,22 +6,23 @@ import pytest from metagpt.provider.zhipuai_api import ZhiPuAILLM from tests.metagpt.provider.mock_llm_config import mock_llm_config_zhipu +from tests.metagpt.provider.req_resp_const import ( + get_part_chat_completion, + messages, + prompt, + resp_cont_tmpl, +) -prompt_msg = "who are you" -messages = [{"role": "user", "content": prompt_msg}] - -resp_content = "I'm chatglm-turbo" -default_resp = { - "choices": [{"finish_reason": "stop", "index": 0, "message": {"content": resp_content, "role": "assistant"}}], - "usage": {"completion_tokens": 22, "prompt_tokens": 19, "total_tokens": 41}, -} +llm_name = "ChatGLM-4" +resp_cont = resp_cont_tmpl.format(name=llm_name) +default_resp = get_part_chat_completion(llm_name) async def mock_zhipuai_acreate_stream(**kwargs): class MockResponse(object): async def _aread(self): class Iterator(object): - events = [{"choices": [{"index": 0, "delta": {"content": resp_content, "role": "assistant"}}]}] + events = [{"choices": [{"index": 0, "delta": {"content": resp_cont, "role": "assistant"}}]}] async def __aiter__(self): for event in self.events: @@ -49,19 +50,19 @@ async def test_zhipuai_acompletion(mocker): zhipu_gpt = ZhiPuAILLM(mock_llm_config_zhipu) resp = await zhipu_gpt.acompletion(messages) - assert resp["choices"][0]["message"]["content"] == resp_content + assert resp["choices"][0]["message"]["content"] == resp_cont - resp = await zhipu_gpt.aask(prompt_msg, stream=False) - assert resp == resp_content + resp = await zhipu_gpt.aask(prompt, stream=False) + assert resp == resp_cont resp = await zhipu_gpt.acompletion_text(messages, stream=False) - assert resp == resp_content + assert resp == resp_cont resp = await zhipu_gpt.acompletion_text(messages, stream=True) - assert resp == resp_content + assert resp == resp_cont - resp = await zhipu_gpt.aask(prompt_msg) - assert resp == resp_content + resp = await zhipu_gpt.aask(prompt) + assert resp == resp_cont def test_zhipuai_proxy(): diff --git a/tests/spark.yaml b/tests/spark.yaml deleted file mode 100644 index a5bbd98bd..000000000 --- a/tests/spark.yaml +++ /dev/null @@ -1,7 +0,0 @@ -llm: - api_type: "spark" - app_id: "xxx" - api_key: "xxx" - api_secret: "xxx" - domain: "generalv2" - base_url: "wss://spark-api.xf-yun.com/v3.1/chat" \ No newline at end of file From d94f4fbfbc3bd4310669f06e9bac9a7c89001712 Mon Sep 17 00:00:00 2001 From: shenchucheng Date: Wed, 7 Feb 2024 17:44:36 +0800 Subject: [PATCH 13/15] fix research bugs --- metagpt/provider/openai_api.py | 4 +++- metagpt/utils/text.py | 2 +- tests/metagpt/utils/test_text.py | 1 + 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 63e68c9bd..7b2cd6220 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -253,7 +253,9 @@ class OpenAILLM(BaseLLM): def _get_max_tokens(self, messages: list[dict]): if not self.auto_max_tokens: return self.config.max_token - return get_max_completion_tokens(messages, self.model, self.config.max_tokens) + # FIXME + # https://community.openai.com/t/why-is-gpt-3-5-turbo-1106-max-tokens-limited-to-4096/494973/3 + return min(get_max_completion_tokens(messages, self.model, self.config.max_tokens), 4096) @handle_exception async def amoderation(self, content: Union[str, list[str]]): diff --git a/metagpt/utils/text.py b/metagpt/utils/text.py index 921efe706..fb8b94232 100644 --- a/metagpt/utils/text.py +++ b/metagpt/utils/text.py @@ -93,7 +93,7 @@ def split_paragraph(paragraph: str, sep: str = ".,", count: int = 2) -> list[str continue ret = ["".join(j) for j in _split_by_count(sentences, count)] return ret - return _split_by_count(paragraph, count) + return list(_split_by_count(paragraph, count)) def decode_unicode_escape(text: str) -> str: diff --git a/tests/metagpt/utils/test_text.py b/tests/metagpt/utils/test_text.py index 7003c7767..c9a9753be 100644 --- a/tests/metagpt/utils/test_text.py +++ b/tests/metagpt/utils/test_text.py @@ -42,6 +42,7 @@ def test_reduce_message_length(msgs, model_name, system_text, reserved, expected (" ".join("Hello World." for _ in range(1000)), "Prompt: {}", "gpt-3.5-turbo-16k", "System", 3000, 1), (" ".join("Hello World." for _ in range(4000)), "Prompt: {}", "gpt-4", "System", 2000, 2), (" ".join("Hello World." for _ in range(8000)), "Prompt: {}", "gpt-4-32k", "System", 4000, 1), + (" ".join("Hello World" for _ in range(8000)), "Prompt: {}", "gpt-3.5-turbo", "System", 1000, 8), ], ) def test_generate_prompt_chunk(text, prompt_template, model_name, system_text, reserved, expected): From d3f6e38e8a9805d6bc80e5489dde99007bd20b6d Mon Sep 17 00:00:00 2001 From: better629 Date: Wed, 7 Feb 2024 18:32:32 +0800 Subject: [PATCH 14/15] add qianfan ut code and update xx_llm from xx_gpt --- metagpt/provider/qianfan_api.py | 3 +- tests/metagpt/provider/mock_llm_config.py | 7 +++ tests/metagpt/provider/req_resp_const.py | 57 ++++++++++++++++--- tests/metagpt/provider/test_base_llm.py | 6 +- tests/metagpt/provider/test_fireworks_llm.py | 26 ++++----- .../provider/test_google_gemini_api.py | 20 +++---- tests/metagpt/provider/test_ollama_api.py | 12 ++-- tests/metagpt/provider/test_open_llm_api.py | 28 ++++----- tests/metagpt/provider/test_qianfan_api.py | 39 +++++++++++-- tests/metagpt/provider/test_spark_api.py | 12 ++-- tests/metagpt/provider/test_zhipuai_api.py | 23 +++----- 11 files changed, 153 insertions(+), 80 deletions(-) diff --git a/metagpt/provider/qianfan_api.py b/metagpt/provider/qianfan_api.py index fbbff7085..6f94b9cea 100644 --- a/metagpt/provider/qianfan_api.py +++ b/metagpt/provider/qianfan_api.py @@ -5,6 +5,7 @@ import copy import os import qianfan +from qianfan import ChatCompletion from qianfan.resources.typing import JsonBody from tenacity import ( after_log, @@ -78,7 +79,7 @@ class QianFanLLM(BaseLLM): # self deployed model on the cloud not to calculate usage, it charges resource pool rental fee self.calc_usage = self.config.calc_usage and self.config.endpoint is None - self.aclient = qianfan.ChatCompletion() + self.aclient: ChatCompletion = qianfan.ChatCompletion() def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict: kwargs = { diff --git a/tests/metagpt/provider/mock_llm_config.py b/tests/metagpt/provider/mock_llm_config.py index 21780f914..e61e32e8b 100644 --- a/tests/metagpt/provider/mock_llm_config.py +++ b/tests/metagpt/provider/mock_llm_config.py @@ -52,3 +52,10 @@ mock_llm_config_spark = LLMConfig( domain="generalv2", base_url="wss://spark-api.xf-yun.com/v3.1/chat", ) + +mock_llm_config_qianfan = LLMConfig( + api_type="qianfan", + access_key="xxx", + secret_key="xxx", + model="ERNIE-Bot-turbo" +) diff --git a/tests/metagpt/provider/req_resp_const.py b/tests/metagpt/provider/req_resp_const.py index a3a7a363c..20d8e0914 100644 --- a/tests/metagpt/provider/req_resp_const.py +++ b/tests/metagpt/provider/req_resp_const.py @@ -2,6 +2,7 @@ # -*- coding: utf-8 -*- # @Desc : default request & response data for provider unittest +from typing import Dict from openai.types.chat.chat_completion import ( ChatCompletion, ChatCompletionMessage, @@ -11,6 +12,9 @@ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from openai.types.chat.chat_completion_chunk import Choice as AChoice from openai.types.chat.chat_completion_chunk import ChoiceDelta from openai.types.completion_usage import CompletionUsage +from qianfan.resources.typing import QfResponse, default_field + +from metagpt.provider.base_llm import BaseLLM prompt = "who are you?" messages = [{"role": "user", "content": prompt}] @@ -20,14 +24,14 @@ default_resp_cont = resp_cont_tmpl.format(name="GPT") # part of whole ChatCompletion of openai like structure -def get_part_chat_completion(llm_name: str) -> dict: +def get_part_chat_completion(name: str) -> dict: part_chat_completion = { "choices": [ { "index": 0, "message": { "role": "assistant", - "content": resp_cont_tmpl.format(name=llm_name), + "content": resp_cont_tmpl.format(name=name), }, "finish_reason": "stop", } @@ -37,7 +41,7 @@ def get_part_chat_completion(llm_name: str) -> dict: return part_chat_completion -def get_openai_chat_completion(llm_name: str) -> ChatCompletion: +def get_openai_chat_completion(name: str) -> ChatCompletion: openai_chat_completion = ChatCompletion( id="cmpl-a6652c1bb181caae8dd19ad8", model="xx/xxx", @@ -47,7 +51,7 @@ def get_openai_chat_completion(llm_name: str) -> ChatCompletion: Choice( finish_reason="stop", index=0, - message=ChatCompletionMessage(role="assistant", content=resp_cont_tmpl.format(name=llm_name)), + message=ChatCompletionMessage(role="assistant", content=resp_cont_tmpl.format(name=name)), logprobs=None, ) ], @@ -56,7 +60,7 @@ def get_openai_chat_completion(llm_name: str) -> ChatCompletion: return openai_chat_completion -def get_openai_chat_completion_chunk(llm_name: str, usage_as_dict: bool = False) -> ChatCompletionChunk: +def get_openai_chat_completion_chunk(name: str, usage_as_dict: bool = False) -> ChatCompletionChunk: usage = CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202) usage = usage if not usage_as_dict else usage.model_dump() openai_chat_completion_chunk = ChatCompletionChunk( @@ -66,7 +70,7 @@ def get_openai_chat_completion_chunk(llm_name: str, usage_as_dict: bool = False) created=1703300855, choices=[ AChoice( - delta=ChoiceDelta(role="assistant", content=resp_cont_tmpl.format(name=llm_name)), + delta=ChoiceDelta(role="assistant", content=resp_cont_tmpl.format(name=name)), finish_reason="stop", index=0, logprobs=None, @@ -76,5 +80,44 @@ def get_openai_chat_completion_chunk(llm_name: str, usage_as_dict: bool = False) ) return openai_chat_completion_chunk - +# For gemini gemini_messages = [{"role": "user", "parts": prompt}] + + +# For QianFan +qf_jsonbody_dict = { + "id": "as-4v1h587fyv", + "object": "chat.completion", + "created": 1695021339, + "result": "", + "is_truncated": False, + "need_clear_history": False, + "usage": { + "prompt_tokens": 7, + "completion_tokens": 15, + "total_tokens": 22 + } +} + + +def get_qianfan_response(name: str) -> QfResponse: + qf_jsonbody_dict["result"] = resp_cont_tmpl.format(name=name) + return QfResponse( + code=200, + body=qf_jsonbody_dict + ) + + +# For llm general chat functions call +async def llm_general_chat_funcs_test(llm: BaseLLM, prompt: str, messages: list[dict], resp_cont: str): + resp = await llm.aask(prompt, stream=False) + assert resp == resp_cont + + resp = await llm.aask(prompt) + assert resp == resp_cont + + resp = await llm.acompletion_text(messages, stream=False) + assert resp == resp_cont + + resp = await llm.acompletion_text(messages, stream=True) + assert resp == resp_cont diff --git a/tests/metagpt/provider/test_base_llm.py b/tests/metagpt/provider/test_base_llm.py index 0babd6d5f..cf44343bc 100644 --- a/tests/metagpt/provider/test_base_llm.py +++ b/tests/metagpt/provider/test_base_llm.py @@ -17,7 +17,7 @@ from tests.metagpt.provider.req_resp_const import ( prompt, ) -llm_name = "GPT" +name = "GPT" class MockBaseLLM(BaseLLM): @@ -25,10 +25,10 @@ class MockBaseLLM(BaseLLM): pass def completion(self, messages: list[dict], timeout=3): - return get_part_chat_completion(llm_name) + return get_part_chat_completion(name) async def acompletion(self, messages: list[dict], timeout=3): - return get_part_chat_completion(llm_name) + return get_part_chat_completion(name) async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str: return default_resp_cont diff --git a/tests/metagpt/provider/test_fireworks_llm.py b/tests/metagpt/provider/test_fireworks_llm.py index 834f6305f..e28f7500b 100644 --- a/tests/metagpt/provider/test_fireworks_llm.py +++ b/tests/metagpt/provider/test_fireworks_llm.py @@ -21,10 +21,10 @@ from tests.metagpt.provider.req_resp_const import ( resp_cont_tmpl, ) -llm_name = "fireworks" -resp_cont = resp_cont_tmpl.format(name=llm_name) -default_resp = get_openai_chat_completion(llm_name) -default_resp_chunk = get_openai_chat_completion_chunk(llm_name, usage_as_dict=True) +name = "fireworks" +resp_cont = resp_cont_tmpl.format(name=name) +default_resp = get_openai_chat_completion(name) +default_resp_chunk = get_openai_chat_completion_chunk(name, usage_as_dict=True) def test_fireworks_costmanager(): @@ -57,27 +57,27 @@ async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) async def test_fireworks_acompletion(mocker): mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create) - fireworks_gpt = FireworksLLM(mock_llm_config) - fireworks_gpt.model = "llama-v2-13b-chat" + fireworks_llm = FireworksLLM(mock_llm_config) + fireworks_llm.model = "llama-v2-13b-chat" - fireworks_gpt._update_costs( + fireworks_llm._update_costs( usage=CompletionUsage(prompt_tokens=500000, completion_tokens=500000, total_tokens=1000000) ) - assert fireworks_gpt.get_costs() == Costs( + assert fireworks_llm.get_costs() == Costs( total_prompt_tokens=500000, total_completion_tokens=500000, total_cost=0.5, total_budget=0 ) - resp = await fireworks_gpt.acompletion(messages) + resp = await fireworks_llm.acompletion(messages) assert resp.choices[0].message.content in resp_cont - resp = await fireworks_gpt.aask(prompt, stream=False) + resp = await fireworks_llm.aask(prompt, stream=False) assert resp == resp_cont - resp = await fireworks_gpt.acompletion_text(messages, stream=False) + resp = await fireworks_llm.acompletion_text(messages, stream=False) assert resp == resp_cont - resp = await fireworks_gpt.acompletion_text(messages, stream=True) + resp = await fireworks_llm.acompletion_text(messages, stream=True) assert resp == resp_cont - resp = await fireworks_gpt.aask(prompt) + resp = await fireworks_llm.aask(prompt) assert resp == resp_cont diff --git a/tests/metagpt/provider/test_google_gemini_api.py b/tests/metagpt/provider/test_google_gemini_api.py index ad0c7bbfe..dae9d123b 100644 --- a/tests/metagpt/provider/test_google_gemini_api.py +++ b/tests/metagpt/provider/test_google_gemini_api.py @@ -63,28 +63,28 @@ async def test_gemini_acompletion(mocker): mock_gemini_generate_content_async, ) - gemini_gpt = GeminiLLM(mock_llm_config) + gemini_llm = GeminiLLM(mock_llm_config) - assert gemini_gpt._user_msg(prompt) == {"role": "user", "parts": [prompt]} - assert gemini_gpt._assistant_msg(prompt) == {"role": "model", "parts": [prompt]} + assert gemini_llm._user_msg(prompt) == {"role": "user", "parts": [prompt]} + assert gemini_llm._assistant_msg(prompt) == {"role": "model", "parts": [prompt]} - usage = gemini_gpt.get_usage(gemini_messages, resp_cont) + usage = gemini_llm.get_usage(gemini_messages, resp_cont) assert usage == {"prompt_tokens": 20, "completion_tokens": 20} - resp = gemini_gpt.completion(gemini_messages) + resp = gemini_llm.completion(gemini_messages) assert resp == default_resp - resp = await gemini_gpt.acompletion(gemini_messages) + resp = await gemini_llm.acompletion(gemini_messages) assert resp.text == default_resp.text - resp = await gemini_gpt.aask(prompt, stream=False) + resp = await gemini_llm.aask(prompt, stream=False) assert resp == resp_cont - resp = await gemini_gpt.acompletion_text(gemini_messages, stream=False) + resp = await gemini_llm.acompletion_text(gemini_messages, stream=False) assert resp == resp_cont - resp = await gemini_gpt.acompletion_text(gemini_messages, stream=True) + resp = await gemini_llm.acompletion_text(gemini_messages, stream=True) assert resp == resp_cont - resp = await gemini_gpt.aask(prompt) + resp = await gemini_llm.aask(prompt) assert resp == resp_cont diff --git a/tests/metagpt/provider/test_ollama_api.py b/tests/metagpt/provider/test_ollama_api.py index 8e2625e35..01d53251c 100644 --- a/tests/metagpt/provider/test_ollama_api.py +++ b/tests/metagpt/provider/test_ollama_api.py @@ -39,19 +39,19 @@ async def mock_ollama_arequest(self, stream: bool = False, **kwargs) -> Tuple[An async def test_gemini_acompletion(mocker): mocker.patch("metagpt.provider.general_api_requestor.GeneralAPIRequestor.arequest", mock_ollama_arequest) - ollama_gpt = OllamaLLM(mock_llm_config) + ollama_llm = OllamaLLM(mock_llm_config) - resp = await ollama_gpt.acompletion(messages) + resp = await ollama_llm.acompletion(messages) assert resp["message"]["content"] == default_resp["message"]["content"] - resp = await ollama_gpt.aask(prompt, stream=False) + resp = await ollama_llm.aask(prompt, stream=False) assert resp == resp_cont - resp = await ollama_gpt.acompletion_text(messages, stream=False) + resp = await ollama_llm.acompletion_text(messages, stream=False) assert resp == resp_cont - resp = await ollama_gpt.acompletion_text(messages, stream=True) + resp = await ollama_llm.acompletion_text(messages, stream=True) assert resp == resp_cont - resp = await ollama_gpt.aask(prompt) + resp = await ollama_llm.aask(prompt) assert resp == resp_cont diff --git a/tests/metagpt/provider/test_open_llm_api.py b/tests/metagpt/provider/test_open_llm_api.py index 5b8a506e9..b2e759d06 100644 --- a/tests/metagpt/provider/test_open_llm_api.py +++ b/tests/metagpt/provider/test_open_llm_api.py @@ -17,11 +17,11 @@ from tests.metagpt.provider.req_resp_const import ( resp_cont_tmpl, ) -llm_name = "llama2-7b" -resp_cont = resp_cont_tmpl.format(name=llm_name) -default_resp = get_openai_chat_completion(llm_name) +name = "llama2-7b" +resp_cont = resp_cont_tmpl.format(name=name) +default_resp = get_openai_chat_completion(name) -default_resp_chunk = get_openai_chat_completion_chunk(llm_name) +default_resp_chunk = get_openai_chat_completion_chunk(name) async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) -> ChatCompletionChunk: @@ -40,26 +40,26 @@ async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) async def test_openllm_acompletion(mocker): mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create) - openllm_gpt = OpenLLM(mock_llm_config) - openllm_gpt.model = "llama-v2-13b-chat" + openllm_llm = OpenLLM(mock_llm_config) + openllm_llm.model = "llama-v2-13b-chat" - openllm_gpt.cost_manager = CostManager() - openllm_gpt._update_costs(usage=CompletionUsage(prompt_tokens=100, completion_tokens=100, total_tokens=200)) - assert openllm_gpt.get_costs() == Costs( + openllm_llm.cost_manager = CostManager() + openllm_llm._update_costs(usage=CompletionUsage(prompt_tokens=100, completion_tokens=100, total_tokens=200)) + assert openllm_llm.get_costs() == Costs( total_prompt_tokens=100, total_completion_tokens=100, total_cost=0, total_budget=0 ) - resp = await openllm_gpt.acompletion(messages) + resp = await openllm_llm.acompletion(messages) assert resp.choices[0].message.content in resp_cont - resp = await openllm_gpt.aask(prompt, stream=False) + resp = await openllm_llm.aask(prompt, stream=False) assert resp == resp_cont - resp = await openllm_gpt.acompletion_text(messages, stream=False) + resp = await openllm_llm.acompletion_text(messages, stream=False) assert resp == resp_cont - resp = await openllm_gpt.acompletion_text(messages, stream=True) + resp = await openllm_llm.acompletion_text(messages, stream=True) assert resp == resp_cont - resp = await openllm_gpt.aask(prompt) + resp = await openllm_llm.aask(prompt) assert resp == resp_cont diff --git a/tests/metagpt/provider/test_qianfan_api.py b/tests/metagpt/provider/test_qianfan_api.py index 76271b1e8..30ac06911 100644 --- a/tests/metagpt/provider/test_qianfan_api.py +++ b/tests/metagpt/provider/test_qianfan_api.py @@ -2,14 +2,45 @@ # -*- coding: utf-8 -*- # @Desc : the unittest of qianfan api +from typing import Dict, Union, AsyncIterator import pytest +from qianfan.resources.typing import JsonBody, QfResponse + from metagpt.provider.qianfan_api import QianFanLLM -from tests.metagpt.provider.req_resp_const import prompt, messages, resp_cont_tmpl +from tests.metagpt.provider.mock_llm_config import mock_llm_config_qianfan +from tests.metagpt.provider.req_resp_const import resp_cont_tmpl, prompt, messages, llm_general_chat_funcs_test, get_qianfan_response + +name = "ERNIE-Bot-turbo" +resp_cont = resp_cont_tmpl.format(name=name) -resp_cont = resp_cont_tmpl.format(name="ERNIE-Bot-turbo") +def mock_qianfan_do(self, messages: list[dict], model: str, stream: bool = False, system: str = None) -> QfResponse: + return get_qianfan_response(name=name) -def test_qianfan_acompletion(mocker): - assert True, True +async def mock_qianfan_ado(self, messages: list[dict], model: str, stream: bool = True, system: str = None) -> Union[QfResponse, AsyncIterator[QfResponse]]: + resps = [get_qianfan_response(name=name)] + if stream: + async def aresp_iterator(resps: list[JsonBody]): + for resp in resps: + yield resp + return aresp_iterator(resps) + else: + return resps[0] + + +@pytest.mark.asyncio +async def test_qianfan_acompletion(mocker): + mocker.patch("qianfan.resources.llm.chat_completion.ChatCompletion.do", mock_qianfan_do) + mocker.patch("qianfan.resources.llm.chat_completion.ChatCompletion.ado", mock_qianfan_ado) + + qianfan_llm = QianFanLLM(mock_llm_config_qianfan) + + resp = qianfan_llm.completion(messages) + assert resp.get("result") == resp_cont + + resp = await qianfan_llm.acompletion(messages) + assert resp.get("result") == resp_cont + + await llm_general_chat_funcs_test(qianfan_llm, prompt, messages, resp_cont) diff --git a/tests/metagpt/provider/test_spark_api.py b/tests/metagpt/provider/test_spark_api.py index 32a839393..8aa8bc7a8 100644 --- a/tests/metagpt/provider/test_spark_api.py +++ b/tests/metagpt/provider/test_spark_api.py @@ -50,19 +50,19 @@ async def test_spark_aask(mocker): async def test_spark_acompletion(mocker): mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run) - spark_gpt = SparkLLM(mock_llm_config) + spark_llm = SparkLLM(mock_llm_config) - resp = await spark_gpt.acompletion([]) + resp = await spark_llm.acompletion([]) assert resp == resp_cont - resp = await spark_gpt.aask(prompt, stream=False) + resp = await spark_llm.aask(prompt, stream=False) assert resp == resp_cont - resp = await spark_gpt.acompletion_text([], stream=False) + resp = await spark_llm.acompletion_text([], stream=False) assert resp == resp_cont - resp = await spark_gpt.acompletion_text([], stream=True) + resp = await spark_llm.acompletion_text([], stream=True) assert resp == resp_cont - resp = await spark_gpt.aask(prompt) + resp = await spark_llm.aask(prompt) assert resp == resp_cont diff --git a/tests/metagpt/provider/test_zhipuai_api.py b/tests/metagpt/provider/test_zhipuai_api.py index 064562bff..3dada367c 100644 --- a/tests/metagpt/provider/test_zhipuai_api.py +++ b/tests/metagpt/provider/test_zhipuai_api.py @@ -11,11 +11,12 @@ from tests.metagpt.provider.req_resp_const import ( messages, prompt, resp_cont_tmpl, + llm_general_chat_funcs_test ) -llm_name = "ChatGLM-4" -resp_cont = resp_cont_tmpl.format(name=llm_name) -default_resp = get_part_chat_completion(llm_name) +name = "ChatGLM-4" +resp_cont = resp_cont_tmpl.format(name=name) +default_resp = get_part_chat_completion(name) async def mock_zhipuai_acreate_stream(**kwargs): @@ -47,22 +48,12 @@ async def test_zhipuai_acompletion(mocker): mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.acreate", mock_zhipuai_acreate) mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.acreate_stream", mock_zhipuai_acreate_stream) - zhipu_gpt = ZhiPuAILLM(mock_llm_config_zhipu) + zhipu_llm = ZhiPuAILLM(mock_llm_config_zhipu) - resp = await zhipu_gpt.acompletion(messages) + resp = await zhipu_llm.acompletion(messages) assert resp["choices"][0]["message"]["content"] == resp_cont - resp = await zhipu_gpt.aask(prompt, stream=False) - assert resp == resp_cont - - resp = await zhipu_gpt.acompletion_text(messages, stream=False) - assert resp == resp_cont - - resp = await zhipu_gpt.acompletion_text(messages, stream=True) - assert resp == resp_cont - - resp = await zhipu_gpt.aask(prompt) - assert resp == resp_cont + await llm_general_chat_funcs_test(zhipu_llm, prompt, messages, resp_cont) def test_zhipuai_proxy(): From 997e25e97d291d83e0fc587abde46bf383308a59 Mon Sep 17 00:00:00 2001 From: better629 Date: Wed, 7 Feb 2024 18:42:22 +0800 Subject: [PATCH 15/15] simplify provider ut code --- tests/metagpt/provider/mock_llm_config.py | 7 +------ tests/metagpt/provider/req_resp_const.py | 16 +++++----------- tests/metagpt/provider/test_fireworks_llm.py | 13 ++----------- .../metagpt/provider/test_google_gemini_api.py | 13 ++----------- tests/metagpt/provider/test_ollama_api.py | 16 +++++++--------- tests/metagpt/provider/test_open_llm_api.py | 13 ++----------- tests/metagpt/provider/test_qianfan_api.py | 18 ++++++++++++++---- tests/metagpt/provider/test_spark_api.py | 18 ++++++------------ tests/metagpt/provider/test_zhipuai_api.py | 2 +- 9 files changed, 40 insertions(+), 76 deletions(-) diff --git a/tests/metagpt/provider/mock_llm_config.py b/tests/metagpt/provider/mock_llm_config.py index e61e32e8b..e0afaa51e 100644 --- a/tests/metagpt/provider/mock_llm_config.py +++ b/tests/metagpt/provider/mock_llm_config.py @@ -53,9 +53,4 @@ mock_llm_config_spark = LLMConfig( base_url="wss://spark-api.xf-yun.com/v3.1/chat", ) -mock_llm_config_qianfan = LLMConfig( - api_type="qianfan", - access_key="xxx", - secret_key="xxx", - model="ERNIE-Bot-turbo" -) +mock_llm_config_qianfan = LLMConfig(api_type="qianfan", access_key="xxx", secret_key="xxx", model="ERNIE-Bot-turbo") diff --git a/tests/metagpt/provider/req_resp_const.py b/tests/metagpt/provider/req_resp_const.py index 20d8e0914..73939e1c6 100644 --- a/tests/metagpt/provider/req_resp_const.py +++ b/tests/metagpt/provider/req_resp_const.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- # @Desc : default request & response data for provider unittest -from typing import Dict + from openai.types.chat.chat_completion import ( ChatCompletion, ChatCompletionMessage, @@ -12,7 +12,7 @@ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from openai.types.chat.chat_completion_chunk import Choice as AChoice from openai.types.chat.chat_completion_chunk import ChoiceDelta from openai.types.completion_usage import CompletionUsage -from qianfan.resources.typing import QfResponse, default_field +from qianfan.resources.typing import QfResponse from metagpt.provider.base_llm import BaseLLM @@ -80,6 +80,7 @@ def get_openai_chat_completion_chunk(name: str, usage_as_dict: bool = False) -> ) return openai_chat_completion_chunk + # For gemini gemini_messages = [{"role": "user", "parts": prompt}] @@ -92,20 +93,13 @@ qf_jsonbody_dict = { "result": "", "is_truncated": False, "need_clear_history": False, - "usage": { - "prompt_tokens": 7, - "completion_tokens": 15, - "total_tokens": 22 - } + "usage": {"prompt_tokens": 7, "completion_tokens": 15, "total_tokens": 22}, } def get_qianfan_response(name: str) -> QfResponse: qf_jsonbody_dict["result"] = resp_cont_tmpl.format(name=name) - return QfResponse( - code=200, - body=qf_jsonbody_dict - ) + return QfResponse(code=200, body=qf_jsonbody_dict) # For llm general chat functions call diff --git a/tests/metagpt/provider/test_fireworks_llm.py b/tests/metagpt/provider/test_fireworks_llm.py index e28f7500b..1c1aa9caa 100644 --- a/tests/metagpt/provider/test_fireworks_llm.py +++ b/tests/metagpt/provider/test_fireworks_llm.py @@ -16,6 +16,7 @@ from tests.metagpt.provider.mock_llm_config import mock_llm_config from tests.metagpt.provider.req_resp_const import ( get_openai_chat_completion, get_openai_chat_completion_chunk, + llm_general_chat_funcs_test, messages, prompt, resp_cont_tmpl, @@ -70,14 +71,4 @@ async def test_fireworks_acompletion(mocker): resp = await fireworks_llm.acompletion(messages) assert resp.choices[0].message.content in resp_cont - resp = await fireworks_llm.aask(prompt, stream=False) - assert resp == resp_cont - - resp = await fireworks_llm.acompletion_text(messages, stream=False) - assert resp == resp_cont - - resp = await fireworks_llm.acompletion_text(messages, stream=True) - assert resp == resp_cont - - resp = await fireworks_llm.aask(prompt) - assert resp == resp_cont + await llm_general_chat_funcs_test(fireworks_llm, prompt, messages, resp_cont) diff --git a/tests/metagpt/provider/test_google_gemini_api.py b/tests/metagpt/provider/test_google_gemini_api.py index dae9d123b..50c15ee19 100644 --- a/tests/metagpt/provider/test_google_gemini_api.py +++ b/tests/metagpt/provider/test_google_gemini_api.py @@ -13,6 +13,7 @@ from metagpt.provider.google_gemini_api import GeminiLLM from tests.metagpt.provider.mock_llm_config import mock_llm_config from tests.metagpt.provider.req_resp_const import ( gemini_messages, + llm_general_chat_funcs_test, prompt, resp_cont_tmpl, ) @@ -77,14 +78,4 @@ async def test_gemini_acompletion(mocker): resp = await gemini_llm.acompletion(gemini_messages) assert resp.text == default_resp.text - resp = await gemini_llm.aask(prompt, stream=False) - assert resp == resp_cont - - resp = await gemini_llm.acompletion_text(gemini_messages, stream=False) - assert resp == resp_cont - - resp = await gemini_llm.acompletion_text(gemini_messages, stream=True) - assert resp == resp_cont - - resp = await gemini_llm.aask(prompt) - assert resp == resp_cont + await llm_general_chat_funcs_test(gemini_llm, prompt, gemini_messages, resp_cont) diff --git a/tests/metagpt/provider/test_ollama_api.py b/tests/metagpt/provider/test_ollama_api.py index 01d53251c..af2e929e9 100644 --- a/tests/metagpt/provider/test_ollama_api.py +++ b/tests/metagpt/provider/test_ollama_api.py @@ -9,7 +9,12 @@ import pytest from metagpt.provider.ollama_api import OllamaLLM from tests.metagpt.provider.mock_llm_config import mock_llm_config -from tests.metagpt.provider.req_resp_const import messages, prompt, resp_cont_tmpl +from tests.metagpt.provider.req_resp_const import ( + llm_general_chat_funcs_test, + messages, + prompt, + resp_cont_tmpl, +) resp_cont = resp_cont_tmpl.format(name="ollama") default_resp = {"message": {"role": "assistant", "content": resp_cont}} @@ -47,11 +52,4 @@ async def test_gemini_acompletion(mocker): resp = await ollama_llm.aask(prompt, stream=False) assert resp == resp_cont - resp = await ollama_llm.acompletion_text(messages, stream=False) - assert resp == resp_cont - - resp = await ollama_llm.acompletion_text(messages, stream=True) - assert resp == resp_cont - - resp = await ollama_llm.aask(prompt) - assert resp == resp_cont + await llm_general_chat_funcs_test(ollama_llm, prompt, messages, resp_cont) diff --git a/tests/metagpt/provider/test_open_llm_api.py b/tests/metagpt/provider/test_open_llm_api.py index b2e759d06..aa38b95a6 100644 --- a/tests/metagpt/provider/test_open_llm_api.py +++ b/tests/metagpt/provider/test_open_llm_api.py @@ -12,6 +12,7 @@ from tests.metagpt.provider.mock_llm_config import mock_llm_config from tests.metagpt.provider.req_resp_const import ( get_openai_chat_completion, get_openai_chat_completion_chunk, + llm_general_chat_funcs_test, messages, prompt, resp_cont_tmpl, @@ -52,14 +53,4 @@ async def test_openllm_acompletion(mocker): resp = await openllm_llm.acompletion(messages) assert resp.choices[0].message.content in resp_cont - resp = await openllm_llm.aask(prompt, stream=False) - assert resp == resp_cont - - resp = await openllm_llm.acompletion_text(messages, stream=False) - assert resp == resp_cont - - resp = await openllm_llm.acompletion_text(messages, stream=True) - assert resp == resp_cont - - resp = await openllm_llm.aask(prompt) - assert resp == resp_cont + await llm_general_chat_funcs_test(openllm_llm, prompt, messages, resp_cont) diff --git a/tests/metagpt/provider/test_qianfan_api.py b/tests/metagpt/provider/test_qianfan_api.py index 30ac06911..28341425c 100644 --- a/tests/metagpt/provider/test_qianfan_api.py +++ b/tests/metagpt/provider/test_qianfan_api.py @@ -2,14 +2,20 @@ # -*- coding: utf-8 -*- # @Desc : the unittest of qianfan api -from typing import Dict, Union, AsyncIterator -import pytest +from typing import AsyncIterator, Union +import pytest from qianfan.resources.typing import JsonBody, QfResponse from metagpt.provider.qianfan_api import QianFanLLM from tests.metagpt.provider.mock_llm_config import mock_llm_config_qianfan -from tests.metagpt.provider.req_resp_const import resp_cont_tmpl, prompt, messages, llm_general_chat_funcs_test, get_qianfan_response +from tests.metagpt.provider.req_resp_const import ( + get_qianfan_response, + llm_general_chat_funcs_test, + messages, + prompt, + resp_cont_tmpl, +) name = "ERNIE-Bot-turbo" resp_cont = resp_cont_tmpl.format(name=name) @@ -19,12 +25,16 @@ def mock_qianfan_do(self, messages: list[dict], model: str, stream: bool = False return get_qianfan_response(name=name) -async def mock_qianfan_ado(self, messages: list[dict], model: str, stream: bool = True, system: str = None) -> Union[QfResponse, AsyncIterator[QfResponse]]: +async def mock_qianfan_ado( + self, messages: list[dict], model: str, stream: bool = True, system: str = None +) -> Union[QfResponse, AsyncIterator[QfResponse]]: resps = [get_qianfan_response(name=name)] if stream: + async def aresp_iterator(resps: list[JsonBody]): for resp in resps: yield resp + return aresp_iterator(resps) else: return resps[0] diff --git a/tests/metagpt/provider/test_spark_api.py b/tests/metagpt/provider/test_spark_api.py index 8aa8bc7a8..9c278267d 100644 --- a/tests/metagpt/provider/test_spark_api.py +++ b/tests/metagpt/provider/test_spark_api.py @@ -9,7 +9,11 @@ from tests.metagpt.provider.mock_llm_config import ( mock_llm_config, mock_llm_config_spark, ) -from tests.metagpt.provider.req_resp_const import prompt, resp_cont_tmpl +from tests.metagpt.provider.req_resp_const import ( + llm_general_chat_funcs_test, + prompt, + resp_cont_tmpl, +) resp_cont = resp_cont_tmpl.format(name="Spark") @@ -55,14 +59,4 @@ async def test_spark_acompletion(mocker): resp = await spark_llm.acompletion([]) assert resp == resp_cont - resp = await spark_llm.aask(prompt, stream=False) - assert resp == resp_cont - - resp = await spark_llm.acompletion_text([], stream=False) - assert resp == resp_cont - - resp = await spark_llm.acompletion_text([], stream=True) - assert resp == resp_cont - - resp = await spark_llm.aask(prompt) - assert resp == resp_cont + await llm_general_chat_funcs_test(spark_llm, prompt, prompt, resp_cont) diff --git a/tests/metagpt/provider/test_zhipuai_api.py b/tests/metagpt/provider/test_zhipuai_api.py index 3dada367c..8ec9ab4f9 100644 --- a/tests/metagpt/provider/test_zhipuai_api.py +++ b/tests/metagpt/provider/test_zhipuai_api.py @@ -8,10 +8,10 @@ from metagpt.provider.zhipuai_api import ZhiPuAILLM from tests.metagpt.provider.mock_llm_config import mock_llm_config_zhipu from tests.metagpt.provider.req_resp_const import ( get_part_chat_completion, + llm_general_chat_funcs_test, messages, prompt, resp_cont_tmpl, - llm_general_chat_funcs_test ) name = "ChatGLM-4"