mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-20 15:38:09 +02:00
refine code
This commit is contained in:
parent
ba8bf01870
commit
0435b1321f
53 changed files with 118 additions and 289 deletions
|
|
@ -10,7 +10,7 @@ import pytest
|
|||
|
||||
from metagpt.actions.write_code import WriteCode
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI as LLM
|
||||
from metagpt.provider.openai_api import OpenAILLM as LLM
|
||||
from metagpt.schema import CodingContext, Document
|
||||
from tests.metagpt.actions.mock_markdown import TASKS_2, WRITE_CODE_PROMPT_SAMPLE
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Message
|
||||
|
||||
default_chat_resp = {
|
||||
|
|
@ -27,7 +27,7 @@ prompt_msg = "who are you"
|
|||
resp_content = default_chat_resp["choices"][0]["message"]["content"]
|
||||
|
||||
|
||||
class MockBaseGPTAPI(BaseGPTAPI):
|
||||
class MockBaseGPTAPI(BaseLLM):
|
||||
def completion(self, messages: list[dict], timeout=3):
|
||||
return default_chat_resp
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from openai.types.completion_usage import CompletionUsage
|
|||
from metagpt.provider.fireworks_api import (
|
||||
MODEL_GRADE_TOKEN_COSTS,
|
||||
FireworksCostManager,
|
||||
FireWorksGPTAPI,
|
||||
FireworksLLM,
|
||||
)
|
||||
|
||||
resp_content = "I'm fireworks"
|
||||
|
|
@ -62,7 +62,7 @@ async def test_fireworks_acompletion(mocker):
|
|||
mocker.patch(
|
||||
"metagpt.provider.fireworks_api.FireWorksGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream
|
||||
)
|
||||
fireworks_gpt = FireWorksGPTAPI()
|
||||
fireworks_gpt = FireworksLLM()
|
||||
|
||||
resp = await fireworks_gpt.acompletion(messages, stream=False)
|
||||
assert resp.choices[0].message.content in resp_content
|
||||
|
|
|
|||
|
|
@ -35,16 +35,6 @@ async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
|
|||
return resp_content
|
||||
|
||||
|
||||
def test_gemini_completion(mocker):
|
||||
mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.completion", mock_llm_completion)
|
||||
gemini_gpt = GeminiGPTAPI()
|
||||
resp = gemini_gpt.completion(messages)
|
||||
assert resp.text == resp_content
|
||||
|
||||
resp = gemini_gpt.ask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.acompletion", mock_llm_acompletion)
|
||||
|
|
|
|||
|
|
@ -17,15 +17,6 @@ async def mock_llm_aask(msg: str, timeout: int = 3) -> str:
|
|||
return mock_llm_ask(msg)
|
||||
|
||||
|
||||
def test_human_provider(mocker):
|
||||
mocker.patch("metagpt.provider.human_provider.HumanProvider.ask", mock_llm_ask)
|
||||
human_provider = HumanProvider()
|
||||
|
||||
assert resp_content == human_provider.ask(None)
|
||||
|
||||
assert not human_provider.completion(messages=[])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_human_provider(mocker):
|
||||
mocker.patch("metagpt.provider.human_provider.HumanProvider.aask", mock_llm_aask)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.ollama_api import OllamaGPTAPI
|
||||
from metagpt.provider.ollama_api import OllamaLLM
|
||||
|
||||
prompt_msg = "who are you"
|
||||
messages = [{"role": "user", "content": prompt_msg}]
|
||||
|
|
@ -28,22 +28,12 @@ async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
|
|||
return resp_content
|
||||
|
||||
|
||||
def test_gemini_completion(mocker):
|
||||
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.completion", mock_llm_completion)
|
||||
ollama_gpt = OllamaGPTAPI()
|
||||
resp = ollama_gpt.completion(messages)
|
||||
assert resp["message"]["content"] == default_resp["message"]["content"]
|
||||
|
||||
resp = ollama_gpt.ask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.acompletion", mock_llm_acompletion)
|
||||
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI._achat_completion", mock_llm_acompletion)
|
||||
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream)
|
||||
ollama_gpt = OllamaGPTAPI()
|
||||
ollama_gpt = OllamaLLM()
|
||||
|
||||
resp = await ollama_gpt.acompletion(messages)
|
||||
assert resp["message"]["content"] == default_resp["message"]["content"]
|
||||
|
|
|
|||
|
|
@ -2,13 +2,13 @@ from unittest.mock import Mock
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI
|
||||
from metagpt.provider.openai_api import OpenAILLM
|
||||
from metagpt.schema import UserMessage
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aask_code():
|
||||
llm = OpenAIGPTAPI()
|
||||
llm = OpenAILLM()
|
||||
msg = [{"role": "user", "content": "Write a python hello world code."}]
|
||||
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
assert "language" in rsp
|
||||
|
|
@ -18,7 +18,7 @@ async def test_aask_code():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aask_code_str():
|
||||
llm = OpenAIGPTAPI()
|
||||
llm = OpenAILLM()
|
||||
msg = "Write a python hello world code."
|
||||
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
assert "language" in rsp
|
||||
|
|
@ -28,7 +28,7 @@ async def test_aask_code_str():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aask_code_Message():
|
||||
llm = OpenAIGPTAPI()
|
||||
llm = OpenAILLM()
|
||||
msg = UserMessage("Write a python hello world code.")
|
||||
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
assert "language" in rsp
|
||||
|
|
@ -84,7 +84,7 @@ class TestOpenAI:
|
|||
)
|
||||
|
||||
def test_make_client_kwargs_without_proxy(self, config):
|
||||
instance = OpenAIGPTAPI()
|
||||
instance = OpenAILLM()
|
||||
instance.config = config
|
||||
kwargs, async_kwargs = instance._make_client_kwargs()
|
||||
assert kwargs == {"api_key": "test_key", "base_url": "test_url"}
|
||||
|
|
@ -93,7 +93,7 @@ class TestOpenAI:
|
|||
assert "http_client" not in async_kwargs
|
||||
|
||||
def test_make_client_kwargs_without_proxy_azure(self, config_azure):
|
||||
instance = OpenAIGPTAPI()
|
||||
instance = OpenAILLM()
|
||||
instance.config = config_azure
|
||||
kwargs, async_kwargs = instance._make_client_kwargs()
|
||||
assert kwargs == {"api_key": "test_key", "base_url": "test_url"}
|
||||
|
|
@ -102,14 +102,14 @@ class TestOpenAI:
|
|||
assert "http_client" not in async_kwargs
|
||||
|
||||
def test_make_client_kwargs_with_proxy(self, config_proxy):
|
||||
instance = OpenAIGPTAPI()
|
||||
instance = OpenAILLM()
|
||||
instance.config = config_proxy
|
||||
kwargs, async_kwargs = instance._make_client_kwargs()
|
||||
assert "http_client" in kwargs
|
||||
assert "http_client" in async_kwargs
|
||||
|
||||
def test_make_client_kwargs_with_proxy_azure(self, config_azure_proxy):
|
||||
instance = OpenAIGPTAPI()
|
||||
instance = OpenAILLM()
|
||||
instance.config = config_azure_proxy
|
||||
kwargs, async_kwargs = instance._make_client_kwargs()
|
||||
assert "http_client" in kwargs
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.spark_api import SparkGPTAPI
|
||||
from metagpt.provider.spark_api import SparkLLM
|
||||
|
||||
prompt_msg = "who are you"
|
||||
resp_content = "I'm Spark"
|
||||
|
|
@ -18,24 +18,13 @@ async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False,
|
|||
return resp_content
|
||||
|
||||
|
||||
def test_spark_completion(mocker):
|
||||
mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.completion", mock_llm_completion)
|
||||
spark_gpt = SparkGPTAPI()
|
||||
|
||||
resp = spark_gpt.completion([])
|
||||
assert resp == resp_content
|
||||
|
||||
resp = spark_gpt.ask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spark_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.acompletion", mock_llm_acompletion)
|
||||
mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.acompletion_text", mock_llm_acompletion)
|
||||
spark_gpt = SparkGPTAPI()
|
||||
spark_gpt = SparkLLM()
|
||||
|
||||
resp = await spark_gpt.acompletion([], stream=False)
|
||||
resp = await spark_gpt.acompletion([])
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await spark_gpt.aask(prompt_msg, stream=False)
|
||||
|
|
|
|||
|
|
@ -28,18 +28,6 @@ async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
|
|||
return resp_content
|
||||
|
||||
|
||||
def test_zhipuai_completion(mocker):
|
||||
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.completion", mock_llm_completion)
|
||||
zhipu_gpt = ZhiPuAIGPTAPI()
|
||||
|
||||
resp = zhipu_gpt.completion(messages)
|
||||
assert resp["code"] == 200
|
||||
assert resp["data"]["choices"][0]["content"] == resp_content
|
||||
|
||||
resp = zhipu_gpt.ask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zhipuai_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.acompletion", mock_llm_acompletion)
|
||||
|
|
|
|||
|
|
@ -14,15 +14,6 @@ from metagpt.logs import logger
|
|||
|
||||
@pytest.mark.usefixtures("llm_api")
|
||||
class TestGPT:
|
||||
def test_llm_api_ask(self, llm_api):
|
||||
answer = llm_api.ask("hello chatgpt")
|
||||
logger.info(answer)
|
||||
assert len(answer) > 0
|
||||
|
||||
def test_gptapi_ask_batch(self, llm_api):
|
||||
answer = llm_api.ask_batch(["请扮演一个Google Python专家工程师,如果理解,回复明白", "写一个hello world"], timeout=60)
|
||||
assert len(answer) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_api_aask(self, llm_api):
|
||||
answer = await llm_api.aask("hello chatgpt", stream=False)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI as LLM
|
||||
from metagpt.provider.openai_api import OpenAILLM as LLM
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue