From 0f047e5693ebe5f5f92f95c81cfbd4cf4cd9ad67 Mon Sep 17 00:00:00 2001 From: better629 Date: Fri, 29 Dec 2023 02:39:00 +0800 Subject: [PATCH] update provider unittests to update coverage rate --- metagpt/actions/action_node.py | 4 +- metagpt/provider/general_api_base.py | 2 +- metagpt/provider/google_gemini_api.py | 3 - metagpt/provider/open_llm_api.py | 1 - .../{postprecess => postprocess}/__init__.py | 0 .../base_postprocess_plugin.py} | 4 +- .../llm_output_postprocess.py} | 10 +- metagpt/provider/zhipuai/zhipu_model_api.py | 2 +- metagpt/provider/zhipuai_api.py | 3 - .../metagpt/provider/postprocess/__init__.py | 3 + .../test_base_postprocess_plugin.py | 38 ++++++++ .../test_llm_output_postprocess.py | 14 +++ tests/metagpt/provider/test_anthropic_api.py | 19 ++-- .../metagpt/provider/test_azure_openai_api.py | 15 +++ tests/metagpt/provider/test_fireworks_api.py | 58 +++++++++--- .../metagpt/provider/test_general_api_base.py | 84 +++++++++++++++++ .../provider/test_general_api_requestor.py | 15 ++- .../provider/test_google_gemini_api.py | 49 ++++++++-- tests/metagpt/provider/test_ollama_api.py | 31 +++++-- tests/metagpt/provider/test_open_llm_api.py | 93 +++++++++++++++++++ tests/metagpt/provider/test_openai.py | 5 + tests/metagpt/provider/test_spark_api.py | 19 ++-- tests/metagpt/provider/test_zhipuai_api.py | 52 +++++++++-- tests/metagpt/provider/zhipuai/__init__.py | 3 + .../provider/zhipuai/test_async_sse_client.py | 18 ++++ .../provider/zhipuai/test_zhipu_model_api.py | 40 ++++++++ 26 files changed, 509 insertions(+), 76 deletions(-) rename metagpt/provider/{postprecess => postprocess}/__init__.py (100%) rename metagpt/provider/{postprecess/base_postprecess_plugin.py => postprocess/base_postprocess_plugin.py} (98%) rename metagpt/provider/{postprecess/llm_output_postprecess.py => postprocess/llm_output_postprocess.py} (58%) create mode 100644 tests/metagpt/provider/postprocess/__init__.py create mode 100644 tests/metagpt/provider/postprocess/test_base_postprocess_plugin.py create mode 100644 tests/metagpt/provider/postprocess/test_llm_output_postprocess.py create mode 100644 tests/metagpt/provider/test_azure_openai_api.py create mode 100644 tests/metagpt/provider/test_general_api_base.py create mode 100644 tests/metagpt/provider/test_open_llm_api.py create mode 100644 tests/metagpt/provider/zhipuai/__init__.py create mode 100644 tests/metagpt/provider/zhipuai/test_async_sse_client.py create mode 100644 tests/metagpt/provider/zhipuai/test_zhipu_model_api.py diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index 3389b8964..35f2b76f8 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -17,7 +17,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.config import CONFIG from metagpt.llm import BaseLLM from metagpt.logs import logger -from metagpt.provider.postprecess.llm_output_postprecess import llm_output_postprecess +from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess from metagpt.utils.common import OutputParser, general_after_log TAG = "CONTENT" @@ -275,7 +275,7 @@ class ActionNode: output_class = self.create_model_class(output_class_name, output_data_mapping) if schema == "json": - parsed_data = llm_output_postprecess( + parsed_data = llm_output_postprocess( output=content, schema=output_class.model_json_schema(), req_key=f"[/{TAG}]" ) else: # using markdown parser diff --git a/metagpt/provider/general_api_base.py b/metagpt/provider/general_api_base.py index 814be2f67..bbe03774c 100644 --- a/metagpt/provider/general_api_base.py +++ b/metagpt/provider/general_api_base.py @@ -100,7 +100,7 @@ def log_info(message, **params): def log_warn(message, **params): msg = logfmt(dict(message=message, **params)) print(msg, file=sys.stderr) - logger.warn(msg) + logger.warning(msg) def logfmt(props): diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index b9ee73a92..c99a14b38 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -79,9 +79,6 @@ class GeminiLLM(BaseLLM): except Exception as e: logger.error(f"google gemini updats costs failed! exp: {e}") - def close(self): - pass - def get_choice_text(self, resp: GenerateContentResponse) -> str: return resp.text diff --git a/metagpt/provider/open_llm_api.py b/metagpt/provider/open_llm_api.py index 6ccdb4da0..7f5870702 100644 --- a/metagpt/provider/open_llm_api.py +++ b/metagpt/provider/open_llm_api.py @@ -31,7 +31,6 @@ class OpenLLMCostManager(CostManager): f"Max budget: ${CONFIG.max_budget:.3f} | reference " f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}" ) - CONFIG.total_cost = self.total_cost @register_provider(LLMProviderEnum.OPEN_LLM) diff --git a/metagpt/provider/postprecess/__init__.py b/metagpt/provider/postprocess/__init__.py similarity index 100% rename from metagpt/provider/postprecess/__init__.py rename to metagpt/provider/postprocess/__init__.py diff --git a/metagpt/provider/postprecess/base_postprecess_plugin.py b/metagpt/provider/postprocess/base_postprocess_plugin.py similarity index 98% rename from metagpt/provider/postprecess/base_postprecess_plugin.py rename to metagpt/provider/postprocess/base_postprocess_plugin.py index 46646be91..48130ede8 100644 --- a/metagpt/provider/postprecess/base_postprecess_plugin.py +++ b/metagpt/provider/postprocess/base_postprocess_plugin.py @@ -12,8 +12,8 @@ from metagpt.utils.repair_llm_raw_output import ( ) -class BasePostPrecessPlugin(object): - model = None # the plugin of the `model`, use to judge in `llm_postprecess` +class BasePostProcessPlugin(object): + model = None # the plugin of the `model`, use to judge in `llm_postprocess` def run_repair_llm_output(self, output: str, schema: dict, req_key: str = "[/CONTENT]") -> Union[dict, list]: """ diff --git a/metagpt/provider/postprecess/llm_output_postprecess.py b/metagpt/provider/postprocess/llm_output_postprocess.py similarity index 58% rename from metagpt/provider/postprecess/llm_output_postprecess.py rename to metagpt/provider/postprocess/llm_output_postprocess.py index 85405543d..f898ba3d7 100644 --- a/metagpt/provider/postprecess/llm_output_postprecess.py +++ b/metagpt/provider/postprocess/llm_output_postprocess.py @@ -4,17 +4,17 @@ from typing import Union -from metagpt.provider.postprecess.base_postprecess_plugin import BasePostPrecessPlugin +from metagpt.provider.postprocess.base_postprocess_plugin import BasePostProcessPlugin -def llm_output_postprecess( +def llm_output_postprocess( output: str, schema: dict, req_key: str = "[/CONTENT]", model_name: str = None ) -> Union[dict, str]: """ - default use BasePostPrecessPlugin if there is not matched plugin. + default use BasePostProcessPlugin if there is not matched plugin. """ # TODO choose different model's plugin according to the model_name - postprecess_plugin = BasePostPrecessPlugin() + postprocess_plugin = BasePostProcessPlugin() - result = postprecess_plugin.run(output=output, schema=schema, req_key=req_key) + result = postprocess_plugin.run(output=output, schema=schema, req_key=req_key) return result diff --git a/metagpt/provider/zhipuai/zhipu_model_api.py b/metagpt/provider/zhipuai/zhipu_model_api.py index 19eb52530..72be0f333 100644 --- a/metagpt/provider/zhipuai/zhipu_model_api.py +++ b/metagpt/provider/zhipuai/zhipu_model_api.py @@ -33,7 +33,7 @@ class ZhiPuModelAPI(ModelAPI): zhipu_api_url: https://open.bigmodel.cn/api/paas/v3/model-api/{model}/{invoke_method} """ arr = zhipu_api_url.split("/api/") - # ("https://open.bigmodel.cn/api/" , "/paas/v3/model-api/chatglm_turbo/invoke") + # ("https://open.bigmodel.cn/api" , "/paas/v3/model-api/chatglm_turbo/invoke") return f"{arr[0]}/api", f"/{arr[1]}" @classmethod diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index cdc9c63e6..addbe58af 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -68,9 +68,6 @@ class ZhiPuAILLM(BaseLLM): except Exception as e: logger.error(f"zhipuai updats costs failed! exp: {e}") - def close(self): - pass - def get_choice_text(self, resp: dict) -> str: """get the first text of choice from llm response""" assist_msg = resp.get("data", {}).get("choices", [{"role": "error"}])[-1] diff --git a/tests/metagpt/provider/postprocess/__init__.py b/tests/metagpt/provider/postprocess/__init__.py new file mode 100644 index 000000000..2bcf8efd0 --- /dev/null +++ b/tests/metagpt/provider/postprocess/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/tests/metagpt/provider/postprocess/test_base_postprocess_plugin.py b/tests/metagpt/provider/postprocess/test_base_postprocess_plugin.py new file mode 100644 index 000000000..e63e4ecfe --- /dev/null +++ b/tests/metagpt/provider/postprocess/test_base_postprocess_plugin.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +import pytest + +from metagpt.provider.postprocess.base_postprocess_plugin import BasePostProcessPlugin + +raw_output = """ +[CONTENT] +{ +"Original Requirements": "xxx" +} +[/CONTENT] +""" +raw_schema = { + "title":"prd", + "type":"object", + "properties":{ + "Original Requirements":{ + "title":"Original Requirements", + "type":"string" + }, + }, + "required":[ + "Original Requirements", + ] + } + + +def test_llm_post_process_plugin(): + post_process_plugin = BasePostProcessPlugin() + + output = post_process_plugin.run( + output=raw_output, + schema=raw_schema + ) + assert "Original Requirements" in output diff --git a/tests/metagpt/provider/postprocess/test_llm_output_postprocess.py b/tests/metagpt/provider/postprocess/test_llm_output_postprocess.py new file mode 100644 index 000000000..3cb627216 --- /dev/null +++ b/tests/metagpt/provider/postprocess/test_llm_output_postprocess.py @@ -0,0 +1,14 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +import pytest + +from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess + +from tests.metagpt.provider.postprocess.test_base_postprocess_plugin import raw_output, raw_schema + + +def test_llm_output_postprocess(): + output = llm_output_postprocess(output=raw_output, schema=raw_schema) + assert "Original Requirements" in output diff --git a/tests/metagpt/provider/test_anthropic_api.py b/tests/metagpt/provider/test_anthropic_api.py index 4d3de5320..4410717a9 100644 --- a/tests/metagpt/provider/test_anthropic_api.py +++ b/tests/metagpt/provider/test_anthropic_api.py @@ -2,28 +2,33 @@ # -*- coding: utf-8 -*- # @Desc : the unittest of Claude2 -import pytest +import pytest +from anthropic.resources.completions import Completion + +from metagpt.config import CONFIG from metagpt.provider.anthropic_api import Claude2 +CONFIG.anthropic_api_key = "xxx" + prompt = "who are you" resp = "I'am Claude2" -def mock_llm_ask(self, msg: str) -> str: - return resp +def mock_anthropic_completions_create(self, model: str, prompt: str, max_tokens_to_sample: int) -> Completion: + return Completion(id="xx", completion=resp, model="claude-2", stop_reason="stop_sequence", type="completion") -async def mock_llm_aask(self, msg: str) -> str: - return resp +async def mock_anthropic_acompletions_create(self, model: str, prompt: str, max_tokens_to_sample: int) -> Completion: + return Completion(id="xx", completion=resp, model="claude-2", stop_reason="stop_sequence", type="completion") def test_claude2_ask(mocker): - mocker.patch("metagpt.provider.anthropic_api.Claude2.ask", mock_llm_ask) + mocker.patch("anthropic.resources.completions.Completions.create", mock_anthropic_completions_create) assert resp == Claude2().ask(prompt) @pytest.mark.asyncio async def test_claude2_aask(mocker): - mocker.patch("metagpt.provider.anthropic_api.Claude2.aask", mock_llm_aask) + mocker.patch("anthropic.resources.completions.AsyncCompletions.create", mock_anthropic_acompletions_create) assert resp == await Claude2().aask(prompt) diff --git a/tests/metagpt/provider/test_azure_openai_api.py b/tests/metagpt/provider/test_azure_openai_api.py new file mode 100644 index 000000000..208e3104a --- /dev/null +++ b/tests/metagpt/provider/test_azure_openai_api.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +import pytest + +from metagpt.provider.azure_openai_api import AzureOpenAILLM +from metagpt.config import CONFIG + +CONFIG.OPENAI_API_VERSION = "xx" +CONFIG.openai_proxy = "http://127.0.0.1:80" # fake value + + +def test_azure_openai_api(): + _ = AzureOpenAILLM() diff --git a/tests/metagpt/provider/test_fireworks_api.py b/tests/metagpt/provider/test_fireworks_api.py index 496465e5f..d48686eaa 100644 --- a/tests/metagpt/provider/test_fireworks_api.py +++ b/tests/metagpt/provider/test_fireworks_api.py @@ -8,6 +8,9 @@ from openai.types.chat.chat_completion import ( ChatCompletionMessage, Choice, ) +from openai.types.chat.chat_completion_chunk import ChatCompletionChunk +from openai.types.chat.chat_completion_chunk import Choice as AChoice +from openai.types.chat.chat_completion_chunk import ChoiceDelta from openai.types.completion_usage import CompletionUsage from metagpt.config import CONFIG @@ -16,8 +19,11 @@ from metagpt.provider.fireworks_api import ( FireworksCostManager, FireworksLLM, ) +from metagpt.utils.cost_manager import Costs CONFIG.fireworks_api_key = "xxx" +CONFIG.max_budget = 10 +CONFIG.calc_usage = True resp_content = "I'm fireworks" default_resp = ChatCompletion( @@ -36,6 +42,22 @@ default_resp = ChatCompletion( usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202), ) +default_resp_chunk = ChatCompletionChunk( + id=default_resp.id, + model=default_resp.model, + object="chat.completion.chunk", + created=default_resp.created, + choices=[ + AChoice( + delta=ChoiceDelta(content=resp_content, role="assistant"), + finish_reason="stop", + index=0, + logprobs=None, + ) + ], + usage=dict(default_resp.usage), +) + prompt_msg = "who are you" messages = [{"role": "user", "content": prompt_msg}] @@ -50,29 +72,37 @@ def test_fireworks_costmanager(): assert MODEL_GRADE_TOKEN_COSTS["80"] == cost_manager.model_grade_token_costs("xxx-80b-chat") assert MODEL_GRADE_TOKEN_COSTS["mixtral-8x7b"] == cost_manager.model_grade_token_costs("mixtral-8x7b-chat") - -def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> ChatCompletion: - return default_resp + cost_manager.update_cost(prompt_tokens=500000, completion_tokens=500000, model="llama-v2-13b-chat") + assert cost_manager.total_cost == 0.5 -async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> ChatCompletion: - return default_resp +async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) -> ChatCompletionChunk: + if stream: + class Iterator(object): + async def __aiter__(self): + yield default_resp_chunk -async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str: - return default_resp.choices[0].message.content + return Iterator() + else: + return default_resp @pytest.mark.asyncio async def test_fireworks_acompletion(mocker): - mocker.patch("metagpt.provider.fireworks_api.FireworksLLM.acompletion", mock_llm_acompletion) - mocker.patch("metagpt.provider.fireworks_api.FireworksLLM._achat_completion", mock_llm_acompletion) - mocker.patch( - "metagpt.provider.fireworks_api.FireworksLLM._achat_completion_stream", mock_llm_achat_completion_stream - ) - fireworks_gpt = FireworksLLM() + mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create) - resp = await fireworks_gpt.acompletion(messages, stream=False) + fireworks_gpt = FireworksLLM() + fireworks_gpt.model = "llama-v2-13b-chat" + + fireworks_gpt._update_costs( + usage=CompletionUsage(prompt_tokens=500000, completion_tokens=500000, total_tokens=1000000) + ) + assert fireworks_gpt.get_costs() == Costs( + total_prompt_tokens=500000, total_completion_tokens=500000, total_cost=0.5, total_budget=0 + ) + + resp = await fireworks_gpt.acompletion(messages) assert resp.choices[0].message.content in resp_content resp = await fireworks_gpt.aask(prompt_msg, stream=False) diff --git a/tests/metagpt/provider/test_general_api_base.py b/tests/metagpt/provider/test_general_api_base.py new file mode 100644 index 000000000..52ba32f01 --- /dev/null +++ b/tests/metagpt/provider/test_general_api_base.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +import pytest +import os +import requests +import aiohttp +from typing import Iterator, Tuple, Union, Generator, AsyncGenerator + +from openai import OpenAIError +from metagpt.provider.general_api_base import ApiType, log_debug, log_info, log_warn, OpenAIResponse, \ + _requests_proxies_arg, _aiohttp_proxies_arg, _make_session, parse_stream_helper, parse_stream, APIRequestor + + +def test_basic(): + _ = ApiType.from_str("azure") + _ = ApiType.from_str("azuread") + _ = ApiType.from_str("openai") + with pytest.raises(OpenAIError): + _ = ApiType.from_str("xx") + + os.environ.setdefault("LLM_LOG", "debug") + log_debug("debug") + log_warn("warn") + log_info("info") + + +def test_openai_response(): + resp = OpenAIResponse(data=[], headers={"retry-after": 3}) + assert resp.request_id is None + assert resp.retry_after == 3 + assert resp.operation_location is None + assert resp.organization is None + assert resp.response_ms is None + + +def test_proxy(): + assert _requests_proxies_arg(proxy=None) is None + + proxy = "127.0.0.1:80" + assert _requests_proxies_arg(proxy=proxy) == {"http": proxy, "https": proxy} + proxy_dict = {"http": proxy} + assert _requests_proxies_arg(proxy=proxy_dict) == proxy_dict + proxy_dict = {"https": proxy} + assert _requests_proxies_arg(proxy=proxy_dict) == proxy_dict + + assert _make_session() is not None + + +def test_parse_stream(): + assert parse_stream_helper(None) is None + assert parse_stream_helper(b"data: [DONE]") is None + assert parse_stream_helper(b"data: test") == "test" + assert parse_stream_helper(b"test") is None + for line in parse_stream([b"data: test"]): + assert line == "test" + + +api_requestor = APIRequestor(base_url="http://www.baidu.com") + + +def mock_interpret_response(self, result: requests.Response, stream: bool + ) -> Tuple[Union[bytes, Iterator[Generator]], bytes]: + return b"baidu", False + + +async def mock_interpret_async_response(self, result: aiohttp.ClientResponse, stream: bool + ) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool]: + return b"baidu", True + + +def test_api_requestor(mocker): + mocker.patch("metagpt.provider.general_api_base.APIRequestor._interpret_response", mock_interpret_response) + resp, _, _ = api_requestor.request(method="get", url="/s?wd=baidu") + + resp, _, _ = api_requestor.request(method="post", url="/s?wd=baidu") + + +@pytest.mark.asyncio +async def test_async_api_requestor(mocker): + mocker.patch("metagpt.provider.general_api_base.APIRequestor._interpret_async_response", mock_interpret_async_response) + resp, _, _ = await api_requestor.arequest(method="get", url="/s?wd=baidu") + resp, _, _ = await api_requestor.arequest(method="post", url="/s?wd=baidu") diff --git a/tests/metagpt/provider/test_general_api_requestor.py b/tests/metagpt/provider/test_general_api_requestor.py index 28130fa65..dcbcc0567 100644 --- a/tests/metagpt/provider/test_general_api_requestor.py +++ b/tests/metagpt/provider/test_general_api_requestor.py @@ -4,11 +4,24 @@ import pytest -from metagpt.provider.general_api_requestor import GeneralAPIRequestor +from metagpt.provider.general_api_requestor import ( + GeneralAPIRequestor, + parse_stream, + parse_stream_helper, +) api_requestor = GeneralAPIRequestor(base_url="http://www.baidu.com") +def test_parse_stream(): + assert parse_stream_helper(None) is None + assert parse_stream_helper(b"data: [DONE]") is None + assert parse_stream_helper(b"data: test") == b"test" + assert parse_stream_helper(b"test") is None + for line in parse_stream([b"data: test"]): + assert line == b"test" + + def test_api_requestor(): resp, _, _ = api_requestor.request(method="get", url="/s?wd=baidu") assert b"baidu" in resp diff --git a/tests/metagpt/provider/test_google_gemini_api.py b/tests/metagpt/provider/test_google_gemini_api.py index 7e372634c..ffd10df7f 100644 --- a/tests/metagpt/provider/test_google_gemini_api.py +++ b/tests/metagpt/provider/test_google_gemini_api.py @@ -6,9 +6,14 @@ from abc import ABC from dataclasses import dataclass import pytest +from google.ai import generativelanguage as glm +from google.generativeai.types import content_types +from metagpt.config import CONFIG from metagpt.provider.google_gemini_api import GeminiLLM +CONFIG.gemini_api_key = "xx" + @dataclass class MockGeminiResponse(ABC): @@ -21,29 +26,53 @@ resp_content = "I'm gemini from google" default_resp = MockGeminiResponse(text=resp_content) -def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> MockGeminiResponse: +def mock_gemini_count_tokens(self, contents: content_types.ContentsType) -> glm.CountTokensResponse: + return glm.CountTokensResponse(total_tokens=20) + + +async def mock_gemini_count_tokens_async(self, contents: content_types.ContentsType) -> glm.CountTokensResponse: + return glm.CountTokensResponse(total_tokens=20) + + +def mock_gemini_generate_content(self, **kwargs) -> MockGeminiResponse: return default_resp -async def mock_llm_acompletion( - self, messgaes: list[dict], stream: bool = False, timeout: int = 60 -) -> MockGeminiResponse: - return default_resp +async def mock_gemini_generate_content_async(self, stream: bool = False, **kwargs) -> MockGeminiResponse: + if stream: + class Iterator(object): + async def __aiter__(self): + yield default_resp -async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str: - return resp_content + return Iterator() + else: + return default_resp @pytest.mark.asyncio async def test_gemini_acompletion(mocker): - mocker.patch("metagpt.provider.google_gemini_api.GeminiLLM.acompletion", mock_llm_acompletion) - mocker.patch("metagpt.provider.google_gemini_api.GeminiLLM._achat_completion", mock_llm_acompletion) + mocker.patch("metagpt.provider.google_gemini_api.GeminiGenerativeModel.count_tokens", mock_gemini_count_tokens) mocker.patch( - "metagpt.provider.google_gemini_api.GeminiLLM._achat_completion_stream", mock_llm_achat_completion_stream + "metagpt.provider.google_gemini_api.GeminiGenerativeModel.count_tokens_async", mock_gemini_count_tokens_async ) + mocker.patch("google.generativeai.generative_models.GenerativeModel.generate_content", mock_gemini_generate_content) + mocker.patch( + "google.generativeai.generative_models.GenerativeModel.generate_content_async", + mock_gemini_generate_content_async, + ) + gemini_gpt = GeminiLLM() + assert gemini_gpt._user_msg(prompt_msg) == {"role": "user", "parts": [prompt_msg]} + assert gemini_gpt._assistant_msg(prompt_msg) == {"role": "model", "parts": [prompt_msg]} + + usage = gemini_gpt.get_usage(messages, resp_content) + assert usage == {"prompt_tokens": 20, "completion_tokens": 20} + + resp = gemini_gpt.completion(messages) + assert resp == default_resp + resp = await gemini_gpt.acompletion(messages) assert resp.text == default_resp.text diff --git a/tests/metagpt/provider/test_ollama_api.py b/tests/metagpt/provider/test_ollama_api.py index ba019f295..1c604768e 100644 --- a/tests/metagpt/provider/test_ollama_api.py +++ b/tests/metagpt/provider/test_ollama_api.py @@ -2,6 +2,9 @@ # -*- coding: utf-8 -*- # @Desc : the unittest of ollama api +import json +from typing import Any, Tuple + import pytest from metagpt.config import CONFIG @@ -14,25 +17,33 @@ resp_content = "I'm ollama" default_resp = {"message": {"role": "assistant", "content": resp_content}} CONFIG.ollama_api_base = "http://xxx" +CONFIG.max_budget = 10 -def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> dict: - return default_resp +async def mock_ollama_arequest(self, stream: bool = False, **kwargs) -> Tuple[Any, Any, bool]: + if stream: + class Iterator(object): + events = [ + b'{"message": {"role": "assistant", "content": "I\'m ollama"}, "done": false}', + b'{"prompt_eval_count": 20, "eval_count": 20, "done": true}', + ] -async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> dict: - return default_resp + async def __aiter__(self): + for event in self.events: + yield event - -async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str: - return resp_content + return Iterator(), None, None + else: + raw_default_resp = default_resp.copy() + raw_default_resp.update({"prompt_eval_count": 20, "eval_count": 20}) + return json.dumps(raw_default_resp).encode(), None, None @pytest.mark.asyncio async def test_gemini_acompletion(mocker): - mocker.patch("metagpt.provider.ollama_api.OllamaLLM.acompletion", mock_llm_acompletion) - mocker.patch("metagpt.provider.ollama_api.OllamaLLM._achat_completion", mock_llm_acompletion) - mocker.patch("metagpt.provider.ollama_api.OllamaLLM._achat_completion_stream", mock_llm_achat_completion_stream) + mocker.patch("metagpt.provider.general_api_requestor.GeneralAPIRequestor.arequest", mock_ollama_arequest) + ollama_gpt = OllamaLLM() resp = await ollama_gpt.acompletion(messages) diff --git a/tests/metagpt/provider/test_open_llm_api.py b/tests/metagpt/provider/test_open_llm_api.py new file mode 100644 index 000000000..bf094d54a --- /dev/null +++ b/tests/metagpt/provider/test_open_llm_api.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +import pytest + +from openai.types.chat.chat_completion import ( + ChatCompletion, + ChatCompletionMessage, + Choice, +) +from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, ChoiceDelta, Choice as AChoice +from openai.types.completion_usage import CompletionUsage + +from metagpt.config import CONFIG +from metagpt.provider.open_llm_api import OpenLLM +from metagpt.utils.cost_manager import Costs + +CONFIG.max_budget = 10 +CONFIG.calc_usage = True + +resp_content = "I'm llama2" +default_resp = ChatCompletion( + id="cmpl-a6652c1bb181caae8dd19ad8", + model="llama-v2-13b-chat", + object="chat.completion", + created=1703302755, + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage(role="assistant", content=resp_content), + logprobs=None, + ) + ] +) + +default_resp_chunk = ChatCompletionChunk( + id=default_resp.id, + model=default_resp.model, + object="chat.completion.chunk", + created=default_resp.created, + choices=[ + AChoice( + delta=ChoiceDelta( + content=resp_content, + role="assistant" + ), + finish_reason="stop", + index=0, + logprobs=None, + ) + ] +) + +prompt_msg = "who are you" +messages = [{"role": "user", "content": prompt_msg}] + + +async def mock_openai_acompletions_create(self, stream: bool=False, **kwargs) -> ChatCompletionChunk: + if stream: + class Iterator(object): + async def __aiter__(self): + yield default_resp_chunk + return Iterator() + else: + return default_resp + + +@pytest.mark.asyncio +async def test_openllm_acompletion(mocker): + mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create) + + openllm_gpt = OpenLLM() + openllm_gpt.model = "llama-v2-13b-chat" + + openllm_gpt._update_costs(usage=CompletionUsage(prompt_tokens=100, completion_tokens=100, total_tokens=200)) + assert openllm_gpt.get_costs() == Costs(total_prompt_tokens=100, total_completion_tokens=100, total_cost=0, total_budget=0) + + resp = await openllm_gpt.acompletion(messages) + assert resp.choices[0].message.content in resp_content + + resp = await openllm_gpt.aask(prompt_msg, stream=False) + assert resp == resp_content + + resp = await openllm_gpt.acompletion_text(messages, stream=False) + assert resp == resp_content + + resp = await openllm_gpt.acompletion_text(messages, stream=True) + assert resp == resp_content + + resp = await openllm_gpt.aask(prompt_msg) + assert resp == resp_content diff --git a/tests/metagpt/provider/test_openai.py b/tests/metagpt/provider/test_openai.py index cb86dfcf9..ddc290731 100644 --- a/tests/metagpt/provider/test_openai.py +++ b/tests/metagpt/provider/test_openai.py @@ -2,9 +2,14 @@ from unittest.mock import Mock import pytest +from metagpt.config import CONFIG from metagpt.provider.openai_api import OpenAILLM from metagpt.schema import UserMessage +CONFIG.openai_proxy = None + +print("openai_api_key ", CONFIG.openai_api_key) + @pytest.mark.asyncio async def test_aask_code(): diff --git a/tests/metagpt/provider/test_spark_api.py b/tests/metagpt/provider/test_spark_api.py index e62c287c0..6d5a0e1f6 100644 --- a/tests/metagpt/provider/test_spark_api.py +++ b/tests/metagpt/provider/test_spark_api.py @@ -4,24 +4,31 @@ import pytest -from metagpt.provider.spark_api import SparkLLM +from metagpt.config import CONFIG +from metagpt.provider.spark_api import GetMessageFromWeb, SparkLLM + +CONFIG.spark_appid = "xxx" +CONFIG.spark_api_secret = "xxx" +CONFIG.spark_api_key = "xxx" +CONFIG.domain = "xxxxxx" +CONFIG.spark_url = "xxxx" prompt_msg = "who are you" resp_content = "I'm Spark" -def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> str: - return resp_content +def test_get_msg_from_web(): + get_msg_from_web = GetMessageFromWeb(text=prompt_msg) + assert get_msg_from_web.gen_params()["parameter"]["chat"]["domain"] == "xxxxxx" -async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> str: +def mock_spark_get_msg_from_web_run(self) -> str: return resp_content @pytest.mark.asyncio async def test_spark_acompletion(mocker): - mocker.patch("metagpt.provider.spark_api.SparkLLM.acompletion", mock_llm_acompletion) - mocker.patch("metagpt.provider.spark_api.SparkLLM.acompletion_text", mock_llm_acompletion) + mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run) spark_gpt = SparkLLM() resp = await spark_gpt.acompletion([]) diff --git a/tests/metagpt/provider/test_zhipuai_api.py b/tests/metagpt/provider/test_zhipuai_api.py index c1af2f0be..826e706e8 100644 --- a/tests/metagpt/provider/test_zhipuai_api.py +++ b/tests/metagpt/provider/test_zhipuai_api.py @@ -3,36 +3,68 @@ # @Desc : the unittest of ZhiPuAILLM import pytest +from zhipuai.utils.sse_client import Event from metagpt.config import CONFIG from metagpt.provider.zhipuai_api import ZhiPuAILLM -CONFIG.zhipuai_api_key = "xxx" +CONFIG.zhipuai_api_key = "xxx.xxx" prompt_msg = "who are you" messages = [{"role": "user", "content": prompt_msg}] resp_content = "I'm chatglm-turbo" -default_resp = {"code": 200, "data": {"choices": [{"role": "assistant", "content": resp_content}]}} +default_resp = { + "code": 200, + "data": { + "choices": [{"role": "assistant", "content": resp_content}], + "usage": {"prompt_tokens": 20, "completion_tokens": 20}, + }, +} -def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> dict: +def mock_zhipuai_invoke(**kwargs) -> dict: return default_resp -async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> dict: +async def mock_zhipuai_ainvoke(**kwargs) -> dict: return default_resp -async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str: - return resp_content +async def mock_zhipuai_asse_invoke(**kwargs): + class MockResponse(object): + async def _aread(self): + class Iterator(object): + events = [ + Event(id="xxx", event="add", data=resp_content, retry=0), + Event( + id="xxx", + event="finish", + data="", + meta='{"usage": {"completion_tokens": 20,"prompt_tokens": 20}}', + ), + ] + + async def __aiter__(self): + for event in self.events: + yield event + + async for chunk in Iterator(): + yield chunk + + async def async_events(self): + async for chunk in self._aread(): + yield chunk + + return MockResponse() @pytest.mark.asyncio async def test_zhipuai_acompletion(mocker): - mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAILLM.acompletion", mock_llm_acompletion) - mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAILLM._achat_completion", mock_llm_acompletion) - mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAILLM._achat_completion_stream", mock_llm_achat_completion_stream) + mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.invoke", mock_zhipuai_invoke) + mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.ainvoke", mock_zhipuai_ainvoke) + mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.asse_invoke", mock_zhipuai_asse_invoke) + zhipu_gpt = ZhiPuAILLM() resp = await zhipu_gpt.acompletion(messages) @@ -51,7 +83,7 @@ async def test_zhipuai_acompletion(mocker): assert resp == resp_content -def test_zhipuai_proxy(mocker): +def test_zhipuai_proxy(): import openai from metagpt.config import CONFIG diff --git a/tests/metagpt/provider/zhipuai/__init__.py b/tests/metagpt/provider/zhipuai/__init__.py new file mode 100644 index 000000000..2bcf8efd0 --- /dev/null +++ b/tests/metagpt/provider/zhipuai/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/tests/metagpt/provider/zhipuai/test_async_sse_client.py b/tests/metagpt/provider/zhipuai/test_async_sse_client.py new file mode 100644 index 000000000..af75e40df --- /dev/null +++ b/tests/metagpt/provider/zhipuai/test_async_sse_client.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +import pytest + +from metagpt.provider.zhipuai.async_sse_client import AsyncSSEClient + + +@pytest.mark.asyncio +async def test_async_sse_client(): + class Iterator(object): + async def __aiter__(self): + yield b"data: test_value" + + async_sse_client = AsyncSSEClient(event_source=Iterator()) + async for event in async_sse_client.async_events(): + assert event.data, "test_value" diff --git a/tests/metagpt/provider/zhipuai/test_zhipu_model_api.py b/tests/metagpt/provider/zhipuai/test_zhipu_model_api.py new file mode 100644 index 000000000..b3838e813 --- /dev/null +++ b/tests/metagpt/provider/zhipuai/test_zhipu_model_api.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +from typing import Any, Tuple +import pytest + +import zhipuai +from zhipuai.model_api.api import InvokeType +from zhipuai.utils.http_client import headers as zhipuai_default_headers + +from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI + + +api_key = "xxx.xxx" +zhipuai.api_key = api_key + +default_resp = {"result": "test response"} + + +async def mock_requestor_arequest(self, **kwargs) -> Tuple[Any, Any, str]: + return default_resp, None, None + + +@pytest.mark.asyncio +async def test_zhipu_model_api(mocker): + header = ZhiPuModelAPI.get_header() + zhipuai_default_headers.update({"Authorization": api_key}) + assert header == zhipuai_default_headers + + sse_header = ZhiPuModelAPI.get_sse_header() + assert len(sse_header["Authorization"]) == 191 + + url_prefix, url_suffix = ZhiPuModelAPI.split_zhipu_api_url(InvokeType.SYNC, kwargs={"model": "chatglm_turbo"}) + assert url_prefix == "https://open.bigmodel.cn/api" + assert url_suffix == "/paas/v3/model-api/chatglm_turbo/invoke" + + mocker.patch("metagpt.provider.general_api_requestor.GeneralAPIRequestor.arequest", mock_requestor_arequest) + result = await ZhiPuModelAPI.arequest(InvokeType.SYNC, stream=False, method="get", headers={}, kwargs={"model": "chatglm_turbo"}) + assert result == default_resp