mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-07-02 16:01:04 +02:00
resolve problem and add cost manager
This commit is contained in:
parent
f14a1f63ef
commit
0006b62901
4 changed files with 112 additions and 63 deletions
|
|
@ -5,7 +5,6 @@ import pytest
|
|||
from metagpt.provider.bedrock.utils import (
|
||||
NOT_SUUPORT_STREAM_MODELS,
|
||||
SUPPORT_STREAM_MODELS,
|
||||
get_max_tokens,
|
||||
)
|
||||
from metagpt.provider.bedrock_api import BedrockLLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config_bedrock
|
||||
|
|
@ -17,14 +16,19 @@ from tests.metagpt.provider.req_resp_const import (
|
|||
# all available model from bedrock
|
||||
models = SUPPORT_STREAM_MODELS | NOT_SUUPORT_STREAM_MODELS
|
||||
messages = [{"role": "user", "content": "Hi!"}]
|
||||
usage = {
|
||||
"prompt_tokens": 1000000,
|
||||
"completion_tokens": 1000000,
|
||||
}
|
||||
|
||||
|
||||
def mock_bedrock_provider_response(self, *args, **kwargs) -> dict:
|
||||
def mock_invoke_model(self: BedrockLLM, *args, **kwargs) -> dict:
|
||||
provider = self.config.model.split(".")[0]
|
||||
self._update_costs(usage, self.config.model)
|
||||
return BEDROCK_PROVIDER_RESPONSE_BODY[provider]
|
||||
|
||||
|
||||
def mock_bedrock_provider_stream_response(self, *args, **kwargs) -> dict:
|
||||
def mock_invoke_model_stream(self: BedrockLLM, *args, **kwargs) -> dict:
|
||||
# use json object to mock EventStream
|
||||
def dict2bytes(x):
|
||||
return json.dumps(x).encode("utf-8")
|
||||
|
|
@ -43,6 +47,7 @@ def mock_bedrock_provider_stream_response(self, *args, **kwargs) -> dict:
|
|||
response_body_bytes = dict2bytes(BEDROCK_PROVIDER_RESPONSE_BODY[provider])
|
||||
|
||||
response_body_stream = {"body": [{"chunk": {"bytes": response_body_bytes}}]}
|
||||
self._update_costs(usage, self.config.model)
|
||||
return response_body_stream
|
||||
|
||||
|
||||
|
|
@ -82,41 +87,23 @@ def bedrock_api(request) -> BedrockLLM:
|
|||
|
||||
class TestBedrockAPI:
|
||||
def _patch_invoke_model(self, mocker):
|
||||
mocker.patch("metagpt.provider.bedrock_api.BedrockLLM.invoke_model", mock_bedrock_provider_response)
|
||||
mocker.patch("metagpt.provider.bedrock_api.BedrockLLM.invoke_model", mock_invoke_model)
|
||||
|
||||
def _patch_invoke_model_stream(self, mocker):
|
||||
mocker.patch(
|
||||
"metagpt.provider.bedrock_api.BedrockLLM.invoke_model_with_response_stream",
|
||||
mock_bedrock_provider_stream_response,
|
||||
mock_invoke_model_stream,
|
||||
)
|
||||
|
||||
def test_const_kwargs(self, bedrock_api: BedrockLLM):
|
||||
provider = bedrock_api.provider
|
||||
assert bedrock_api._const_kwargs[provider.max_tokens_field_name] <= get_max_tokens(bedrock_api.config.model)
|
||||
|
||||
def test_get_request_body(self, bedrock_api: BedrockLLM):
|
||||
"""Ensure request body has correct format"""
|
||||
provider = bedrock_api.provider
|
||||
request_body = json.loads(provider.get_request_body(messages, bedrock_api._const_kwargs))
|
||||
|
||||
assert is_subset(request_body, get_bedrock_request_body(bedrock_api.config.model))
|
||||
|
||||
def test_completion(self, bedrock_api: BedrockLLM, mocker):
|
||||
self._patch_invoke_model(mocker)
|
||||
assert bedrock_api.completion(messages) == "Hello World"
|
||||
|
||||
def test_chat_completion_stream(self, bedrock_api: BedrockLLM, mocker):
|
||||
@pytest.mark.asyncio
|
||||
async def test_aask(self, bedrock_api: BedrockLLM, mocker):
|
||||
self._patch_invoke_model(mocker)
|
||||
self._patch_invoke_model_stream(mocker)
|
||||
assert bedrock_api._chat_completion_stream(messages) == "Hello World"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_achat_completion_stream(self, bedrock_api: BedrockLLM, mocker):
|
||||
self._patch_invoke_model_stream(mocker)
|
||||
self._patch_invoke_model(mocker)
|
||||
assert await bedrock_api._achat_completion_stream(messages) == "Hello World"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acompletion(self, bedrock_api: BedrockLLM, mocker):
|
||||
self._patch_invoke_model(mocker)
|
||||
assert await bedrock_api.acompletion(messages) == "Hello World"
|
||||
assert await bedrock_api.aask(messages, stream=False) == "Hello World"
|
||||
assert await bedrock_api.aask(messages, stream=True) == "Hello World"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue