diff --git a/metagpt/provider/bedrock/amazon_bedrock_api.py b/metagpt/provider/bedrock/amazon_bedrock_api.py index 123495da5..b6d12b8d9 100644 --- a/metagpt/provider/bedrock/amazon_bedrock_api.py +++ b/metagpt/provider/bedrock/amazon_bedrock_api.py @@ -1,15 +1,14 @@ import json -from typing import Coroutine, Literal +from typing import Literal from metagpt.const import USE_CONFIG_TIMEOUT from metagpt.provider.llm_provider_registry import register_provider from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.provider.base_llm import BaseLLM from metagpt.logs import log_llm_stream, logger -from botocore.config import Config -import boto3 - from metagpt.provider.bedrock.bedrock_provider import get_provider +from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS +import boto3 @register_provider([LLMType.AMAZON_BEDROCK]) @@ -40,8 +39,9 @@ class AmazonBedrockLLM(BaseLLM): @property def _generate_kwargs(self): + # for now only use temperature due to the difference of request body return { - "temperature": self.config.get("temperature", 0.3), + "temperature": self.config.get("temperature", 0.1), } def completion(self, messages: list[dict]): @@ -51,10 +51,14 @@ class AmazonBedrockLLM(BaseLLM): modelId=self.config.model, body=request_body ) completions = self.provider.get_choice_text(response) - log_llm_stream(completions) return completions def _chat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): + if self.config.model in NOT_SUUPORT_STREAM_MODELS: + logger.warning( + f"model {self.config.model} doesn't support streaming output!") + return self.completion(messages) + request_body = self.provider.get_request_body( messages, **self._generate_kwargs) response = self.__client.invoke_model_with_response_stream( @@ -90,5 +94,5 @@ if __name__ == '__main__': {"role": "assistant", "content": "hello,my friend"}, {"role": "user", "content": "What is your name?"}] llm = AmazonBedrockLLM(my_config) - llm.completion(messages) - llm._chat_completion_stream(messages) + print(llm.completion(messages)) + print(llm._chat_completion_stream(messages)) diff --git a/metagpt/provider/bedrock/base_provider.py b/metagpt/provider/bedrock/base_provider.py index 2a17c335b..4a96192d9 100644 --- a/metagpt/provider/bedrock/base_provider.py +++ b/metagpt/provider/bedrock/base_provider.py @@ -1,7 +1,8 @@ import json +from abc import ABC, abstractmethod -class BaseBedrockProvider(object): +class BaseBedrockProvider(ABC): # to handle different generation kwargs def get_request_body(self, messages, **generate_kwargs): body = json.dumps( @@ -10,17 +11,22 @@ class BaseBedrockProvider(object): def get_choice_text(self, response) -> str: response_body = self._get_response_body_json(response) - completions = response_body["outputs"][0]['text'] + completions = self._get_completion_from_dict(response_body) return completions def get_choice_text_from_stream(self, event): - completions = json.loads(event["chunk"]["bytes"])["outputs"][0]["text"] + rsp_dict = json.loads(event["chunk"]["bytes"]) + completions = self._get_completion_from_dict(rsp_dict) return completions def _get_response_body_json(self, response): response_body = json.loads(response["body"].read()) return response_body + @abstractmethod + def _get_completion_from_dict(self, rsp_dict: dict) -> str: + ... + def messages_to_prompt(self, messages: list[dict]): """[{"role": "user", "content": msg}] to user: etc.""" return "\n".join([f"{i['role']}: {i['content']}" for i in messages]) diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py index bbf3de223..d0fe42725 100644 --- a/metagpt/provider/bedrock/bedrock_provider.py +++ b/metagpt/provider/bedrock/bedrock_provider.py @@ -5,96 +5,56 @@ from metagpt.provider.bedrock.utils import messages_to_prompt_llama class MistralProvider(BaseBedrockProvider): # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral.html + def messages_to_prompt(self, messages: list[dict]): return messages_to_prompt_llama(messages) + def _get_completion_from_dict(self, rsp_dict: dict) -> str: + return rsp_dict["outputs"][0]["text"] + class AnthropicProvider(BaseBedrockProvider): # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html + def get_request_body(self, messages, **generate_kwargs): body = json.dumps( {"messages": messages, "anthropic_version": "bedrock-2023-05-31", **generate_kwargs}) return body - def get_choice_text(self, response) -> str: - response_body = self._get_response_body_json(response) - completions = response_body["content"][0]['text'] - return completions - - def get_choice_text_from_stream(self, event): - completions = json.loads(event["chunk"]["bytes"])["content"][0]["text"] - return completions + def _get_completion_from_dict(self, rsp_dict: dict) -> str: + return rsp_dict["content"][0]["text"] class CohereProvider(BaseBedrockProvider): - pass + # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command.html + + def _get_completion_from_dict(self, rsp_dict: dict) -> str: + return rsp_dict["generations"][0]["text"] class MetaProvider(BaseBedrockProvider): + # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html + def messages_to_prompt(self, messages: list[dict]): return messages_to_prompt_llama(messages) - def get_choice_text(self, response) -> str: - response_body = self._get_response_body_json(response) - completions = response_body['generation'] - return completions - - def get_choice_text_from_stream(self, event): - completions = json.loads(event["chunk"]["bytes"])["generation"] - return completions + def _get_completion_from_dict(self, rsp_dict: dict) -> str: + return rsp_dict["generation"] class Ai21Provider(BaseBedrockProvider): - pass + # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-jurassic2.html + + def _get_completion_from_dict(self, rsp_dict: dict) -> str: + return rsp_dict['completions'][0]["data"]["text"] PROVIDERS = { "mistral": MistralProvider(), "meta": MetaProvider(), -} - -NOT_SUUPORT_STREAM_MODELS = { - "ai21.j2-grande-instruct", - "ai21.j2-jumbo-instruct", - "ai21.j2-mid", - "ai21.j2-mid-v1", - "ai21.j2-ultra", - "ai21.j2-ultra-v1", -} - -SUPPORT_STREAM_MODELS = { - "amazon.titan-tg1-large", - "amazon.titan-text-lite-v1:0:4k", - "amazon.titan-text-lite-v1", - "amazon.titan-text-express-v1:0:8k", - "amazon.titan-text-express-v1", - "anthropic.claude-instant-v1:2:100k", - "anthropic.claude-instant-v1", - "anthropic.claude-v2:0:18k", - "anthropic.claude-v2:0:100k", - "anthropic.claude-v2:1:18k", - "anthropic.claude-v2:1:200k", - "anthropic.claude-v2:1", - "anthropic.claude-v2:2:18k", - "anthropic.claude-v2:2:200k", - "anthropic.claude-v2:2", - "anthropic.claude-v2", - "anthropic.claude-3-sonnet-20240229-v1:0:28k", - "anthropic.claude-3-sonnet-20240229-v1:0:200k", - "anthropic.claude-3-sonnet-20240229-v1:0", - "anthropic.claude-3-haiku-20240307-v1:0:48k", - "anthropic.claude-3-haiku-20240307-v1:0:200k", - "anthropic.claude-3-haiku-20240307-v1:0", - "cohere.command-text-v14:7:4k", - "cohere.command-text-v14", - "cohere.command-light-text-v14:7:4k", - "cohere.command-light-text-v14", - "meta.llama2-70b-v1", - "meta.llama3-8b-instruct-v1:0", - "meta.llama3-70b-instruct-v1:0", - "mistral.mistral-7b-instruct-v0:2", - "mistral.mixtral-8x7b-instruct-v0:1", - "mistral.mistral-large-2402-v1:0", + "ai21": Ai21Provider(), + "cohere": CohereProvider(), + "anthropic": AnthropicProvider(), } diff --git a/metagpt/provider/bedrock/utils.py b/metagpt/provider/bedrock/utils.py index 7352a15a0..57b83681c 100644 --- a/metagpt/provider/bedrock/utils.py +++ b/metagpt/provider/bedrock/utils.py @@ -23,3 +23,47 @@ def messages_to_prompt_llama(messages: list[dict]): return prompt + +NOT_SUUPORT_STREAM_MODELS = { + "ai21.j2-grande-instruct", + "ai21.j2-jumbo-instruct", + "ai21.j2-mid", + "ai21.j2-mid-v1", + "ai21.j2-ultra", + "ai21.j2-ultra-v1", +} + +SUPPORT_STREAM_MODELS = { + "amazon.titan-tg1-large", + "amazon.titan-text-lite-v1:0:4k", + "amazon.titan-text-lite-v1", + "amazon.titan-text-express-v1:0:8k", + "amazon.titan-text-express-v1", + "anthropic.claude-instant-v1:2:100k", + "anthropic.claude-instant-v1", + "anthropic.claude-v2:0:18k", + "anthropic.claude-v2:0:100k", + "anthropic.claude-v2:1:18k", + "anthropic.claude-v2:1:200k", + "anthropic.claude-v2:1", + "anthropic.claude-v2:2:18k", + "anthropic.claude-v2:2:200k", + "anthropic.claude-v2:2", + "anthropic.claude-v2", + "anthropic.claude-3-sonnet-20240229-v1:0:28k", + "anthropic.claude-3-sonnet-20240229-v1:0:200k", + "anthropic.claude-3-sonnet-20240229-v1:0", + "anthropic.claude-3-haiku-20240307-v1:0:48k", + "anthropic.claude-3-haiku-20240307-v1:0:200k", + "anthropic.claude-3-haiku-20240307-v1:0", + "cohere.command-text-v14:7:4k", + "cohere.command-text-v14", + "cohere.command-light-text-v14:7:4k", + "cohere.command-light-text-v14", + "meta.llama2-70b-v1", + "meta.llama3-8b-instruct-v1:0", + "meta.llama3-70b-instruct-v1:0", + "mistral.mistral-7b-instruct-v0:2", + "mistral.mixtral-8x7b-instruct-v0:1", + "mistral.mistral-large-2402-v1:0", +}