mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-11 15:15:18 +02:00
add test
This commit is contained in:
parent
cafe666bfd
commit
4c394a1cac
7 changed files with 198 additions and 29 deletions
|
|
@ -1,4 +1,5 @@
|
|||
from typing import Literal
|
||||
import json
|
||||
from metagpt.const import USE_CONFIG_TIMEOUT
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
|
|
@ -8,9 +9,10 @@ from metagpt.provider.bedrock.bedrock_provider import get_provider
|
|||
from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS, get_max_tokens
|
||||
try:
|
||||
import boto3
|
||||
from botocore.response import StreamingBody
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"boto3 not found! please install it by `pip install boto3` first ")
|
||||
"boto3 not found! please install it by `pip install boto3` ")
|
||||
|
||||
|
||||
@register_provider([LLMType.AMAZON_BEDROCK])
|
||||
|
|
@ -25,7 +27,7 @@ class AmazonBedrockLLM(BaseLLM):
|
|||
self.__client = self.__init_client("bedrock-runtime")
|
||||
self.__provider = get_provider(self.config.model)
|
||||
logger.warning(
|
||||
"Amazon bedrock doesn't support asynchronous calls now")
|
||||
"Amazon bedrock doesn't support asynchronous now")
|
||||
|
||||
def __init_client(self, service_name: Literal["bedrock-runtime", "bedrock"]):
|
||||
"""initialize boto3 client"""
|
||||
|
|
@ -39,6 +41,12 @@ class AmazonBedrockLLM(BaseLLM):
|
|||
client = session.client(service_name)
|
||||
return client
|
||||
|
||||
def _get_client(self):
|
||||
return self.__client
|
||||
|
||||
def _get_provider(self):
|
||||
return self.__provider
|
||||
|
||||
def list_models(self):
|
||||
"""list all available text-generation models
|
||||
|
||||
|
|
@ -55,6 +63,19 @@ class AmazonBedrockLLM(BaseLLM):
|
|||
for summary in response["modelSummaries"]]
|
||||
logger.info("\n"+"\n".join(summaries))
|
||||
|
||||
def invoke_model(self, request_body) -> dict:
|
||||
response = self.__client.invoke_model(
|
||||
modelId=self.config.model, body=request_body
|
||||
)
|
||||
response_body = self._get_response_body(response)
|
||||
return response_body
|
||||
|
||||
def invoke_model_with_response_stream(self, request_body) -> StreamingBody:
|
||||
response = self.__client.invoke_model_with_response_stream(
|
||||
modelId=self.config.model, body=request_body
|
||||
)
|
||||
return response
|
||||
|
||||
@property
|
||||
def _generate_kwargs(self) -> dict:
|
||||
model_max_tokens = get_max_tokens(self.config.model)
|
||||
|
|
@ -71,10 +92,8 @@ class AmazonBedrockLLM(BaseLLM):
|
|||
def completion(self, messages: list[dict]) -> str:
|
||||
request_body = self.__provider.get_request_body(
|
||||
messages, **self._generate_kwargs)
|
||||
response = self.__client.invoke_model(
|
||||
modelId=self.config.model, body=request_body
|
||||
)
|
||||
completions = self.__provider.get_choice_text(response)
|
||||
response_body = self.invoke_model(request_body)
|
||||
completions = self.__provider.get_choice_text(response_body)
|
||||
return completions
|
||||
|
||||
def _chat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str:
|
||||
|
|
@ -86,9 +105,7 @@ class AmazonBedrockLLM(BaseLLM):
|
|||
request_body = self.__provider.get_request_body(
|
||||
messages, **self._generate_kwargs)
|
||||
|
||||
response = self.__client.invoke_model_with_response_stream(
|
||||
modelId=self.config.model, body=request_body
|
||||
)
|
||||
response = self.invoke_model_with_response_stream(request_body)
|
||||
|
||||
collected_content = []
|
||||
for event in response["body"]:
|
||||
|
|
@ -119,3 +136,7 @@ class AmazonBedrockLLM(BaseLLM):
|
|||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
|
||||
return self._chat_completion_stream(messages)
|
||||
|
||||
def _get_response_body(self, response) -> dict:
|
||||
response_body = json.loads(response["body"].read())
|
||||
return response_body
|
||||
|
|
|
|||
|
|
@ -15,8 +15,7 @@ class BaseBedrockProvider(ABC):
|
|||
{"prompt": self.messages_to_prompt(messages), **generate_kwargs})
|
||||
return body
|
||||
|
||||
def get_choice_text(self, response) -> str:
|
||||
response_body = self._get_response_body_json(response)
|
||||
def get_choice_text(self, response_body: dict) -> str:
|
||||
completions = self._get_completion_from_dict(response_body)
|
||||
return completions
|
||||
|
||||
|
|
@ -25,10 +24,6 @@ class BaseBedrockProvider(ABC):
|
|||
completions = self._get_completion_from_dict(rsp_dict)
|
||||
return completions
|
||||
|
||||
def _get_response_body_json(self, response):
|
||||
response_body = json.loads(response["body"].read())
|
||||
return response_body
|
||||
|
||||
def messages_to_prompt(self, messages: list[dict]) -> str:
|
||||
"""[{"role": "user", "content": msg}] to user: <msg> etc."""
|
||||
return "\n".join([f"{i['role']}: {i['content']}" for i in messages])
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
from typing import Literal
|
||||
from metagpt.provider.bedrock.base_provider import BaseBedrockProvider
|
||||
from metagpt.provider.bedrock.utils import messages_to_prompt_llama2, messages_to_prompt_llama3, messages_to_prompt_claude
|
||||
from metagpt.provider.bedrock.utils import messages_to_prompt_llama2, messages_to_prompt_llama3
|
||||
|
||||
|
||||
class MistralProvider(BaseBedrockProvider):
|
||||
|
|
@ -17,9 +17,6 @@ class MistralProvider(BaseBedrockProvider):
|
|||
class AnthropicProvider(BaseBedrockProvider):
|
||||
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
|
||||
def messages_to_prompt(self, messages: list[dict]) -> str:
|
||||
return messages_to_prompt_claude(messages)
|
||||
|
||||
def get_request_body(self, messages: list[dict], **generate_kwargs):
|
||||
body = json.dumps(
|
||||
{"messages": messages, "anthropic_version": "bedrock-2023-05-31", **generate_kwargs})
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ SUPPORT_STREAM_MODELS = {
|
|||
# TODO:use a more general function for constructing chat templates.
|
||||
|
||||
|
||||
def messages_to_prompt_llama2(messages: list[dict]):
|
||||
def messages_to_prompt_llama2(messages: list[dict]) -> str:
|
||||
BOS, EOS = "<s>", "</s>"
|
||||
B_INST, E_INST = "[INST]", "[/INST]"
|
||||
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
||||
|
|
@ -56,7 +56,7 @@ def messages_to_prompt_llama2(messages: list[dict]):
|
|||
return prompt
|
||||
|
||||
|
||||
def messages_to_prompt_llama3(messages: list[dict]):
|
||||
def messages_to_prompt_llama3(messages: list[dict]) -> str:
|
||||
BOS, EOS = "<|begin_of_text|>", "<|eot_id|>"
|
||||
GENERAL_TEMPLATE = "<|start_header_id|>{role}<|end_header_id|>\n\n{content}<|eot_id|>"
|
||||
|
||||
|
|
@ -72,7 +72,7 @@ def messages_to_prompt_llama3(messages: list[dict]):
|
|||
return prompt
|
||||
|
||||
|
||||
def messages_to_prompt_claude(messages: list[dict]):
|
||||
def messages_to_prompt_claude2(messages: list[dict]) -> str:
|
||||
GENERAL_TEMPLATE = "\n\n{role}: {content}"
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
|
|
|
|||
|
|
@ -60,3 +60,12 @@ mock_llm_config_dashscope = LLMConfig(api_type="dashscope", api_key="xxx", model
|
|||
mock_llm_config_anthropic = LLMConfig(
|
||||
api_type="anthropic", api_key="xxx", base_url="https://api.anthropic.com", model="claude-3-opus-20240229"
|
||||
)
|
||||
|
||||
mock_llm_config_bedrock = LLMConfig(
|
||||
api_type="amazon_bedrock",
|
||||
model="gpt-100",
|
||||
region_name="somewhere",
|
||||
access_key="123abc",
|
||||
secret_key="123abc",
|
||||
max_token=10000,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -65,17 +65,20 @@ def get_openai_chat_completion(name: str) -> ChatCompletion:
|
|||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(role="assistant", content=resp_cont_tmpl.format(name=name)),
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant", content=resp_cont_tmpl.format(name=name)),
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202),
|
||||
usage=CompletionUsage(completion_tokens=110,
|
||||
prompt_tokens=92, total_tokens=202),
|
||||
)
|
||||
return openai_chat_completion
|
||||
|
||||
|
||||
def get_openai_chat_completion_chunk(name: str, usage_as_dict: bool = False) -> ChatCompletionChunk:
|
||||
usage = CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202)
|
||||
usage = CompletionUsage(completion_tokens=110,
|
||||
prompt_tokens=92, total_tokens=202)
|
||||
usage = usage if not usage_as_dict else usage.model_dump()
|
||||
openai_chat_completion_chunk = ChatCompletionChunk(
|
||||
id="cmpl-a6652c1bb181caae8dd19ad8",
|
||||
|
|
@ -84,7 +87,8 @@ def get_openai_chat_completion_chunk(name: str, usage_as_dict: bool = False) ->
|
|||
created=1703300855,
|
||||
choices=[
|
||||
AChoice(
|
||||
delta=ChoiceDelta(role="assistant", content=resp_cont_tmpl.format(name=name)),
|
||||
delta=ChoiceDelta(role="assistant",
|
||||
content=resp_cont_tmpl.format(name=name)),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
logprobs=None,
|
||||
|
|
@ -133,7 +137,8 @@ def get_dashscope_response(name: str) -> GenerationResponse:
|
|||
],
|
||||
}
|
||||
),
|
||||
usage=GenerationUsage(**{"input_tokens": 12, "output_tokens": 98, "total_tokens": 110}),
|
||||
usage=GenerationUsage(
|
||||
**{"input_tokens": 12, "output_tokens": 98, "total_tokens": 110}),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -155,7 +160,8 @@ def get_anthropic_response(name: str, stream: bool = False) -> Message:
|
|||
),
|
||||
ContentBlockDeltaEvent(
|
||||
index=0,
|
||||
delta=TextDelta(text=resp_cont_tmpl.format(name=name), type="text_delta"),
|
||||
delta=TextDelta(text=resp_cont_tmpl.format(
|
||||
name=name), type="text_delta"),
|
||||
type="content_block_delta",
|
||||
),
|
||||
]
|
||||
|
|
@ -165,7 +171,8 @@ def get_anthropic_response(name: str, stream: bool = False) -> Message:
|
|||
model=name,
|
||||
role="assistant",
|
||||
type="message",
|
||||
content=[ContentBlock(text=resp_cont_tmpl.format(name=name), type="text")],
|
||||
content=[ContentBlock(
|
||||
text=resp_cont_tmpl.format(name=name), type="text")],
|
||||
usage=AnthropicUsage(input_tokens=10, output_tokens=10),
|
||||
)
|
||||
|
||||
|
|
@ -183,3 +190,52 @@ async def llm_general_chat_funcs_test(llm: BaseLLM, prompt: str, messages: list[
|
|||
|
||||
resp = await llm.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_cont
|
||||
|
||||
|
||||
# For Amazon Bedrock
|
||||
BEDROCK_PROVIDER_REQUEST_BODY = {
|
||||
"mistral": {"prompt": "", "max_tokens": 0, "stop": [], "temperature": 0.0, "top_p": 0.0, "top_k": 0},
|
||||
"meta": {"prompt": "", "temperature": 0.0, "top_p": 0.0, "max_gen_len": 0},
|
||||
"ai21": {
|
||||
"prompt": "", "temperature": 0.0, "topP": 0.0, "maxTokens": 0,
|
||||
"stopSequences": [], "countPenalty": {"scale": 0.0},
|
||||
"presencePenalty": {"scale": 0.0}, "frequencyPenalty": {"scale": 0.0}
|
||||
},
|
||||
"cohere": {
|
||||
"prompt": "", "temperature": 0.0, "p": 0.0, "k": 0.0, "max_tokens": 0, "stop_sequences": [],
|
||||
"return_likelihoods": "NONE", "stream": False, "num_generations": 0, "logit_bias": {}, "truncate": "NONE"
|
||||
},
|
||||
"anthropic": {
|
||||
"anthropic_version": "bedrock-2023-05-31", "max_tokens": 0, "system": "",
|
||||
"messages": [{"role": "", "content": ""}], "temperature": 0.0, "top_p": 0.0, "top_k": 0, "stop_sequences": []
|
||||
},
|
||||
"amazon": {
|
||||
"inputText": "", "textGenerationConfig": {"temperature": 0.0, "topP": 0.0, "maxTokenCount": 0, "stopSequences": []}
|
||||
}
|
||||
}
|
||||
|
||||
BEDROCK_PROVIDER_RESPONSE_BODY = {
|
||||
"mistral": {"outputs": [{"text": "Hello World", "stop_reason": ""}]},
|
||||
|
||||
"meta": {"generation": "Hello World", "prompt_token_count": 0, "generation_token_count": 0, "stop_reason": ""},
|
||||
|
||||
"ai21": {
|
||||
"id": "", "prompt": {"text": "Hello World", "tokens": []},
|
||||
"completions": [{"data": {"text": "Hello World", "tokens": []},
|
||||
"finishReason": {"reason": "length", "length": 2}}]
|
||||
},
|
||||
"cohere": {
|
||||
"generations": [
|
||||
{"finish_reason": "", "id": "", "text": "Hello World", "likelihood": 0.0,
|
||||
"token_likelihoods": [{"token": 0.0}], "is_finished": True, "index": 0}], "id": "", "prompt": ""
|
||||
},
|
||||
|
||||
"anthropic": {
|
||||
"id": "", "model": "", "type": "message", "role": "assistant", "content": [{"type": "text", "text": "Hello World"}],
|
||||
"stop_reason": "", "stop_sequence": "", "usage": {"input_tokens": 0, "output_tokens": 0}
|
||||
},
|
||||
|
||||
"amazon": {'inputTextTokenCount': 0, 'results': [{'tokenCount': 0, 'outputText': 'Hello World', 'completionReason': ""}]}
|
||||
}
|
||||
|
||||
BEDROCK_PROVIDER_STREAM_RESPONSE = {}
|
||||
|
|
|
|||
91
tests/metagpt/provider/test_amazon_bedrock_api.py
Normal file
91
tests/metagpt/provider/test_amazon_bedrock_api.py
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
import pytest
|
||||
import json
|
||||
from metagpt.provider.bedrock.amazon_bedrock_api import AmazonBedrockLLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config_bedrock
|
||||
from metagpt.provider.bedrock.utils import get_max_tokens, SUPPORT_STREAM_MODELS, NOT_SUUPORT_STREAM_MODELS
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
BEDROCK_PROVIDER_REQUEST_BODY,
|
||||
BEDROCK_PROVIDER_RESPONSE_BODY,
|
||||
BEDROCK_PROVIDER_STREAM_RESPONSE)
|
||||
from botocore.response import StreamingBody
|
||||
|
||||
# all available model from bedrock
|
||||
models = SUPPORT_STREAM_MODELS | NOT_SUUPORT_STREAM_MODELS
|
||||
messages = [{"role": "user", "content": "Hi!"}]
|
||||
|
||||
|
||||
def mock_bedrock_provider_response(self, *args, **kwargs) -> dict:
|
||||
provider = self.config.model.split(".")[0]
|
||||
return BEDROCK_PROVIDER_RESPONSE_BODY[provider]
|
||||
|
||||
|
||||
def mock_bedrock_provider_stream_response(self, *args, **kwargs) -> StreamingBody:
|
||||
provider = self.config.model.split(".")[0]
|
||||
response_json = BEDROCK_PROVIDER_STREAM_RESPONSE[provider]
|
||||
return
|
||||
|
||||
|
||||
def get_bedrock_request_body(model_id) -> dict:
|
||||
provider = model_id.split(".")[0]
|
||||
return BEDROCK_PROVIDER_REQUEST_BODY[provider]
|
||||
|
||||
|
||||
def is_subset(subset, superset):
|
||||
"""Ensure all fields in request body are allowed.
|
||||
|
||||
```python
|
||||
subset = {"prompt": "hello","kwargs": {"temperature": 0.9,"p": 0.0}}
|
||||
superset = {"prompt": "hello", "kwargs": {"temperature": 0.0, "top-p": 0.0}}
|
||||
is_subset(subset, superset)
|
||||
```
|
||||
>>>False
|
||||
"""
|
||||
for key, value in subset.items():
|
||||
if key not in superset:
|
||||
return False
|
||||
if isinstance(value, dict):
|
||||
if not isinstance(superset[key], dict):
|
||||
return False
|
||||
if not is_subset(value, superset[key]):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@ pytest.fixture(scope="class", params=models)
|
||||
def bedrock_api(request) -> AmazonBedrockLLM:
|
||||
model_id = request.param
|
||||
mock_llm_config_bedrock.model = model_id
|
||||
api = AmazonBedrockLLM(mock_llm_config_bedrock)
|
||||
return api
|
||||
|
||||
|
||||
class TestAPI:
|
||||
def test_generate_kwargs(self, bedrock_api: AmazonBedrockLLM):
|
||||
provider = bedrock_api._get_provider()
|
||||
assert bedrock_api._generate_kwargs[provider.max_tokens_field_name] <= get_max_tokens(
|
||||
bedrock_api.config.model)
|
||||
|
||||
def test_get_request_body(self, bedrock_api: AmazonBedrockLLM):
|
||||
provider = bedrock_api._get_provider()
|
||||
request_body = json.loads(provider.get_request_body(
|
||||
messages, **bedrock_api._generate_kwargs))
|
||||
print(get_bedrock_request_body(
|
||||
bedrock_api.config.model))
|
||||
print(request_body)
|
||||
|
||||
assert is_subset(request_body, get_bedrock_request_body(
|
||||
bedrock_api.config.model))
|
||||
|
||||
def test_completion(self, bedrock_api: AmazonBedrockLLM, mocker):
|
||||
mocker.patch(
|
||||
"metagpt.provider.bedrock.amazon_bedrock_api.AmazonBedrockLLM.invoke_model", mock_bedrock_provider_response)
|
||||
assert bedrock_api.completion(messages) == "Hello World"
|
||||
|
||||
# def test_stream_completion(self, bedrock_api: AmazonBedrockLLM, mocker):
|
||||
# mocker.patch(
|
||||
# "metagpt.provider.bedrock.amazon_bedrock_api.AmazonBedrockLLM.invoke_model_with_response_stream", mock_bedrock_provider_response)
|
||||
# assert bedrock_api._chat_completion_stream(messages) == "Hello World"
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(get_bedrock_request_body("amazon.titan-tg1-large"))
|
||||
Loading…
Add table
Add a link
Reference in a new issue