mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-02 12:22:39 +02:00
remove opus since unavailable now and fix bug
This commit is contained in:
parent
a05d25757c
commit
98cb452911
5 changed files with 42 additions and 16 deletions
|
|
@ -86,7 +86,7 @@ class AmazonBedrockLLM(BaseLLM):
|
|||
|
||||
def completion(self, messages: list[dict]) -> str:
|
||||
request_body = self.__provider.get_request_body(
|
||||
messages, **self._generate_kwargs)
|
||||
messages, self._generate_kwargs)
|
||||
response_body = self.invoke_model(request_body)
|
||||
completions = self.__provider.get_choice_text(response_body)
|
||||
return completions
|
||||
|
|
@ -98,7 +98,7 @@ class AmazonBedrockLLM(BaseLLM):
|
|||
return self.completion(messages)
|
||||
|
||||
request_body = self.__provider.get_request_body(
|
||||
messages, **self._generate_kwargs)
|
||||
messages, self._generate_kwargs, stream=True)
|
||||
|
||||
response = self.invoke_model_with_response_stream(request_body)
|
||||
collected_content = []
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ class BaseBedrockProvider(ABC):
|
|||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
...
|
||||
|
||||
def get_request_body(self, messages: list[dict], **generate_kwargs):
|
||||
def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs):
|
||||
body = json.dumps(
|
||||
{"prompt": self.messages_to_prompt(messages), **generate_kwargs})
|
||||
return body
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ class MistralProvider(BaseBedrockProvider):
|
|||
class AnthropicProvider(BaseBedrockProvider):
|
||||
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
|
||||
def get_request_body(self, messages: list[dict], **generate_kwargs):
|
||||
def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs):
|
||||
body = json.dumps(
|
||||
{"messages": messages, "anthropic_version": "bedrock-2023-05-31", **generate_kwargs})
|
||||
return body
|
||||
|
|
@ -41,6 +41,16 @@ class CohereProvider(BaseBedrockProvider):
|
|||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
return rsp_dict["generations"][0]["text"]
|
||||
|
||||
def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs):
|
||||
body = json.dumps(
|
||||
{"prompt": self.messages_to_prompt(messages), "stream": kwargs.get("stream", False), **generate_kwargs})
|
||||
return body
|
||||
|
||||
def get_choice_text_from_stream(self, event) -> str:
|
||||
rsp_dict = json.loads(event["chunk"]["bytes"])
|
||||
completions = rsp_dict.get("text", "")
|
||||
return completions
|
||||
|
||||
|
||||
class MetaProvider(BaseBedrockProvider):
|
||||
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html
|
||||
|
|
@ -74,7 +84,7 @@ class AmazonProvider(BaseBedrockProvider):
|
|||
|
||||
max_tokens_field_name = "maxTokenCount"
|
||||
|
||||
def get_request_body(self, messages: list[dict], **generate_kwargs):
|
||||
def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs):
|
||||
body = json.dumps({
|
||||
"inputText": self.messages_to_prompt(messages),
|
||||
"textGenerationConfig": generate_kwargs
|
||||
|
|
|
|||
|
|
@ -14,15 +14,24 @@ SUPPORT_STREAM_MODELS = {
|
|||
"amazon.titan-tg1-large": 8000,
|
||||
"amazon.titan-text-express-v1": 8000,
|
||||
"anthropic.claude-instant-v1": 100000,
|
||||
"anthropic.claude-instant-v1:2:100k": 100000,
|
||||
"anthropic.claude-v1": 100000,
|
||||
"anthropic.claude-v2": 100000,
|
||||
"anthropic.claude-v2:1": 200000,
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0": 200000,
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:28k": 28000,
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:200k": 200000,
|
||||
"anthropic.claude-3-haiku-20240307-v1:0": 200000,
|
||||
"anthropic.claude-3-opus-20240229-v1:0": 200000,
|
||||
"cohere.command-text-v14": 4096,
|
||||
"cohere.command-light-text-v14": 4096,
|
||||
"meta.llama2-70b-v1": 4096,
|
||||
"anthropic.claude-3-haiku-20240307-v1:0:48k": 48000,
|
||||
"anthropic.claude-3-haiku-20240307-v1:0:200k": 200000,
|
||||
"cohere.command-text-v14": 4000,
|
||||
"cohere.command-text-v14:7:4k": 4000,
|
||||
"cohere.command-light-text-v14": 4000,
|
||||
"cohere.command-light-text-v14:7:4k": 4000,
|
||||
"meta.llama2-70b-v1": 4000,
|
||||
"meta.llama2-13b-chat-v1:0:4k": 4000,
|
||||
"meta.llama2-13b-chat-v1": 2000,
|
||||
"meta.llama2-70b-v1:0:4k": 4000,
|
||||
"meta.llama3-8b-instruct-v1:0": 2000,
|
||||
"meta.llama3-70b-instruct-v1:0": 2000,
|
||||
"mistral.mistral-7b-instruct-v0:2": 32000,
|
||||
|
|
|
|||
|
|
@ -20,14 +20,21 @@ def mock_bedrock_provider_stream_response(self, *args, **kwargs) -> dict:
|
|||
def dict2bytes(x):
|
||||
return json.dumps(x).encode("utf-8")
|
||||
provider = self.config.model.split(".")[0]
|
||||
response_body_bytes = dict2bytes(BEDROCK_PROVIDER_RESPONSE_BODY[provider])
|
||||
# decoded bytes share the same format as non-stream response_body except for titan
|
||||
|
||||
if provider == "amazon":
|
||||
response_body_stream = {
|
||||
"body": [{'chunk': {'bytes': dict2bytes({"outputText": "Hello World"})}}]}
|
||||
response_body_bytes = dict2bytes({"outputText": "Hello World"})
|
||||
elif provider == "anthropic":
|
||||
response_body_bytes = dict2bytes({"type": "content_block_delta", "index": 0,
|
||||
"delta": {"type": "text_delta", "text": "Hello World"}})
|
||||
elif provider == "cohere":
|
||||
response_body_bytes = dict2bytes(
|
||||
{"is_finished": False, "text": "Hello World"})
|
||||
else:
|
||||
response_body_stream = {
|
||||
"body": [{'chunk': {'bytes': response_body_bytes}}]}
|
||||
response_body_bytes = dict2bytes(
|
||||
BEDROCK_PROVIDER_RESPONSE_BODY[provider])
|
||||
|
||||
response_body_stream = {
|
||||
"body": [{'chunk': {'bytes': response_body_bytes}}]}
|
||||
return response_body_stream
|
||||
|
||||
|
||||
|
|
@ -75,7 +82,7 @@ class TestAPI:
|
|||
"""Ensure request body has correct format"""
|
||||
provider = bedrock_api._get_provider()
|
||||
request_body = json.loads(provider.get_request_body(
|
||||
messages, **bedrock_api._generate_kwargs))
|
||||
messages, bedrock_api._generate_kwargs))
|
||||
|
||||
assert is_subset(request_body, get_bedrock_request_body(
|
||||
bedrock_api.config.model))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue