add titan

This commit is contained in:
usamimeri_renko 2024-04-25 21:16:22 +08:00
parent a6058ca629
commit f45a379183
3 changed files with 24 additions and 15 deletions

View file

@ -7,7 +7,7 @@ from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.provider.base_llm import BaseLLM
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.bedrock.bedrock_provider import get_provider
from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS
from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS, SUPPORT_STREAM_MODELS
import boto3
@ -86,13 +86,3 @@ class AmazonBedrockLLM(BaseLLM):
return self._chat_completion_stream(messages)
if __name__ == '__main__':
from .config import my_config
messages = [
{"role": "system", "content": "your name is Bob"},
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "hello,my friend"},
{"role": "user", "content": "What is your name?"}]
llm = AmazonBedrockLLM(my_config)
print(llm.completion(messages))
print(llm._chat_completion_stream(messages))

View file

@ -4,6 +4,11 @@ from abc import ABC, abstractmethod
class BaseBedrockProvider(ABC):
# to handle different generation kwargs
@abstractmethod
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
...
def get_request_body(self, messages, **generate_kwargs):
body = json.dumps(
{"prompt": self.messages_to_prompt(messages), **generate_kwargs})
@ -23,10 +28,6 @@ class BaseBedrockProvider(ABC):
response_body = json.loads(response["body"].read())
return response_body
@abstractmethod
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
...
def messages_to_prompt(self, messages: list[dict]):
"""[{"role": "user", "content": msg}] to user: <msg> etc."""
return "\n".join([f"{i['role']}: {i['content']}" for i in messages])

View file

@ -49,12 +49,30 @@ class Ai21Provider(BaseBedrockProvider):
return rsp_dict['completions'][0]["data"]["text"]
class AmazonProvider(BaseBedrockProvider):
def get_request_body(self, messages, **generate_kwargs):
body = json.dumps({
"inputText": self.messages_to_prompt(messages),
"textGenerationConfig": generate_kwargs
})
return body
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
return rsp_dict['results'][0]['outputText'].strip()
def get_choice_text_from_stream(self, event):
rsp_dict = json.loads(event["chunk"]["bytes"])
completions = rsp_dict["outputText"]
return completions
PROVIDERS = {
"mistral": MistralProvider(),
"meta": MetaProvider(),
"ai21": Ai21Provider(),
"cohere": CohereProvider(),
"anthropic": AnthropicProvider(),
"amazon": AmazonProvider()
}