refine code

This commit is contained in:
geekan 2023-12-26 17:54:52 +08:00
parent ba8bf01870
commit 0435b1321f
53 changed files with 118 additions and 289 deletions

View file

@ -10,7 +10,7 @@ import pytest
from metagpt.actions.write_code import WriteCode
from metagpt.logs import logger
from metagpt.provider.openai_api import OpenAIGPTAPI as LLM
from metagpt.provider.openai_api import OpenAILLM as LLM
from metagpt.schema import CodingContext, Document
from tests.metagpt.actions.mock_markdown import TASKS_2, WRITE_CODE_PROMPT_SAMPLE

View file

@ -8,7 +8,7 @@
import pytest
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.base_llm import BaseLLM
from metagpt.schema import Message
default_chat_resp = {
@ -27,7 +27,7 @@ prompt_msg = "who are you"
resp_content = default_chat_resp["choices"][0]["message"]["content"]
class MockBaseGPTAPI(BaseGPTAPI):
class MockBaseGPTAPI(BaseLLM):
def completion(self, messages: list[dict], timeout=3):
return default_chat_resp

View file

@ -13,7 +13,7 @@ from openai.types.completion_usage import CompletionUsage
from metagpt.provider.fireworks_api import (
MODEL_GRADE_TOKEN_COSTS,
FireworksCostManager,
FireWorksGPTAPI,
FireworksLLM,
)
resp_content = "I'm fireworks"
@ -62,7 +62,7 @@ async def test_fireworks_acompletion(mocker):
mocker.patch(
"metagpt.provider.fireworks_api.FireWorksGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream
)
fireworks_gpt = FireWorksGPTAPI()
fireworks_gpt = FireworksLLM()
resp = await fireworks_gpt.acompletion(messages, stream=False)
assert resp.choices[0].message.content in resp_content

View file

@ -35,16 +35,6 @@ async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
return resp_content
def test_gemini_completion(mocker):
mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.completion", mock_llm_completion)
gemini_gpt = GeminiGPTAPI()
resp = gemini_gpt.completion(messages)
assert resp.text == resp_content
resp = gemini_gpt.ask(prompt_msg)
assert resp == resp_content
@pytest.mark.asyncio
async def test_gemini_acompletion(mocker):
mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.acompletion", mock_llm_acompletion)

View file

@ -17,15 +17,6 @@ async def mock_llm_aask(msg: str, timeout: int = 3) -> str:
return mock_llm_ask(msg)
def test_human_provider(mocker):
mocker.patch("metagpt.provider.human_provider.HumanProvider.ask", mock_llm_ask)
human_provider = HumanProvider()
assert resp_content == human_provider.ask(None)
assert not human_provider.completion(messages=[])
@pytest.mark.asyncio
async def test_async_human_provider(mocker):
mocker.patch("metagpt.provider.human_provider.HumanProvider.aask", mock_llm_aask)

View file

@ -5,7 +5,7 @@
import pytest
from metagpt.config import CONFIG
from metagpt.provider.ollama_api import OllamaGPTAPI
from metagpt.provider.ollama_api import OllamaLLM
prompt_msg = "who are you"
messages = [{"role": "user", "content": prompt_msg}]
@ -28,22 +28,12 @@ async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
return resp_content
def test_gemini_completion(mocker):
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.completion", mock_llm_completion)
ollama_gpt = OllamaGPTAPI()
resp = ollama_gpt.completion(messages)
assert resp["message"]["content"] == default_resp["message"]["content"]
resp = ollama_gpt.ask(prompt_msg)
assert resp == resp_content
@pytest.mark.asyncio
async def test_gemini_acompletion(mocker):
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.acompletion", mock_llm_acompletion)
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI._achat_completion", mock_llm_acompletion)
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream)
ollama_gpt = OllamaGPTAPI()
ollama_gpt = OllamaLLM()
resp = await ollama_gpt.acompletion(messages)
assert resp["message"]["content"] == default_resp["message"]["content"]

View file

@ -2,13 +2,13 @@ from unittest.mock import Mock
import pytest
from metagpt.provider.openai_api import OpenAIGPTAPI
from metagpt.provider.openai_api import OpenAILLM
from metagpt.schema import UserMessage
@pytest.mark.asyncio
async def test_aask_code():
llm = OpenAIGPTAPI()
llm = OpenAILLM()
msg = [{"role": "user", "content": "Write a python hello world code."}]
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
assert "language" in rsp
@ -18,7 +18,7 @@ async def test_aask_code():
@pytest.mark.asyncio
async def test_aask_code_str():
llm = OpenAIGPTAPI()
llm = OpenAILLM()
msg = "Write a python hello world code."
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
assert "language" in rsp
@ -28,7 +28,7 @@ async def test_aask_code_str():
@pytest.mark.asyncio
async def test_aask_code_Message():
llm = OpenAIGPTAPI()
llm = OpenAILLM()
msg = UserMessage("Write a python hello world code.")
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
assert "language" in rsp
@ -84,7 +84,7 @@ class TestOpenAI:
)
def test_make_client_kwargs_without_proxy(self, config):
instance = OpenAIGPTAPI()
instance = OpenAILLM()
instance.config = config
kwargs, async_kwargs = instance._make_client_kwargs()
assert kwargs == {"api_key": "test_key", "base_url": "test_url"}
@ -93,7 +93,7 @@ class TestOpenAI:
assert "http_client" not in async_kwargs
def test_make_client_kwargs_without_proxy_azure(self, config_azure):
instance = OpenAIGPTAPI()
instance = OpenAILLM()
instance.config = config_azure
kwargs, async_kwargs = instance._make_client_kwargs()
assert kwargs == {"api_key": "test_key", "base_url": "test_url"}
@ -102,14 +102,14 @@ class TestOpenAI:
assert "http_client" not in async_kwargs
def test_make_client_kwargs_with_proxy(self, config_proxy):
instance = OpenAIGPTAPI()
instance = OpenAILLM()
instance.config = config_proxy
kwargs, async_kwargs = instance._make_client_kwargs()
assert "http_client" in kwargs
assert "http_client" in async_kwargs
def test_make_client_kwargs_with_proxy_azure(self, config_azure_proxy):
instance = OpenAIGPTAPI()
instance = OpenAILLM()
instance.config = config_azure_proxy
kwargs, async_kwargs = instance._make_client_kwargs()
assert "http_client" in kwargs

View file

@ -4,7 +4,7 @@
import pytest
from metagpt.provider.spark_api import SparkGPTAPI
from metagpt.provider.spark_api import SparkLLM
prompt_msg = "who are you"
resp_content = "I'm Spark"
@ -18,24 +18,13 @@ async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False,
return resp_content
def test_spark_completion(mocker):
mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.completion", mock_llm_completion)
spark_gpt = SparkGPTAPI()
resp = spark_gpt.completion([])
assert resp == resp_content
resp = spark_gpt.ask(prompt_msg)
assert resp == resp_content
@pytest.mark.asyncio
async def test_spark_acompletion(mocker):
mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.acompletion", mock_llm_acompletion)
mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.acompletion_text", mock_llm_acompletion)
spark_gpt = SparkGPTAPI()
spark_gpt = SparkLLM()
resp = await spark_gpt.acompletion([], stream=False)
resp = await spark_gpt.acompletion([])
assert resp == resp_content
resp = await spark_gpt.aask(prompt_msg, stream=False)

View file

@ -28,18 +28,6 @@ async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
return resp_content
def test_zhipuai_completion(mocker):
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.completion", mock_llm_completion)
zhipu_gpt = ZhiPuAIGPTAPI()
resp = zhipu_gpt.completion(messages)
assert resp["code"] == 200
assert resp["data"]["choices"][0]["content"] == resp_content
resp = zhipu_gpt.ask(prompt_msg)
assert resp == resp_content
@pytest.mark.asyncio
async def test_zhipuai_acompletion(mocker):
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.acompletion", mock_llm_acompletion)

View file

@ -14,15 +14,6 @@ from metagpt.logs import logger
@pytest.mark.usefixtures("llm_api")
class TestGPT:
def test_llm_api_ask(self, llm_api):
answer = llm_api.ask("hello chatgpt")
logger.info(answer)
assert len(answer) > 0
def test_gptapi_ask_batch(self, llm_api):
answer = llm_api.ask_batch(["请扮演一个Google Python专家工程师如果理解回复明白", "写一个hello world"], timeout=60)
assert len(answer) > 0
@pytest.mark.asyncio
async def test_llm_api_aask(self, llm_api):
answer = await llm_api.aask("hello chatgpt", stream=False)

View file

@ -9,7 +9,7 @@
import pytest
from metagpt.provider.openai_api import OpenAIGPTAPI as LLM
from metagpt.provider.openai_api import OpenAILLM as LLM
@pytest.fixture()