From a07f95512448579aa8791466827121691f1109b6 Mon Sep 17 00:00:00 2001 From: geekan Date: Mon, 8 Jan 2024 17:33:06 +0800 Subject: [PATCH] refactor --- metagpt/provider/anthropic_api.py | 2 +- metagpt/provider/base_llm.py | 2 +- metagpt/provider/fireworks_api.py | 2 +- metagpt/provider/google_gemini_api.py | 3 +- metagpt/provider/human_provider.py | 2 +- metagpt/provider/ollama_api.py | 7 +- metagpt/provider/openai_api.py | 2 +- metagpt/provider/spark_api.py | 2 +- metagpt/provider/zhipuai_api.py | 2 +- tests/metagpt/provider/mock_llm_config.py | 9 +++ ...fireworks_api.py => test_fireworks_llm.py} | 0 .../provider/test_google_gemini_api.py | 6 +- tests/metagpt/provider/test_metagpt_api.py | 14 ---- ...metagpt_llm_api.py => test_metagpt_llm.py} | 5 +- tests/metagpt/provider/test_ollama_api.py | 3 +- tests/metagpt/provider/test_open_llm_api.py | 3 +- tests/metagpt/provider/test_openai.py | 78 +++---------------- 17 files changed, 41 insertions(+), 101 deletions(-) rename tests/metagpt/provider/{test_fireworks_api.py => test_fireworks_llm.py} (100%) delete mode 100644 tests/metagpt/provider/test_metagpt_api.py rename tests/metagpt/provider/{test_metagpt_llm_api.py => test_metagpt_llm.py} (63%) diff --git a/metagpt/provider/anthropic_api.py b/metagpt/provider/anthropic_api.py index 2a65b81c1..f31c2d04d 100644 --- a/metagpt/provider/anthropic_api.py +++ b/metagpt/provider/anthropic_api.py @@ -13,7 +13,7 @@ from metagpt.configs.llm_config import LLMConfig class Claude2: - def __init__(self, config: LLMConfig = None): + def __init__(self, config: LLMConfig): self.config = config def ask(self, prompt: str) -> str: diff --git a/metagpt/provider/base_llm.py b/metagpt/provider/base_llm.py index f13899c38..3c6c464dc 100644 --- a/metagpt/provider/base_llm.py +++ b/metagpt/provider/base_llm.py @@ -29,7 +29,7 @@ class BaseLLM(ABC): cost_manager: Optional[CostManager] = None @abstractmethod - def __init__(self, config: LLMConfig = None): + def __init__(self, config: LLMConfig): pass def _user_msg(self, msg: str) -> dict[str, str]: diff --git a/metagpt/provider/fireworks_api.py b/metagpt/provider/fireworks_api.py index 8c0b268e6..5fbcfdbf0 100644 --- a/metagpt/provider/fireworks_api.py +++ b/metagpt/provider/fireworks_api.py @@ -72,7 +72,7 @@ class FireworksCostManager(CostManager): @register_provider(LLMType.FIREWORKS) class FireworksLLM(OpenAILLM): - def __init__(self, config: LLMConfig = None): + def __init__(self, config: LLMConfig): super().__init__(config=config) self.auto_max_tokens = False self.cost_manager = FireworksCostManager() diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index 958ea52a5..6df814b55 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -47,10 +47,11 @@ class GeminiLLM(BaseLLM): Refs to `https://ai.google.dev/tutorials/python_quickstart` """ - def __init__(self, config: LLMConfig = None): + def __init__(self, config: LLMConfig): self.use_system_prompt = False # google gemini has no system prompt when use api self.__init_gemini(config) + self.config = config self.model = "gemini-pro" # so far only one model self.llm = GeminiGenerativeModel(model_name=self.model) diff --git a/metagpt/provider/human_provider.py b/metagpt/provider/human_provider.py index 25b897d74..fe000b3a6 100644 --- a/metagpt/provider/human_provider.py +++ b/metagpt/provider/human_provider.py @@ -15,7 +15,7 @@ class HumanProvider(BaseLLM): This enables replacing LLM anywhere in the framework with a human, thus introducing human interaction """ - def __init__(self, config: LLMConfig = None): + def __init__(self, config: LLMConfig): pass def ask(self, msg: str, timeout=3) -> str: diff --git a/metagpt/provider/ollama_api.py b/metagpt/provider/ollama_api.py index 80d0bf20c..c9103b018 100644 --- a/metagpt/provider/ollama_api.py +++ b/metagpt/provider/ollama_api.py @@ -29,16 +29,17 @@ class OllamaLLM(BaseLLM): Refs to `https://github.com/jmorganca/ollama/blob/main/docs/api.md#generate-a-chat-completion` """ - def __init__(self, config: LLMConfig = None): + def __init__(self, config: LLMConfig): self.__init_ollama(config) - self.client = GeneralAPIRequestor(base_url=config.api_base) + self.client = GeneralAPIRequestor(base_url=config.base_url) + self.config = config self.suffix_url = "/chat" self.http_method = "post" self.use_system_prompt = False self._cost_manager = TokenCostManager() def __init_ollama(self, config: LLMConfig): - assert config.api_base + assert config.base_url, "ollama base url is required!" self.model = config.model def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict: diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index d5e9c0221..d60bb8773 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -54,7 +54,7 @@ See FAQ 5.8 class OpenAILLM(BaseLLM): """Check https://platform.openai.com/examples for examples""" - def __init__(self, config: LLMConfig = None): + def __init__(self, config: LLMConfig): self.config = config self._init_model() self._init_client() diff --git a/metagpt/provider/spark_api.py b/metagpt/provider/spark_api.py index bc842f202..6ea8722c3 100644 --- a/metagpt/provider/spark_api.py +++ b/metagpt/provider/spark_api.py @@ -24,7 +24,7 @@ from metagpt.provider.llm_provider_registry import register_provider @register_provider(LLMType.SPARK) class SparkLLM(BaseLLM): - def __init__(self, config: LLMConfig = None): + def __init__(self, config: LLMConfig): self.config = config logger.warning("当前方法无法支持异步运行。当你使用acompletion时,并不能并行访问。") diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index f556edc08..0d076b801 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -38,7 +38,7 @@ class ZhiPuAILLM(BaseLLM): From now, there is only one model named `chatglm_turbo` """ - def __init__(self, config: LLMConfig = None): + def __init__(self, config: LLMConfig): self.__init_zhipuai(config) self.llm = ZhiPuModelAPI self.model = "chatglm_turbo" # so far only one model, just use it diff --git a/tests/metagpt/provider/mock_llm_config.py b/tests/metagpt/provider/mock_llm_config.py index 969ec2ab6..6b1b52335 100644 --- a/tests/metagpt/provider/mock_llm_config.py +++ b/tests/metagpt/provider/mock_llm_config.py @@ -11,4 +11,13 @@ from metagpt.configs.llm_config import LLMConfig mock_llm_config = LLMConfig( llm_type="mock", api_key="mock_api_key", + base_url="mock_base_url", +) + + +mock_llm_config_proxy = LLMConfig( + llm_type="mock", + api_key="mock_api_key", + base_url="mock_base_url", + proxy="http://localhost:8080", ) diff --git a/tests/metagpt/provider/test_fireworks_api.py b/tests/metagpt/provider/test_fireworks_llm.py similarity index 100% rename from tests/metagpt/provider/test_fireworks_api.py rename to tests/metagpt/provider/test_fireworks_llm.py diff --git a/tests/metagpt/provider/test_google_gemini_api.py b/tests/metagpt/provider/test_google_gemini_api.py index ffd10df7f..404ae1e90 100644 --- a/tests/metagpt/provider/test_google_gemini_api.py +++ b/tests/metagpt/provider/test_google_gemini_api.py @@ -9,10 +9,8 @@ import pytest from google.ai import generativelanguage as glm from google.generativeai.types import content_types -from metagpt.config import CONFIG from metagpt.provider.google_gemini_api import GeminiLLM - -CONFIG.gemini_api_key = "xx" +from tests.metagpt.provider.mock_llm_config import mock_llm_config @dataclass @@ -62,7 +60,7 @@ async def test_gemini_acompletion(mocker): mock_gemini_generate_content_async, ) - gemini_gpt = GeminiLLM() + gemini_gpt = GeminiLLM(mock_llm_config) assert gemini_gpt._user_msg(prompt_msg) == {"role": "user", "parts": [prompt_msg]} assert gemini_gpt._assistant_msg(prompt_msg) == {"role": "model", "parts": [prompt_msg]} diff --git a/tests/metagpt/provider/test_metagpt_api.py b/tests/metagpt/provider/test_metagpt_api.py deleted file mode 100644 index 8f42a53c8..000000000 --- a/tests/metagpt/provider/test_metagpt_api.py +++ /dev/null @@ -1,14 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -@Time : 2023/12/28 -@Author : mashenquan -@File : test_metagpt_api.py -""" -from metagpt.configs.llm_config import LLMType -from metagpt.llm import LLM - - -def test_llm(): - llm = LLM(provider=LLMType.METAGPT) - assert llm diff --git a/tests/metagpt/provider/test_metagpt_llm_api.py b/tests/metagpt/provider/test_metagpt_llm.py similarity index 63% rename from tests/metagpt/provider/test_metagpt_llm_api.py rename to tests/metagpt/provider/test_metagpt_llm.py index 8fce6b6b0..0263fe508 100644 --- a/tests/metagpt/provider/test_metagpt_llm_api.py +++ b/tests/metagpt/provider/test_metagpt_llm.py @@ -3,13 +3,14 @@ """ @Time : 2023/8/30 @Author : mashenquan -@File : test_metagpt_llm_api.py +@File : test_metagpt_llm.py """ from metagpt.provider.metagpt_api import MetaGPTLLM +from tests.metagpt.provider.mock_llm_config import mock_llm_config def test_metagpt(): - llm = MetaGPTLLM() + llm = MetaGPTLLM(mock_llm_config) assert llm diff --git a/tests/metagpt/provider/test_ollama_api.py b/tests/metagpt/provider/test_ollama_api.py index 1c604768e..41f02bf2c 100644 --- a/tests/metagpt/provider/test_ollama_api.py +++ b/tests/metagpt/provider/test_ollama_api.py @@ -9,6 +9,7 @@ import pytest from metagpt.config import CONFIG from metagpt.provider.ollama_api import OllamaLLM +from tests.metagpt.provider.mock_llm_config import mock_llm_config prompt_msg = "who are you" messages = [{"role": "user", "content": prompt_msg}] @@ -44,7 +45,7 @@ async def mock_ollama_arequest(self, stream: bool = False, **kwargs) -> Tuple[An async def test_gemini_acompletion(mocker): mocker.patch("metagpt.provider.general_api_requestor.GeneralAPIRequestor.arequest", mock_ollama_arequest) - ollama_gpt = OllamaLLM() + ollama_gpt = OllamaLLM(mock_llm_config) resp = await ollama_gpt.acompletion(messages) assert resp["message"]["content"] == default_resp["message"]["content"] diff --git a/tests/metagpt/provider/test_open_llm_api.py b/tests/metagpt/provider/test_open_llm_api.py index 85069c5e1..f74bc9c49 100644 --- a/tests/metagpt/provider/test_open_llm_api.py +++ b/tests/metagpt/provider/test_open_llm_api.py @@ -16,6 +16,7 @@ from openai.types.completion_usage import CompletionUsage from metagpt.config import CONFIG from metagpt.provider.open_llm_api import OpenLLM from metagpt.utils.cost_manager import Costs +from tests.metagpt.provider.mock_llm_config import mock_llm_config CONFIG.max_budget = 10 CONFIG.calc_usage = True @@ -71,7 +72,7 @@ async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) async def test_openllm_acompletion(mocker): mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create) - openllm_gpt = OpenLLM() + openllm_gpt = OpenLLM(mock_llm_config) openllm_gpt.model = "llama-v2-13b-chat" openllm_gpt._update_costs(usage=CompletionUsage(prompt_tokens=100, completion_tokens=100, total_tokens=200)) diff --git a/tests/metagpt/provider/test_openai.py b/tests/metagpt/provider/test_openai.py index a996cf5b9..ee69da861 100644 --- a/tests/metagpt/provider/test_openai.py +++ b/tests/metagpt/provider/test_openai.py @@ -1,10 +1,13 @@ -from unittest.mock import Mock - import pytest from metagpt.llm import LLM from metagpt.logs import logger +from metagpt.provider import OpenAILLM from metagpt.schema import UserMessage +from tests.metagpt.provider.mock_llm_config import ( + mock_llm_config, + mock_llm_config_proxy, +) @pytest.mark.asyncio @@ -40,74 +43,13 @@ async def test_aask_code_message(): class TestOpenAI: - @pytest.fixture - def config(self): - return Mock( - openai_api_key="test_key", - OPENAI_API_KEY="test_key", - openai_base_url="test_url", - OPENAI_BASE_URL="test_url", - openai_proxy=None, - openai_api_type="other", - ) - - @pytest.fixture - def config_azure(self): - return Mock( - openai_api_key="test_key", - OPENAI_API_KEY="test_key", - openai_api_version="test_version", - openai_base_url="test_url", - OPENAI_BASE_URL="test_url", - openai_proxy=None, - openai_api_type="azure", - ) - - @pytest.fixture - def config_proxy(self): - return Mock( - openai_api_key="test_key", - OPENAI_API_KEY="test_key", - openai_base_url="test_url", - OPENAI_BASE_URL="test_url", - openai_proxy="http://proxy.com", - openai_api_type="other", - ) - - @pytest.fixture - def config_azure_proxy(self): - return Mock( - openai_api_key="test_key", - OPENAI_API_KEY="test_key", - openai_api_version="test_version", - openai_base_url="test_url", - OPENAI_BASE_URL="test_url", - openai_proxy="http://proxy.com", - openai_api_type="azure", - ) - - def test_make_client_kwargs_without_proxy(self, config): - instance = OpenAILLM() - instance.config = config + def test_make_client_kwargs_without_proxy(self): + instance = OpenAILLM(mock_llm_config) kwargs = instance._make_client_kwargs() - assert kwargs == {"api_key": "test_key", "base_url": "test_url"} + assert kwargs == {"api_key": "mock_api_key", "base_url": "mock_base_url"} assert "http_client" not in kwargs - def test_make_client_kwargs_without_proxy_azure(self, config_azure): - instance = OpenAILLM() - instance.config = config_azure - kwargs = instance._make_client_kwargs() - assert kwargs == {"api_key": "test_key", "base_url": "test_url"} - assert "http_client" not in kwargs - - def test_make_client_kwargs_with_proxy(self, config_proxy): - instance = OpenAILLM() - instance.config = config_proxy - kwargs = instance._make_client_kwargs() - assert "http_client" in kwargs - - def test_make_client_kwargs_with_proxy_azure(self, config_azure_proxy): - instance = OpenAILLM() - instance.config = config_azure_proxy + def test_make_client_kwargs_with_proxy(self): + instance = OpenAILLM(mock_llm_config_proxy) kwargs = instance._make_client_kwargs() assert "http_client" in kwargs