mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-24 14:15:17 +02:00
add pre-commit
This commit is contained in:
parent
3f108abd06
commit
f14a1f63ef
7 changed files with 132 additions and 97 deletions
|
|
@ -100,4 +100,3 @@ class LLMConfig(YamlModel):
|
|||
@classmethod
|
||||
def check_timeout(cls, v):
|
||||
return v or LLM_API_TIMEOUT
|
||||
|
||||
|
|
|
|||
|
|
@ -31,5 +31,5 @@ __all__ = [
|
|||
"QianFanLLM",
|
||||
"DashScopeLLM",
|
||||
"AnthropicLLM",
|
||||
"BedrockLLM"
|
||||
"BedrockLLM",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -11,8 +11,7 @@ class BaseBedrockProvider(ABC):
|
|||
...
|
||||
|
||||
def get_request_body(self, messages: list[dict], const_kwargs, *args, **kwargs) -> str:
|
||||
body = json.dumps(
|
||||
{"prompt": self.messages_to_prompt(messages), **const_kwargs})
|
||||
body = json.dumps({"prompt": self.messages_to_prompt(messages), **const_kwargs})
|
||||
return body
|
||||
|
||||
def get_choice_text(self, response_body: dict) -> str:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,11 @@
|
|||
import json
|
||||
from typing import Literal
|
||||
|
||||
from metagpt.provider.bedrock.base_provider import BaseBedrockProvider
|
||||
from metagpt.provider.bedrock.utils import messages_to_prompt_llama2, messages_to_prompt_llama3
|
||||
from metagpt.provider.bedrock.utils import (
|
||||
messages_to_prompt_llama2,
|
||||
messages_to_prompt_llama3,
|
||||
)
|
||||
|
||||
|
||||
class MistralProvider(BaseBedrockProvider):
|
||||
|
|
@ -18,8 +22,7 @@ class AnthropicProvider(BaseBedrockProvider):
|
|||
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
|
||||
def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs):
|
||||
body = json.dumps(
|
||||
{"messages": messages, "anthropic_version": "bedrock-2023-05-31", **generate_kwargs})
|
||||
body = json.dumps({"messages": messages, "anthropic_version": "bedrock-2023-05-31", **generate_kwargs})
|
||||
return body
|
||||
|
||||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
|
|
@ -43,7 +46,8 @@ class CohereProvider(BaseBedrockProvider):
|
|||
|
||||
def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs):
|
||||
body = json.dumps(
|
||||
{"prompt": self.messages_to_prompt(messages), "stream": kwargs.get("stream", False), **generate_kwargs})
|
||||
{"prompt": self.messages_to_prompt(messages), "stream": kwargs.get("stream", False), **generate_kwargs}
|
||||
)
|
||||
return body
|
||||
|
||||
def get_choice_text_from_stream(self, event) -> str:
|
||||
|
|
@ -85,10 +89,7 @@ class AmazonProvider(BaseBedrockProvider):
|
|||
max_tokens_field_name = "maxTokenCount"
|
||||
|
||||
def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs):
|
||||
body = json.dumps({
|
||||
"inputText": self.messages_to_prompt(messages),
|
||||
"textGenerationConfig": generate_kwargs
|
||||
})
|
||||
body = json.dumps({"inputText": self.messages_to_prompt(messages), "textGenerationConfig": generate_kwargs})
|
||||
return body
|
||||
|
||||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
|
|
@ -106,7 +107,7 @@ PROVIDERS = {
|
|||
"ai21": Ai21Provider,
|
||||
"cohere": CohereProvider,
|
||||
"anthropic": AnthropicProvider,
|
||||
"amazon": AmazonProvider
|
||||
"amazon": AmazonProvider,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,15 +1,17 @@
|
|||
from typing import Literal
|
||||
import json
|
||||
from metagpt.const import USE_CONFIG_TIMEOUT
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.bedrock.bedrock_provider import get_provider
|
||||
from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS, get_max_tokens
|
||||
from typing import Literal
|
||||
|
||||
import boto3
|
||||
from botocore.eventstream import EventStream
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.const import USE_CONFIG_TIMEOUT
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.bedrock.bedrock_provider import get_provider
|
||||
from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS, get_max_tokens
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
|
||||
|
||||
@register_provider([LLMType.BEDROCK])
|
||||
class BedrockLLM(BaseLLM):
|
||||
|
|
@ -17,8 +19,7 @@ class BedrockLLM(BaseLLM):
|
|||
self.config = config
|
||||
self.__client = self.__init_client("bedrock-runtime")
|
||||
self.__provider = get_provider(self.config.model)
|
||||
logger.warning(
|
||||
"Amazon bedrock doesn't support asynchronous now")
|
||||
logger.warning("Amazon bedrock doesn't support asynchronous now")
|
||||
|
||||
def __init_client(self, service_name: Literal["bedrock-runtime", "bedrock"]):
|
||||
"""initialize boto3 client"""
|
||||
|
|
@ -26,7 +27,7 @@ class BedrockLLM(BaseLLM):
|
|||
self.__credentital_kwargs = {
|
||||
"aws_secret_access_key": self.config.secret_key,
|
||||
"aws_access_key_id": self.config.access_key,
|
||||
"region_name": self.config.region_name
|
||||
"region_name": self.config.region_name,
|
||||
}
|
||||
session = boto3.Session(**self.__credentital_kwargs)
|
||||
client = session.client(service_name)
|
||||
|
|
@ -52,22 +53,21 @@ class BedrockLLM(BaseLLM):
|
|||
client = self.__init_client("bedrock")
|
||||
# only output text-generation models
|
||||
response = client.list_foundation_models(byOutputModality="TEXT")
|
||||
summaries = [f'{summary["modelId"]:50} Support Streaming:{summary["responseStreamingSupported"]}'
|
||||
for summary in response["modelSummaries"]]
|
||||
logger.info("\n"+"\n".join(summaries))
|
||||
summaries = [
|
||||
f'{summary["modelId"]:50} Support Streaming:{summary["responseStreamingSupported"]}'
|
||||
for summary in response["modelSummaries"]
|
||||
]
|
||||
logger.info("\n" + "\n".join(summaries))
|
||||
|
||||
def invoke_model(self, request_body: str) -> dict:
|
||||
response = self.__client.invoke_model(
|
||||
modelId=self.config.model, body=request_body
|
||||
)
|
||||
response = self.__client.invoke_model(modelId=self.config.model, body=request_body)
|
||||
usage = self._get_usage(response)
|
||||
self._update_costs(usage)
|
||||
response_body = self._get_response_body(response)
|
||||
return response_body
|
||||
|
||||
def invoke_model_with_response_stream(self, request_body: str) -> EventStream:
|
||||
response = self.__client.invoke_model_with_response_stream(
|
||||
modelId=self.config.model, body=request_body)
|
||||
response = self.__client.invoke_model_with_response_stream(modelId=self.config.model, body=request_body)
|
||||
usage = self._get_usage(response)
|
||||
self._update_costs(usage)
|
||||
return response
|
||||
|
|
@ -80,26 +80,20 @@ class BedrockLLM(BaseLLM):
|
|||
else:
|
||||
max_tokens = self.config.max_token
|
||||
|
||||
return {
|
||||
self.__provider.max_tokens_field_name: max_tokens,
|
||||
"temperature": self.config.temperature
|
||||
}
|
||||
return {self.__provider.max_tokens_field_name: max_tokens, "temperature": self.config.temperature}
|
||||
|
||||
def completion(self, messages: list[dict]) -> str:
|
||||
request_body = self.__provider.get_request_body(
|
||||
messages, self._const_kwargs)
|
||||
request_body = self.__provider.get_request_body(messages, self._const_kwargs)
|
||||
response_body = self.invoke_model(request_body)
|
||||
completions = self.__provider.get_choice_text(response_body)
|
||||
return completions
|
||||
|
||||
def _chat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str:
|
||||
if self.config.model in NOT_SUUPORT_STREAM_MODELS:
|
||||
logger.warning(
|
||||
f"model {self.config.model} doesn't support streaming output!")
|
||||
logger.warning(f"model {self.config.model} doesn't support streaming output!")
|
||||
return self.completion(messages)
|
||||
|
||||
request_body = self.__provider.get_request_body(
|
||||
messages, self._const_kwargs, stream=True)
|
||||
request_body = self.__provider.get_request_body(messages, self._const_kwargs, stream=True)
|
||||
|
||||
response = self.invoke_model_with_response_stream(request_body)
|
||||
collected_content = []
|
||||
|
|
@ -134,8 +128,10 @@ class BedrockLLM(BaseLLM):
|
|||
headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {})
|
||||
prompt_tokens = int(headers.get("x-amzn-bedrock-input-token-count", 0))
|
||||
completion_tokens = int(headers.get("x-amzn-bedrock-output-token-count", 0))
|
||||
usage = {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
},
|
||||
usage = (
|
||||
{
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
},
|
||||
)
|
||||
return usage
|
||||
|
|
|
|||
|
|
@ -65,20 +65,17 @@ def get_openai_chat_completion(name: str) -> ChatCompletion:
|
|||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant", content=resp_cont_tmpl.format(name=name)),
|
||||
message=ChatCompletionMessage(role="assistant", content=resp_cont_tmpl.format(name=name)),
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
usage=CompletionUsage(completion_tokens=110,
|
||||
prompt_tokens=92, total_tokens=202),
|
||||
usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202),
|
||||
)
|
||||
return openai_chat_completion
|
||||
|
||||
|
||||
def get_openai_chat_completion_chunk(name: str, usage_as_dict: bool = False) -> ChatCompletionChunk:
|
||||
usage = CompletionUsage(completion_tokens=110,
|
||||
prompt_tokens=92, total_tokens=202)
|
||||
usage = CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202)
|
||||
usage = usage if not usage_as_dict else usage.model_dump()
|
||||
openai_chat_completion_chunk = ChatCompletionChunk(
|
||||
id="cmpl-a6652c1bb181caae8dd19ad8",
|
||||
|
|
@ -87,8 +84,7 @@ def get_openai_chat_completion_chunk(name: str, usage_as_dict: bool = False) ->
|
|||
created=1703300855,
|
||||
choices=[
|
||||
AChoice(
|
||||
delta=ChoiceDelta(role="assistant",
|
||||
content=resp_cont_tmpl.format(name=name)),
|
||||
delta=ChoiceDelta(role="assistant", content=resp_cont_tmpl.format(name=name)),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
logprobs=None,
|
||||
|
|
@ -137,8 +133,7 @@ def get_dashscope_response(name: str) -> GenerationResponse:
|
|||
],
|
||||
}
|
||||
),
|
||||
usage=GenerationUsage(
|
||||
**{"input_tokens": 12, "output_tokens": 98, "total_tokens": 110}),
|
||||
usage=GenerationUsage(**{"input_tokens": 12, "output_tokens": 98, "total_tokens": 110}),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -170,8 +165,7 @@ def get_anthropic_response(name: str, stream: bool = False) -> Message:
|
|||
model=name,
|
||||
role="assistant",
|
||||
type="message",
|
||||
content=[ContentBlock(
|
||||
text=resp_cont_tmpl.format(name=name), type="text")],
|
||||
content=[ContentBlock(text=resp_cont_tmpl.format(name=name), type="text")],
|
||||
usage=AnthropicUsage(input_tokens=10, output_tokens=10),
|
||||
)
|
||||
|
||||
|
|
@ -198,43 +192,81 @@ BEDROCK_PROVIDER_REQUEST_BODY = {
|
|||
"mistral": {"prompt": "", "max_tokens": 0, "stop": [], "temperature": 0.0, "top_p": 0.0, "top_k": 0},
|
||||
"meta": {"prompt": "", "temperature": 0.0, "top_p": 0.0, "max_gen_len": 0},
|
||||
"ai21": {
|
||||
"prompt": "", "temperature": 0.0, "topP": 0.0, "maxTokens": 0,
|
||||
"stopSequences": [], "countPenalty": {"scale": 0.0},
|
||||
"presencePenalty": {"scale": 0.0}, "frequencyPenalty": {"scale": 0.0}
|
||||
"prompt": "",
|
||||
"temperature": 0.0,
|
||||
"topP": 0.0,
|
||||
"maxTokens": 0,
|
||||
"stopSequences": [],
|
||||
"countPenalty": {"scale": 0.0},
|
||||
"presencePenalty": {"scale": 0.0},
|
||||
"frequencyPenalty": {"scale": 0.0},
|
||||
},
|
||||
"cohere": {
|
||||
"prompt": "", "temperature": 0.0, "p": 0.0, "k": 0.0, "max_tokens": 0, "stop_sequences": [],
|
||||
"return_likelihoods": "NONE", "stream": False, "num_generations": 0, "logit_bias": {}, "truncate": "NONE"
|
||||
"prompt": "",
|
||||
"temperature": 0.0,
|
||||
"p": 0.0,
|
||||
"k": 0.0,
|
||||
"max_tokens": 0,
|
||||
"stop_sequences": [],
|
||||
"return_likelihoods": "NONE",
|
||||
"stream": False,
|
||||
"num_generations": 0,
|
||||
"logit_bias": {},
|
||||
"truncate": "NONE",
|
||||
},
|
||||
"anthropic": {
|
||||
"anthropic_version": "bedrock-2023-05-31", "max_tokens": 0, "system": "",
|
||||
"messages": [{"role": "", "content": ""}], "temperature": 0.0, "top_p": 0.0, "top_k": 0, "stop_sequences": []
|
||||
"anthropic_version": "bedrock-2023-05-31",
|
||||
"max_tokens": 0,
|
||||
"system": "",
|
||||
"messages": [{"role": "", "content": ""}],
|
||||
"temperature": 0.0,
|
||||
"top_p": 0.0,
|
||||
"top_k": 0,
|
||||
"stop_sequences": [],
|
||||
},
|
||||
"amazon": {
|
||||
"inputText": "", "textGenerationConfig": {"temperature": 0.0, "topP": 0.0, "maxTokenCount": 0, "stopSequences": []}
|
||||
}
|
||||
"inputText": "",
|
||||
"textGenerationConfig": {"temperature": 0.0, "topP": 0.0, "maxTokenCount": 0, "stopSequences": []},
|
||||
},
|
||||
}
|
||||
|
||||
BEDROCK_PROVIDER_RESPONSE_BODY = {
|
||||
"mistral": {"outputs": [{"text": "Hello World", "stop_reason": ""}]},
|
||||
|
||||
"meta": {"generation": "Hello World", "prompt_token_count": 0, "generation_token_count": 0, "stop_reason": ""},
|
||||
|
||||
"ai21": {
|
||||
"id": "", "prompt": {"text": "Hello World", "tokens": []},
|
||||
"completions": [{"data": {"text": "Hello World", "tokens": []},
|
||||
"finishReason": {"reason": "length", "length": 2}}]
|
||||
"id": "",
|
||||
"prompt": {"text": "Hello World", "tokens": []},
|
||||
"completions": [
|
||||
{"data": {"text": "Hello World", "tokens": []}, "finishReason": {"reason": "length", "length": 2}}
|
||||
],
|
||||
},
|
||||
"cohere": {
|
||||
"generations": [
|
||||
{"finish_reason": "", "id": "", "text": "Hello World", "likelihood": 0.0,
|
||||
"token_likelihoods": [{"token": 0.0}], "is_finished": True, "index": 0}], "id": "", "prompt": ""
|
||||
{
|
||||
"finish_reason": "",
|
||||
"id": "",
|
||||
"text": "Hello World",
|
||||
"likelihood": 0.0,
|
||||
"token_likelihoods": [{"token": 0.0}],
|
||||
"is_finished": True,
|
||||
"index": 0,
|
||||
}
|
||||
],
|
||||
"id": "",
|
||||
"prompt": "",
|
||||
},
|
||||
|
||||
"anthropic": {
|
||||
"id": "", "model": "", "type": "message", "role": "assistant", "content": [{"type": "text", "text": "Hello World"}],
|
||||
"stop_reason": "", "stop_sequence": "", "usage": {"input_tokens": 0, "output_tokens": 0}
|
||||
"id": "",
|
||||
"model": "",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "Hello World"}],
|
||||
"stop_reason": "",
|
||||
"stop_sequence": "",
|
||||
"usage": {"input_tokens": 0, "output_tokens": 0},
|
||||
},
|
||||
"amazon": {
|
||||
"inputTextTokenCount": 0,
|
||||
"results": [{"tokenCount": 0, "outputText": "Hello World", "completionReason": ""}],
|
||||
},
|
||||
|
||||
"amazon": {"inputTextTokenCount": 0, "results": [{"tokenCount": 0, "outputText": "Hello World", "completionReason": ""}]}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,12 +1,21 @@
|
|||
import pytest
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.bedrock.utils import (
|
||||
NOT_SUUPORT_STREAM_MODELS,
|
||||
SUPPORT_STREAM_MODELS,
|
||||
get_max_tokens,
|
||||
)
|
||||
from metagpt.provider.bedrock_api import BedrockLLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config_bedrock
|
||||
from metagpt.provider.bedrock.utils import get_max_tokens, SUPPORT_STREAM_MODELS, NOT_SUUPORT_STREAM_MODELS
|
||||
from tests.metagpt.provider.req_resp_const import BEDROCK_PROVIDER_REQUEST_BODY, BEDROCK_PROVIDER_RESPONSE_BODY
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
BEDROCK_PROVIDER_REQUEST_BODY,
|
||||
BEDROCK_PROVIDER_RESPONSE_BODY,
|
||||
)
|
||||
|
||||
# all available model from bedrock
|
||||
models = (SUPPORT_STREAM_MODELS | NOT_SUUPORT_STREAM_MODELS)
|
||||
models = SUPPORT_STREAM_MODELS | NOT_SUUPORT_STREAM_MODELS
|
||||
messages = [{"role": "user", "content": "Hi!"}]
|
||||
|
||||
|
||||
|
|
@ -19,22 +28,21 @@ def mock_bedrock_provider_stream_response(self, *args, **kwargs) -> dict:
|
|||
# use json object to mock EventStream
|
||||
def dict2bytes(x):
|
||||
return json.dumps(x).encode("utf-8")
|
||||
|
||||
provider = self.config.model.split(".")[0]
|
||||
|
||||
if provider == "amazon":
|
||||
response_body_bytes = dict2bytes({"outputText": "Hello World"})
|
||||
elif provider == "anthropic":
|
||||
response_body_bytes = dict2bytes({"type": "content_block_delta", "index": 0,
|
||||
"delta": {"type": "text_delta", "text": "Hello World"}})
|
||||
response_body_bytes = dict2bytes(
|
||||
{"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello World"}}
|
||||
)
|
||||
elif provider == "cohere":
|
||||
response_body_bytes = dict2bytes(
|
||||
{"is_finished": False, "text": "Hello World"})
|
||||
response_body_bytes = dict2bytes({"is_finished": False, "text": "Hello World"})
|
||||
else:
|
||||
response_body_bytes = dict2bytes(
|
||||
BEDROCK_PROVIDER_RESPONSE_BODY[provider])
|
||||
response_body_bytes = dict2bytes(BEDROCK_PROVIDER_RESPONSE_BODY[provider])
|
||||
|
||||
response_body_stream = {
|
||||
"body": [{"chunk": {"bytes": response_body_bytes}}]}
|
||||
response_body_stream = {"body": [{"chunk": {"bytes": response_body_bytes}}]}
|
||||
return response_body_stream
|
||||
|
||||
|
||||
|
|
@ -77,19 +85,19 @@ class TestBedrockAPI:
|
|||
mocker.patch("metagpt.provider.bedrock_api.BedrockLLM.invoke_model", mock_bedrock_provider_response)
|
||||
|
||||
def _patch_invoke_model_stream(self, mocker):
|
||||
mocker.patch("metagpt.provider.bedrock_api.BedrockLLM.invoke_model_with_response_stream",
|
||||
mock_bedrock_provider_stream_response)
|
||||
mocker.patch(
|
||||
"metagpt.provider.bedrock_api.BedrockLLM.invoke_model_with_response_stream",
|
||||
mock_bedrock_provider_stream_response,
|
||||
)
|
||||
|
||||
def test_const_kwargs(self, bedrock_api: BedrockLLM):
|
||||
provider = bedrock_api.provider
|
||||
assert bedrock_api._const_kwargs[provider.max_tokens_field_name] <= get_max_tokens(
|
||||
bedrock_api.config.model)
|
||||
assert bedrock_api._const_kwargs[provider.max_tokens_field_name] <= get_max_tokens(bedrock_api.config.model)
|
||||
|
||||
def test_get_request_body(self, bedrock_api: BedrockLLM):
|
||||
"""Ensure request body has correct format"""
|
||||
provider = bedrock_api.provider
|
||||
request_body = json.loads(provider.get_request_body(
|
||||
messages, bedrock_api._const_kwargs))
|
||||
request_body = json.loads(provider.get_request_body(messages, bedrock_api._const_kwargs))
|
||||
|
||||
assert is_subset(request_body, get_bedrock_request_body(bedrock_api.config.model))
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue