mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-05 14:55:18 +02:00
add titan
This commit is contained in:
parent
a6058ca629
commit
f45a379183
3 changed files with 24 additions and 15 deletions
|
|
@ -7,7 +7,7 @@ from metagpt.configs.llm_config import LLMConfig, LLMType
|
|||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.bedrock.bedrock_provider import get_provider
|
||||
from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS
|
||||
from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS, SUPPORT_STREAM_MODELS
|
||||
import boto3
|
||||
|
||||
|
||||
|
|
@ -86,13 +86,3 @@ class AmazonBedrockLLM(BaseLLM):
|
|||
return self._chat_completion_stream(messages)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from .config import my_config
|
||||
messages = [
|
||||
{"role": "system", "content": "your name is Bob"},
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "hello,my friend"},
|
||||
{"role": "user", "content": "What is your name?"}]
|
||||
llm = AmazonBedrockLLM(my_config)
|
||||
print(llm.completion(messages))
|
||||
print(llm._chat_completion_stream(messages))
|
||||
|
|
|
|||
|
|
@ -4,6 +4,11 @@ from abc import ABC, abstractmethod
|
|||
|
||||
class BaseBedrockProvider(ABC):
|
||||
# to handle different generation kwargs
|
||||
|
||||
@abstractmethod
|
||||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
...
|
||||
|
||||
def get_request_body(self, messages, **generate_kwargs):
|
||||
body = json.dumps(
|
||||
{"prompt": self.messages_to_prompt(messages), **generate_kwargs})
|
||||
|
|
@ -23,10 +28,6 @@ class BaseBedrockProvider(ABC):
|
|||
response_body = json.loads(response["body"].read())
|
||||
return response_body
|
||||
|
||||
@abstractmethod
|
||||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
...
|
||||
|
||||
def messages_to_prompt(self, messages: list[dict]):
|
||||
"""[{"role": "user", "content": msg}] to user: <msg> etc."""
|
||||
return "\n".join([f"{i['role']}: {i['content']}" for i in messages])
|
||||
|
|
|
|||
|
|
@ -49,12 +49,30 @@ class Ai21Provider(BaseBedrockProvider):
|
|||
return rsp_dict['completions'][0]["data"]["text"]
|
||||
|
||||
|
||||
class AmazonProvider(BaseBedrockProvider):
|
||||
def get_request_body(self, messages, **generate_kwargs):
|
||||
body = json.dumps({
|
||||
"inputText": self.messages_to_prompt(messages),
|
||||
"textGenerationConfig": generate_kwargs
|
||||
})
|
||||
return body
|
||||
|
||||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
return rsp_dict['results'][0]['outputText'].strip()
|
||||
|
||||
def get_choice_text_from_stream(self, event):
|
||||
rsp_dict = json.loads(event["chunk"]["bytes"])
|
||||
completions = rsp_dict["outputText"]
|
||||
return completions
|
||||
|
||||
|
||||
PROVIDERS = {
|
||||
"mistral": MistralProvider(),
|
||||
"meta": MetaProvider(),
|
||||
"ai21": Ai21Provider(),
|
||||
"cohere": CohereProvider(),
|
||||
"anthropic": AnthropicProvider(),
|
||||
"amazon": AmazonProvider()
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue