This commit is contained in:
geekan 2024-01-08 17:33:06 +08:00
parent 244fa81ffe
commit a07f955124
17 changed files with 41 additions and 101 deletions

View file

@ -13,7 +13,7 @@ from metagpt.configs.llm_config import LLMConfig
class Claude2:
def __init__(self, config: LLMConfig = None):
def __init__(self, config: LLMConfig):
self.config = config
def ask(self, prompt: str) -> str:

View file

@ -29,7 +29,7 @@ class BaseLLM(ABC):
cost_manager: Optional[CostManager] = None
@abstractmethod
def __init__(self, config: LLMConfig = None):
def __init__(self, config: LLMConfig):
pass
def _user_msg(self, msg: str) -> dict[str, str]:

View file

@ -72,7 +72,7 @@ class FireworksCostManager(CostManager):
@register_provider(LLMType.FIREWORKS)
class FireworksLLM(OpenAILLM):
def __init__(self, config: LLMConfig = None):
def __init__(self, config: LLMConfig):
super().__init__(config=config)
self.auto_max_tokens = False
self.cost_manager = FireworksCostManager()

View file

@ -47,10 +47,11 @@ class GeminiLLM(BaseLLM):
Refs to `https://ai.google.dev/tutorials/python_quickstart`
"""
def __init__(self, config: LLMConfig = None):
def __init__(self, config: LLMConfig):
self.use_system_prompt = False # google gemini has no system prompt when use api
self.__init_gemini(config)
self.config = config
self.model = "gemini-pro" # so far only one model
self.llm = GeminiGenerativeModel(model_name=self.model)

View file

@ -15,7 +15,7 @@ class HumanProvider(BaseLLM):
This enables replacing LLM anywhere in the framework with a human, thus introducing human interaction
"""
def __init__(self, config: LLMConfig = None):
def __init__(self, config: LLMConfig):
pass
def ask(self, msg: str, timeout=3) -> str:

View file

@ -29,16 +29,17 @@ class OllamaLLM(BaseLLM):
Refs to `https://github.com/jmorganca/ollama/blob/main/docs/api.md#generate-a-chat-completion`
"""
def __init__(self, config: LLMConfig = None):
def __init__(self, config: LLMConfig):
self.__init_ollama(config)
self.client = GeneralAPIRequestor(base_url=config.api_base)
self.client = GeneralAPIRequestor(base_url=config.base_url)
self.config = config
self.suffix_url = "/chat"
self.http_method = "post"
self.use_system_prompt = False
self._cost_manager = TokenCostManager()
def __init_ollama(self, config: LLMConfig):
assert config.api_base
assert config.base_url, "ollama base url is required!"
self.model = config.model
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:

View file

@ -54,7 +54,7 @@ See FAQ 5.8
class OpenAILLM(BaseLLM):
"""Check https://platform.openai.com/examples for examples"""
def __init__(self, config: LLMConfig = None):
def __init__(self, config: LLMConfig):
self.config = config
self._init_model()
self._init_client()

View file

@ -24,7 +24,7 @@ from metagpt.provider.llm_provider_registry import register_provider
@register_provider(LLMType.SPARK)
class SparkLLM(BaseLLM):
def __init__(self, config: LLMConfig = None):
def __init__(self, config: LLMConfig):
self.config = config
logger.warning("当前方法无法支持异步运行。当你使用acompletion时并不能并行访问。")

View file

@ -38,7 +38,7 @@ class ZhiPuAILLM(BaseLLM):
From now, there is only one model named `chatglm_turbo`
"""
def __init__(self, config: LLMConfig = None):
def __init__(self, config: LLMConfig):
self.__init_zhipuai(config)
self.llm = ZhiPuModelAPI
self.model = "chatglm_turbo" # so far only one model, just use it

View file

@ -11,4 +11,13 @@ from metagpt.configs.llm_config import LLMConfig
mock_llm_config = LLMConfig(
llm_type="mock",
api_key="mock_api_key",
base_url="mock_base_url",
)
mock_llm_config_proxy = LLMConfig(
llm_type="mock",
api_key="mock_api_key",
base_url="mock_base_url",
proxy="http://localhost:8080",
)

View file

@ -9,10 +9,8 @@ import pytest
from google.ai import generativelanguage as glm
from google.generativeai.types import content_types
from metagpt.config import CONFIG
from metagpt.provider.google_gemini_api import GeminiLLM
CONFIG.gemini_api_key = "xx"
from tests.metagpt.provider.mock_llm_config import mock_llm_config
@dataclass
@ -62,7 +60,7 @@ async def test_gemini_acompletion(mocker):
mock_gemini_generate_content_async,
)
gemini_gpt = GeminiLLM()
gemini_gpt = GeminiLLM(mock_llm_config)
assert gemini_gpt._user_msg(prompt_msg) == {"role": "user", "parts": [prompt_msg]}
assert gemini_gpt._assistant_msg(prompt_msg) == {"role": "model", "parts": [prompt_msg]}

View file

@ -1,14 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/12/28
@Author : mashenquan
@File : test_metagpt_api.py
"""
from metagpt.configs.llm_config import LLMType
from metagpt.llm import LLM
def test_llm():
llm = LLM(provider=LLMType.METAGPT)
assert llm

View file

@ -3,13 +3,14 @@
"""
@Time : 2023/8/30
@Author : mashenquan
@File : test_metagpt_llm_api.py
@File : test_metagpt_llm.py
"""
from metagpt.provider.metagpt_api import MetaGPTLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config
def test_metagpt():
llm = MetaGPTLLM()
llm = MetaGPTLLM(mock_llm_config)
assert llm

View file

@ -9,6 +9,7 @@ import pytest
from metagpt.config import CONFIG
from metagpt.provider.ollama_api import OllamaLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config
prompt_msg = "who are you"
messages = [{"role": "user", "content": prompt_msg}]
@ -44,7 +45,7 @@ async def mock_ollama_arequest(self, stream: bool = False, **kwargs) -> Tuple[An
async def test_gemini_acompletion(mocker):
mocker.patch("metagpt.provider.general_api_requestor.GeneralAPIRequestor.arequest", mock_ollama_arequest)
ollama_gpt = OllamaLLM()
ollama_gpt = OllamaLLM(mock_llm_config)
resp = await ollama_gpt.acompletion(messages)
assert resp["message"]["content"] == default_resp["message"]["content"]

View file

@ -16,6 +16,7 @@ from openai.types.completion_usage import CompletionUsage
from metagpt.config import CONFIG
from metagpt.provider.open_llm_api import OpenLLM
from metagpt.utils.cost_manager import Costs
from tests.metagpt.provider.mock_llm_config import mock_llm_config
CONFIG.max_budget = 10
CONFIG.calc_usage = True
@ -71,7 +72,7 @@ async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs)
async def test_openllm_acompletion(mocker):
mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create)
openllm_gpt = OpenLLM()
openllm_gpt = OpenLLM(mock_llm_config)
openllm_gpt.model = "llama-v2-13b-chat"
openllm_gpt._update_costs(usage=CompletionUsage(prompt_tokens=100, completion_tokens=100, total_tokens=200))

View file

@ -1,10 +1,13 @@
from unittest.mock import Mock
import pytest
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.provider import OpenAILLM
from metagpt.schema import UserMessage
from tests.metagpt.provider.mock_llm_config import (
mock_llm_config,
mock_llm_config_proxy,
)
@pytest.mark.asyncio
@ -40,74 +43,13 @@ async def test_aask_code_message():
class TestOpenAI:
@pytest.fixture
def config(self):
return Mock(
openai_api_key="test_key",
OPENAI_API_KEY="test_key",
openai_base_url="test_url",
OPENAI_BASE_URL="test_url",
openai_proxy=None,
openai_api_type="other",
)
@pytest.fixture
def config_azure(self):
return Mock(
openai_api_key="test_key",
OPENAI_API_KEY="test_key",
openai_api_version="test_version",
openai_base_url="test_url",
OPENAI_BASE_URL="test_url",
openai_proxy=None,
openai_api_type="azure",
)
@pytest.fixture
def config_proxy(self):
return Mock(
openai_api_key="test_key",
OPENAI_API_KEY="test_key",
openai_base_url="test_url",
OPENAI_BASE_URL="test_url",
openai_proxy="http://proxy.com",
openai_api_type="other",
)
@pytest.fixture
def config_azure_proxy(self):
return Mock(
openai_api_key="test_key",
OPENAI_API_KEY="test_key",
openai_api_version="test_version",
openai_base_url="test_url",
OPENAI_BASE_URL="test_url",
openai_proxy="http://proxy.com",
openai_api_type="azure",
)
def test_make_client_kwargs_without_proxy(self, config):
instance = OpenAILLM()
instance.config = config
def test_make_client_kwargs_without_proxy(self):
instance = OpenAILLM(mock_llm_config)
kwargs = instance._make_client_kwargs()
assert kwargs == {"api_key": "test_key", "base_url": "test_url"}
assert kwargs == {"api_key": "mock_api_key", "base_url": "mock_base_url"}
assert "http_client" not in kwargs
def test_make_client_kwargs_without_proxy_azure(self, config_azure):
instance = OpenAILLM()
instance.config = config_azure
kwargs = instance._make_client_kwargs()
assert kwargs == {"api_key": "test_key", "base_url": "test_url"}
assert "http_client" not in kwargs
def test_make_client_kwargs_with_proxy(self, config_proxy):
instance = OpenAILLM()
instance.config = config_proxy
kwargs = instance._make_client_kwargs()
assert "http_client" in kwargs
def test_make_client_kwargs_with_proxy_azure(self, config_azure_proxy):
instance = OpenAILLM()
instance.config = config_azure_proxy
def test_make_client_kwargs_with_proxy(self):
instance = OpenAILLM(mock_llm_config_proxy)
kwargs = instance._make_client_kwargs()
assert "http_client" in kwargs