mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
refactor
This commit is contained in:
parent
244fa81ffe
commit
a07f955124
17 changed files with 41 additions and 101 deletions
|
|
@ -13,7 +13,7 @@ from metagpt.configs.llm_config import LLMConfig
|
|||
|
||||
|
||||
class Claude2:
|
||||
def __init__(self, config: LLMConfig = None):
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.config = config
|
||||
|
||||
def ask(self, prompt: str) -> str:
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ class BaseLLM(ABC):
|
|||
cost_manager: Optional[CostManager] = None
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, config: LLMConfig = None):
|
||||
def __init__(self, config: LLMConfig):
|
||||
pass
|
||||
|
||||
def _user_msg(self, msg: str) -> dict[str, str]:
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ class FireworksCostManager(CostManager):
|
|||
|
||||
@register_provider(LLMType.FIREWORKS)
|
||||
class FireworksLLM(OpenAILLM):
|
||||
def __init__(self, config: LLMConfig = None):
|
||||
def __init__(self, config: LLMConfig):
|
||||
super().__init__(config=config)
|
||||
self.auto_max_tokens = False
|
||||
self.cost_manager = FireworksCostManager()
|
||||
|
|
|
|||
|
|
@ -47,10 +47,11 @@ class GeminiLLM(BaseLLM):
|
|||
Refs to `https://ai.google.dev/tutorials/python_quickstart`
|
||||
"""
|
||||
|
||||
def __init__(self, config: LLMConfig = None):
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.use_system_prompt = False # google gemini has no system prompt when use api
|
||||
|
||||
self.__init_gemini(config)
|
||||
self.config = config
|
||||
self.model = "gemini-pro" # so far only one model
|
||||
self.llm = GeminiGenerativeModel(model_name=self.model)
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ class HumanProvider(BaseLLM):
|
|||
This enables replacing LLM anywhere in the framework with a human, thus introducing human interaction
|
||||
"""
|
||||
|
||||
def __init__(self, config: LLMConfig = None):
|
||||
def __init__(self, config: LLMConfig):
|
||||
pass
|
||||
|
||||
def ask(self, msg: str, timeout=3) -> str:
|
||||
|
|
|
|||
|
|
@ -29,16 +29,17 @@ class OllamaLLM(BaseLLM):
|
|||
Refs to `https://github.com/jmorganca/ollama/blob/main/docs/api.md#generate-a-chat-completion`
|
||||
"""
|
||||
|
||||
def __init__(self, config: LLMConfig = None):
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.__init_ollama(config)
|
||||
self.client = GeneralAPIRequestor(base_url=config.api_base)
|
||||
self.client = GeneralAPIRequestor(base_url=config.base_url)
|
||||
self.config = config
|
||||
self.suffix_url = "/chat"
|
||||
self.http_method = "post"
|
||||
self.use_system_prompt = False
|
||||
self._cost_manager = TokenCostManager()
|
||||
|
||||
def __init_ollama(self, config: LLMConfig):
|
||||
assert config.api_base
|
||||
assert config.base_url, "ollama base url is required!"
|
||||
self.model = config.model
|
||||
|
||||
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ See FAQ 5.8
|
|||
class OpenAILLM(BaseLLM):
|
||||
"""Check https://platform.openai.com/examples for examples"""
|
||||
|
||||
def __init__(self, config: LLMConfig = None):
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.config = config
|
||||
self._init_model()
|
||||
self._init_client()
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ from metagpt.provider.llm_provider_registry import register_provider
|
|||
|
||||
@register_provider(LLMType.SPARK)
|
||||
class SparkLLM(BaseLLM):
|
||||
def __init__(self, config: LLMConfig = None):
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.config = config
|
||||
logger.warning("当前方法无法支持异步运行。当你使用acompletion时,并不能并行访问。")
|
||||
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ class ZhiPuAILLM(BaseLLM):
|
|||
From now, there is only one model named `chatglm_turbo`
|
||||
"""
|
||||
|
||||
def __init__(self, config: LLMConfig = None):
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.__init_zhipuai(config)
|
||||
self.llm = ZhiPuModelAPI
|
||||
self.model = "chatglm_turbo" # so far only one model, just use it
|
||||
|
|
|
|||
|
|
@ -11,4 +11,13 @@ from metagpt.configs.llm_config import LLMConfig
|
|||
mock_llm_config = LLMConfig(
|
||||
llm_type="mock",
|
||||
api_key="mock_api_key",
|
||||
base_url="mock_base_url",
|
||||
)
|
||||
|
||||
|
||||
mock_llm_config_proxy = LLMConfig(
|
||||
llm_type="mock",
|
||||
api_key="mock_api_key",
|
||||
base_url="mock_base_url",
|
||||
proxy="http://localhost:8080",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -9,10 +9,8 @@ import pytest
|
|||
from google.ai import generativelanguage as glm
|
||||
from google.generativeai.types import content_types
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.google_gemini_api import GeminiLLM
|
||||
|
||||
CONFIG.gemini_api_key = "xx"
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -62,7 +60,7 @@ async def test_gemini_acompletion(mocker):
|
|||
mock_gemini_generate_content_async,
|
||||
)
|
||||
|
||||
gemini_gpt = GeminiLLM()
|
||||
gemini_gpt = GeminiLLM(mock_llm_config)
|
||||
|
||||
assert gemini_gpt._user_msg(prompt_msg) == {"role": "user", "parts": [prompt_msg]}
|
||||
assert gemini_gpt._assistant_msg(prompt_msg) == {"role": "model", "parts": [prompt_msg]}
|
||||
|
|
|
|||
|
|
@ -1,14 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/28
|
||||
@Author : mashenquan
|
||||
@File : test_metagpt_api.py
|
||||
"""
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.llm import LLM
|
||||
|
||||
|
||||
def test_llm():
|
||||
llm = LLM(provider=LLMType.METAGPT)
|
||||
assert llm
|
||||
|
|
@ -3,13 +3,14 @@
|
|||
"""
|
||||
@Time : 2023/8/30
|
||||
@Author : mashenquan
|
||||
@File : test_metagpt_llm_api.py
|
||||
@File : test_metagpt_llm.py
|
||||
"""
|
||||
from metagpt.provider.metagpt_api import MetaGPTLLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config
|
||||
|
||||
|
||||
def test_metagpt():
|
||||
llm = MetaGPTLLM()
|
||||
llm = MetaGPTLLM(mock_llm_config)
|
||||
assert llm
|
||||
|
||||
|
||||
|
|
@ -9,6 +9,7 @@ import pytest
|
|||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.ollama_api import OllamaLLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config
|
||||
|
||||
prompt_msg = "who are you"
|
||||
messages = [{"role": "user", "content": prompt_msg}]
|
||||
|
|
@ -44,7 +45,7 @@ async def mock_ollama_arequest(self, stream: bool = False, **kwargs) -> Tuple[An
|
|||
async def test_gemini_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.general_api_requestor.GeneralAPIRequestor.arequest", mock_ollama_arequest)
|
||||
|
||||
ollama_gpt = OllamaLLM()
|
||||
ollama_gpt = OllamaLLM(mock_llm_config)
|
||||
|
||||
resp = await ollama_gpt.acompletion(messages)
|
||||
assert resp["message"]["content"] == default_resp["message"]["content"]
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from openai.types.completion_usage import CompletionUsage
|
|||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.open_llm_api import OpenLLM
|
||||
from metagpt.utils.cost_manager import Costs
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config
|
||||
|
||||
CONFIG.max_budget = 10
|
||||
CONFIG.calc_usage = True
|
||||
|
|
@ -71,7 +72,7 @@ async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs)
|
|||
async def test_openllm_acompletion(mocker):
|
||||
mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create)
|
||||
|
||||
openllm_gpt = OpenLLM()
|
||||
openllm_gpt = OpenLLM(mock_llm_config)
|
||||
openllm_gpt.model = "llama-v2-13b-chat"
|
||||
|
||||
openllm_gpt._update_costs(usage=CompletionUsage(prompt_tokens=100, completion_tokens=100, total_tokens=200))
|
||||
|
|
|
|||
|
|
@ -1,10 +1,13 @@
|
|||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider import OpenAILLM
|
||||
from metagpt.schema import UserMessage
|
||||
from tests.metagpt.provider.mock_llm_config import (
|
||||
mock_llm_config,
|
||||
mock_llm_config_proxy,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -40,74 +43,13 @@ async def test_aask_code_message():
|
|||
|
||||
|
||||
class TestOpenAI:
|
||||
@pytest.fixture
|
||||
def config(self):
|
||||
return Mock(
|
||||
openai_api_key="test_key",
|
||||
OPENAI_API_KEY="test_key",
|
||||
openai_base_url="test_url",
|
||||
OPENAI_BASE_URL="test_url",
|
||||
openai_proxy=None,
|
||||
openai_api_type="other",
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def config_azure(self):
|
||||
return Mock(
|
||||
openai_api_key="test_key",
|
||||
OPENAI_API_KEY="test_key",
|
||||
openai_api_version="test_version",
|
||||
openai_base_url="test_url",
|
||||
OPENAI_BASE_URL="test_url",
|
||||
openai_proxy=None,
|
||||
openai_api_type="azure",
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def config_proxy(self):
|
||||
return Mock(
|
||||
openai_api_key="test_key",
|
||||
OPENAI_API_KEY="test_key",
|
||||
openai_base_url="test_url",
|
||||
OPENAI_BASE_URL="test_url",
|
||||
openai_proxy="http://proxy.com",
|
||||
openai_api_type="other",
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def config_azure_proxy(self):
|
||||
return Mock(
|
||||
openai_api_key="test_key",
|
||||
OPENAI_API_KEY="test_key",
|
||||
openai_api_version="test_version",
|
||||
openai_base_url="test_url",
|
||||
OPENAI_BASE_URL="test_url",
|
||||
openai_proxy="http://proxy.com",
|
||||
openai_api_type="azure",
|
||||
)
|
||||
|
||||
def test_make_client_kwargs_without_proxy(self, config):
|
||||
instance = OpenAILLM()
|
||||
instance.config = config
|
||||
def test_make_client_kwargs_without_proxy(self):
|
||||
instance = OpenAILLM(mock_llm_config)
|
||||
kwargs = instance._make_client_kwargs()
|
||||
assert kwargs == {"api_key": "test_key", "base_url": "test_url"}
|
||||
assert kwargs == {"api_key": "mock_api_key", "base_url": "mock_base_url"}
|
||||
assert "http_client" not in kwargs
|
||||
|
||||
def test_make_client_kwargs_without_proxy_azure(self, config_azure):
|
||||
instance = OpenAILLM()
|
||||
instance.config = config_azure
|
||||
kwargs = instance._make_client_kwargs()
|
||||
assert kwargs == {"api_key": "test_key", "base_url": "test_url"}
|
||||
assert "http_client" not in kwargs
|
||||
|
||||
def test_make_client_kwargs_with_proxy(self, config_proxy):
|
||||
instance = OpenAILLM()
|
||||
instance.config = config_proxy
|
||||
kwargs = instance._make_client_kwargs()
|
||||
assert "http_client" in kwargs
|
||||
|
||||
def test_make_client_kwargs_with_proxy_azure(self, config_azure_proxy):
|
||||
instance = OpenAILLM()
|
||||
instance.config = config_azure_proxy
|
||||
def test_make_client_kwargs_with_proxy(self):
|
||||
instance = OpenAILLM(mock_llm_config_proxy)
|
||||
kwargs = instance._make_client_kwargs()
|
||||
assert "http_client" in kwargs
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue