From f45a379183fe19eec7fcd2e1288882691bac2a01 Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Thu, 25 Apr 2024 21:16:22 +0800 Subject: [PATCH] add titan --- metagpt/provider/bedrock/amazon_bedrock_api.py | 12 +----------- metagpt/provider/bedrock/base_provider.py | 9 +++++---- metagpt/provider/bedrock/bedrock_provider.py | 18 ++++++++++++++++++ 3 files changed, 24 insertions(+), 15 deletions(-) diff --git a/metagpt/provider/bedrock/amazon_bedrock_api.py b/metagpt/provider/bedrock/amazon_bedrock_api.py index 7d615ec5e..3b2ea0b81 100644 --- a/metagpt/provider/bedrock/amazon_bedrock_api.py +++ b/metagpt/provider/bedrock/amazon_bedrock_api.py @@ -7,7 +7,7 @@ from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.provider.base_llm import BaseLLM from metagpt.logs import log_llm_stream, logger from metagpt.provider.bedrock.bedrock_provider import get_provider -from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS +from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS, SUPPORT_STREAM_MODELS import boto3 @@ -86,13 +86,3 @@ class AmazonBedrockLLM(BaseLLM): return self._chat_completion_stream(messages) -if __name__ == '__main__': - from .config import my_config - messages = [ - {"role": "system", "content": "your name is Bob"}, - {"role": "user", "content": "hello"}, - {"role": "assistant", "content": "hello,my friend"}, - {"role": "user", "content": "What is your name?"}] - llm = AmazonBedrockLLM(my_config) - print(llm.completion(messages)) - print(llm._chat_completion_stream(messages)) diff --git a/metagpt/provider/bedrock/base_provider.py b/metagpt/provider/bedrock/base_provider.py index 4a96192d9..c591549ce 100644 --- a/metagpt/provider/bedrock/base_provider.py +++ b/metagpt/provider/bedrock/base_provider.py @@ -4,6 +4,11 @@ from abc import ABC, abstractmethod class BaseBedrockProvider(ABC): # to handle different generation kwargs + + @abstractmethod + def _get_completion_from_dict(self, rsp_dict: dict) -> str: + ... + def get_request_body(self, messages, **generate_kwargs): body = json.dumps( {"prompt": self.messages_to_prompt(messages), **generate_kwargs}) @@ -23,10 +28,6 @@ class BaseBedrockProvider(ABC): response_body = json.loads(response["body"].read()) return response_body - @abstractmethod - def _get_completion_from_dict(self, rsp_dict: dict) -> str: - ... - def messages_to_prompt(self, messages: list[dict]): """[{"role": "user", "content": msg}] to user: etc.""" return "\n".join([f"{i['role']}: {i['content']}" for i in messages]) diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py index d0fe42725..e2dba9223 100644 --- a/metagpt/provider/bedrock/bedrock_provider.py +++ b/metagpt/provider/bedrock/bedrock_provider.py @@ -49,12 +49,30 @@ class Ai21Provider(BaseBedrockProvider): return rsp_dict['completions'][0]["data"]["text"] +class AmazonProvider(BaseBedrockProvider): + def get_request_body(self, messages, **generate_kwargs): + body = json.dumps({ + "inputText": self.messages_to_prompt(messages), + "textGenerationConfig": generate_kwargs + }) + return body + + def _get_completion_from_dict(self, rsp_dict: dict) -> str: + return rsp_dict['results'][0]['outputText'].strip() + + def get_choice_text_from_stream(self, event): + rsp_dict = json.loads(event["chunk"]["bytes"]) + completions = rsp_dict["outputText"] + return completions + + PROVIDERS = { "mistral": MistralProvider(), "meta": MetaProvider(), "ai21": Ai21Provider(), "cohere": CohereProvider(), "anthropic": AnthropicProvider(), + "amazon": AmazonProvider() }