update provider uniform name and check tests

This commit is contained in:
better629 2023-12-28 17:18:18 +08:00
parent d40c4f5025
commit 255e2d3fa7
13 changed files with 48 additions and 50 deletions

View file

@ -17,7 +17,7 @@ from pydantic import BaseModel, Field
from metagpt.config import CONFIG
from metagpt.const import DEFAULT_LANGUAGE, DEFAULT_MAX_TOKENS, DEFAULT_TOKEN_SIZE
from metagpt.logs import logger
from metagpt.provider import MetaGPTAPI
from metagpt.provider import MetaGPTLLM
from metagpt.provider.base_llm import BaseLLM
from metagpt.schema import Message, SimpleMessage
from metagpt.utils.redis import Redis
@ -122,7 +122,7 @@ class BrainMemory(BaseModel):
return v
async def summarize(self, llm, max_words=200, keep_language: bool = False, limit: int = -1, **kwargs):
if isinstance(llm, MetaGPTAPI):
if isinstance(llm, MetaGPTLLM):
return await self._metagpt_summarize(max_words=max_words)
self.llm = llm
@ -175,7 +175,7 @@ class BrainMemory(BaseModel):
async def get_title(self, llm, max_words=5, **kwargs) -> str:
"""Generate text title"""
if isinstance(llm, MetaGPTAPI):
if isinstance(llm, MetaGPTLLM):
return self.history[0].content if self.history else "New"
summary = await self.summarize(llm=llm, max_words=500)
@ -190,7 +190,7 @@ class BrainMemory(BaseModel):
return response
async def is_related(self, text1, text2, llm):
if isinstance(llm, MetaGPTAPI):
if isinstance(llm, MetaGPTLLM):
return await self._metagpt_is_related(text1=text1, text2=text2, llm=llm)
return await self._openai_is_related(text1=text1, text2=text2, llm=llm)
@ -212,7 +212,7 @@ class BrainMemory(BaseModel):
return result
async def rewrite(self, sentence: str, context: str, llm):
if isinstance(llm, MetaGPTAPI):
if isinstance(llm, MetaGPTLLM):
return await self._metagpt_rewrite(sentence=sentence, context=context, llm=llm)
return await self._openai_rewrite(sentence=sentence, context=context, llm=llm)

View file

@ -7,21 +7,21 @@
"""
from metagpt.provider.fireworks_api import FireworksLLM
from metagpt.provider.google_gemini_api import GeminiGPTAPI
from metagpt.provider.google_gemini_api import GeminiLLM
from metagpt.provider.ollama_api import OllamaLLM
from metagpt.provider.open_llm_api import OpenLLMGPTAPI
from metagpt.provider.open_llm_api import OpenLLM
from metagpt.provider.openai_api import OpenAILLM
from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI
from metagpt.provider.zhipuai_api import ZhiPuAILLM
from metagpt.provider.azure_openai_api import AzureOpenAILLM
from metagpt.provider.metagpt_api import MetaGPTAPI
from metagpt.provider.metagpt_api import MetaGPTLLM
__all__ = [
"FireworksLLM",
"GeminiGPTAPI",
"OpenLLMGPTAPI",
"GeminiLLM",
"OpenLLM",
"OpenAILLM",
"ZhiPuAIGPTAPI",
"ZhiPuAILLM",
"AzureOpenAILLM",
"MetaGPTAPI",
"MetaGPTLLM",
"OllamaLLM",
]

View file

@ -42,7 +42,7 @@ class GeminiGenerativeModel(GenerativeModel):
@register_provider(LLMProviderEnum.GEMINI)
class GeminiGPTAPI(BaseLLM):
class GeminiLLM(BaseLLM):
"""
Refs to `https://ai.google.dev/tutorials/python_quickstart`
"""

View file

@ -11,6 +11,6 @@ from metagpt.provider.llm_provider_registry import register_provider
@register_provider(LLMProviderEnum.METAGPT)
class MetaGPTAPI(OpenAILLM):
class MetaGPTLLM(OpenAILLM):
def __init__(self):
super().__init__()

View file

@ -35,7 +35,7 @@ class OpenLLMCostManager(CostManager):
@register_provider(LLMProviderEnum.OPEN_LLM)
class OpenLLMGPTAPI(OpenAILLM):
class OpenLLM(OpenAILLM):
def __init__(self):
self.config: Config = CONFIG
self.__init_openllm()

View file

@ -5,6 +5,7 @@
import json
from enum import Enum
import openai
import zhipuai
from requests import ConnectionError
from tenacity import (
@ -31,7 +32,7 @@ class ZhiPuEvent(Enum):
@register_provider(LLMProviderEnum.ZHIPUAI)
class ZhiPuAIGPTAPI(BaseLLM):
class ZhiPuAILLM(BaseLLM):
"""
Refs to `https://open.bigmodel.cn/dev/api#chatglm_turbo`
From now, there is only one model named `chatglm_turbo`

View file

@ -15,6 +15,9 @@ from metagpt.provider.fireworks_api import (
FireworksCostManager,
FireworksLLM,
)
from metagpt.config import CONFIG
CONFIG.fireworks_api_key = "xxx"
resp_content = "I'm fireworks"
default_resp = ChatCompletion(
@ -23,7 +26,7 @@ default_resp = ChatCompletion(
object="chat.completion",
created=1703300855,
choices=[
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(role="assistant", content=resp_content))
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(role="assistant", content=resp_content), logprobs=None)
],
usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202),
)
@ -57,10 +60,10 @@ async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
@pytest.mark.asyncio
async def test_fireworks_acompletion(mocker):
mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.acompletion", mock_llm_acompletion)
mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI._achat_completion", mock_llm_acompletion)
mocker.patch("metagpt.provider.fireworks_api.FireworksLLM.acompletion", mock_llm_acompletion)
mocker.patch("metagpt.provider.fireworks_api.FireworksLLM._achat_completion", mock_llm_acompletion)
mocker.patch(
"metagpt.provider.fireworks_api.FireWorksGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream
"metagpt.provider.fireworks_api.FireworksLLM._achat_completion_stream", mock_llm_achat_completion_stream
)
fireworks_gpt = FireworksLLM()

View file

@ -7,7 +7,7 @@ from dataclasses import dataclass
import pytest
from metagpt.provider.google_gemini_api import GeminiGPTAPI
from metagpt.provider.google_gemini_api import GeminiLLM
@dataclass
@ -37,12 +37,12 @@ async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
@pytest.mark.asyncio
async def test_gemini_acompletion(mocker):
mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.acompletion", mock_llm_acompletion)
mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI._achat_completion", mock_llm_acompletion)
mocker.patch("metagpt.provider.google_gemini_api.GeminiLLM.acompletion", mock_llm_acompletion)
mocker.patch("metagpt.provider.google_gemini_api.GeminiLLM._achat_completion", mock_llm_acompletion)
mocker.patch(
"metagpt.provider.google_gemini_api.GeminiGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream
"metagpt.provider.google_gemini_api.GeminiLLM._achat_completion_stream", mock_llm_achat_completion_stream
)
gemini_gpt = GeminiGPTAPI()
gemini_gpt = GeminiLLM()
resp = await gemini_gpt.acompletion(messages)
assert resp.text == default_resp.text

View file

@ -5,11 +5,11 @@
@Author : mashenquan
@File : test_metagpt_llm_api.py
"""
from metagpt.provider.metagpt_api import MetaGPTAPI
from metagpt.provider.metagpt_api import MetaGPTLLM
def test_metagpt():
llm = MetaGPTAPI()
llm = MetaGPTLLM()
assert llm

View file

@ -30,9 +30,9 @@ async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
@pytest.mark.asyncio
async def test_gemini_acompletion(mocker):
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.acompletion", mock_llm_acompletion)
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI._achat_completion", mock_llm_acompletion)
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream)
mocker.patch("metagpt.provider.ollama_api.OllamaLLM.acompletion", mock_llm_acompletion)
mocker.patch("metagpt.provider.ollama_api.OllamaLLM._achat_completion", mock_llm_acompletion)
mocker.patch("metagpt.provider.ollama_api.OllamaLLM._achat_completion_stream", mock_llm_achat_completion_stream)
ollama_gpt = OllamaLLM()
resp = await ollama_gpt.acompletion(messages)

View file

@ -86,31 +86,25 @@ class TestOpenAI:
def test_make_client_kwargs_without_proxy(self, config):
instance = OpenAILLM()
instance.config = config
kwargs, async_kwargs = instance._make_client_kwargs()
kwargs = instance._make_client_kwargs()
assert kwargs == {"api_key": "test_key", "base_url": "test_url"}
assert async_kwargs == {"api_key": "test_key", "base_url": "test_url"}
assert "http_client" not in kwargs
assert "http_client" not in async_kwargs
def test_make_client_kwargs_without_proxy_azure(self, config_azure):
instance = OpenAILLM()
instance.config = config_azure
kwargs, async_kwargs = instance._make_client_kwargs()
kwargs = instance._make_client_kwargs()
assert kwargs == {"api_key": "test_key", "base_url": "test_url"}
assert async_kwargs == {"api_key": "test_key", "base_url": "test_url"}
assert "http_client" not in kwargs
assert "http_client" not in async_kwargs
def test_make_client_kwargs_with_proxy(self, config_proxy):
instance = OpenAILLM()
instance.config = config_proxy
kwargs, async_kwargs = instance._make_client_kwargs()
kwargs = instance._make_client_kwargs()
assert "http_client" in kwargs
assert "http_client" in async_kwargs
def test_make_client_kwargs_with_proxy_azure(self, config_azure_proxy):
instance = OpenAILLM()
instance.config = config_azure_proxy
kwargs, async_kwargs = instance._make_client_kwargs()
kwargs = instance._make_client_kwargs()
assert "http_client" in kwargs
assert "http_client" in async_kwargs

View file

@ -20,8 +20,8 @@ async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False,
@pytest.mark.asyncio
async def test_spark_acompletion(mocker):
mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.acompletion", mock_llm_acompletion)
mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.acompletion_text", mock_llm_acompletion)
mocker.patch("metagpt.provider.spark_api.SparkLLM.acompletion", mock_llm_acompletion)
mocker.patch("metagpt.provider.spark_api.SparkLLM.acompletion_text", mock_llm_acompletion)
spark_gpt = SparkLLM()
resp = await spark_gpt.acompletion([])

View file

@ -1,11 +1,11 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the unittest of ZhiPuAIGPTAPI
# @Desc : the unittest of ZhiPuAILLM
import pytest
from metagpt.config import CONFIG
from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI
from metagpt.provider.zhipuai_api import ZhiPuAILLM
CONFIG.zhipuai_api_key = "xxx"
@ -30,12 +30,12 @@ async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
@pytest.mark.asyncio
async def test_zhipuai_acompletion(mocker):
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.acompletion", mock_llm_acompletion)
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI._achat_completion", mock_llm_acompletion)
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAILLM.acompletion", mock_llm_acompletion)
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAILLM._achat_completion", mock_llm_acompletion)
mocker.patch(
"metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream
"metagpt.provider.zhipuai_api.ZhiPuAILLM._achat_completion_stream", mock_llm_achat_completion_stream
)
zhipu_gpt = ZhiPuAIGPTAPI()
zhipu_gpt = ZhiPuAILLM()
resp = await zhipu_gpt.acompletion(messages)
assert resp["data"]["choices"][0]["content"] == resp_content
@ -59,5 +59,5 @@ def test_zhipuai_proxy(mocker):
from metagpt.config import CONFIG
CONFIG.openai_proxy = "http://127.0.0.1:8080"
_ = ZhiPuAIGPTAPI()
_ = ZhiPuAILLM()
assert openai.proxy == CONFIG.openai_proxy