mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-05 14:55:18 +02:00
implement all model
This commit is contained in:
parent
187e9ef698
commit
6355c5c0ed
4 changed files with 87 additions and 73 deletions
|
|
@ -1,15 +1,14 @@
|
|||
|
||||
import json
|
||||
from typing import Coroutine, Literal
|
||||
from typing import Literal
|
||||
from metagpt.const import USE_CONFIG_TIMEOUT
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from botocore.config import Config
|
||||
import boto3
|
||||
|
||||
from metagpt.provider.bedrock.bedrock_provider import get_provider
|
||||
from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS
|
||||
import boto3
|
||||
|
||||
|
||||
@register_provider([LLMType.AMAZON_BEDROCK])
|
||||
|
|
@ -40,8 +39,9 @@ class AmazonBedrockLLM(BaseLLM):
|
|||
|
||||
@property
|
||||
def _generate_kwargs(self):
|
||||
# for now only use temperature due to the difference of request body
|
||||
return {
|
||||
"temperature": self.config.get("temperature", 0.3),
|
||||
"temperature": self.config.get("temperature", 0.1),
|
||||
}
|
||||
|
||||
def completion(self, messages: list[dict]):
|
||||
|
|
@ -51,10 +51,14 @@ class AmazonBedrockLLM(BaseLLM):
|
|||
modelId=self.config.model, body=request_body
|
||||
)
|
||||
completions = self.provider.get_choice_text(response)
|
||||
log_llm_stream(completions)
|
||||
return completions
|
||||
|
||||
def _chat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
|
||||
if self.config.model in NOT_SUUPORT_STREAM_MODELS:
|
||||
logger.warning(
|
||||
f"model {self.config.model} doesn't support streaming output!")
|
||||
return self.completion(messages)
|
||||
|
||||
request_body = self.provider.get_request_body(
|
||||
messages, **self._generate_kwargs)
|
||||
response = self.__client.invoke_model_with_response_stream(
|
||||
|
|
@ -90,5 +94,5 @@ if __name__ == '__main__':
|
|||
{"role": "assistant", "content": "hello,my friend"},
|
||||
{"role": "user", "content": "What is your name?"}]
|
||||
llm = AmazonBedrockLLM(my_config)
|
||||
llm.completion(messages)
|
||||
llm._chat_completion_stream(messages)
|
||||
print(llm.completion(messages))
|
||||
print(llm._chat_completion_stream(messages))
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseBedrockProvider(object):
|
||||
class BaseBedrockProvider(ABC):
|
||||
# to handle different generation kwargs
|
||||
def get_request_body(self, messages, **generate_kwargs):
|
||||
body = json.dumps(
|
||||
|
|
@ -10,17 +11,22 @@ class BaseBedrockProvider(object):
|
|||
|
||||
def get_choice_text(self, response) -> str:
|
||||
response_body = self._get_response_body_json(response)
|
||||
completions = response_body["outputs"][0]['text']
|
||||
completions = self._get_completion_from_dict(response_body)
|
||||
return completions
|
||||
|
||||
def get_choice_text_from_stream(self, event):
|
||||
completions = json.loads(event["chunk"]["bytes"])["outputs"][0]["text"]
|
||||
rsp_dict = json.loads(event["chunk"]["bytes"])
|
||||
completions = self._get_completion_from_dict(rsp_dict)
|
||||
return completions
|
||||
|
||||
def _get_response_body_json(self, response):
|
||||
response_body = json.loads(response["body"].read())
|
||||
return response_body
|
||||
|
||||
@abstractmethod
|
||||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
...
|
||||
|
||||
def messages_to_prompt(self, messages: list[dict]):
|
||||
"""[{"role": "user", "content": msg}] to user: <msg> etc."""
|
||||
return "\n".join([f"{i['role']}: {i['content']}" for i in messages])
|
||||
|
|
|
|||
|
|
@ -5,96 +5,56 @@ from metagpt.provider.bedrock.utils import messages_to_prompt_llama
|
|||
|
||||
class MistralProvider(BaseBedrockProvider):
|
||||
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral.html
|
||||
|
||||
def messages_to_prompt(self, messages: list[dict]):
|
||||
return messages_to_prompt_llama(messages)
|
||||
|
||||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
return rsp_dict["outputs"][0]["text"]
|
||||
|
||||
|
||||
class AnthropicProvider(BaseBedrockProvider):
|
||||
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
|
||||
def get_request_body(self, messages, **generate_kwargs):
|
||||
body = json.dumps(
|
||||
{"messages": messages, "anthropic_version": "bedrock-2023-05-31", **generate_kwargs})
|
||||
return body
|
||||
|
||||
def get_choice_text(self, response) -> str:
|
||||
response_body = self._get_response_body_json(response)
|
||||
completions = response_body["content"][0]['text']
|
||||
return completions
|
||||
|
||||
def get_choice_text_from_stream(self, event):
|
||||
completions = json.loads(event["chunk"]["bytes"])["content"][0]["text"]
|
||||
return completions
|
||||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
return rsp_dict["content"][0]["text"]
|
||||
|
||||
|
||||
class CohereProvider(BaseBedrockProvider):
|
||||
pass
|
||||
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command.html
|
||||
|
||||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
return rsp_dict["generations"][0]["text"]
|
||||
|
||||
|
||||
class MetaProvider(BaseBedrockProvider):
|
||||
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html
|
||||
|
||||
def messages_to_prompt(self, messages: list[dict]):
|
||||
return messages_to_prompt_llama(messages)
|
||||
|
||||
def get_choice_text(self, response) -> str:
|
||||
response_body = self._get_response_body_json(response)
|
||||
completions = response_body['generation']
|
||||
return completions
|
||||
|
||||
def get_choice_text_from_stream(self, event):
|
||||
completions = json.loads(event["chunk"]["bytes"])["generation"]
|
||||
return completions
|
||||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
return rsp_dict["generation"]
|
||||
|
||||
|
||||
class Ai21Provider(BaseBedrockProvider):
|
||||
pass
|
||||
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-jurassic2.html
|
||||
|
||||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
return rsp_dict['completions'][0]["data"]["text"]
|
||||
|
||||
|
||||
PROVIDERS = {
|
||||
"mistral": MistralProvider(),
|
||||
"meta": MetaProvider(),
|
||||
}
|
||||
|
||||
NOT_SUUPORT_STREAM_MODELS = {
|
||||
"ai21.j2-grande-instruct",
|
||||
"ai21.j2-jumbo-instruct",
|
||||
"ai21.j2-mid",
|
||||
"ai21.j2-mid-v1",
|
||||
"ai21.j2-ultra",
|
||||
"ai21.j2-ultra-v1",
|
||||
}
|
||||
|
||||
SUPPORT_STREAM_MODELS = {
|
||||
"amazon.titan-tg1-large",
|
||||
"amazon.titan-text-lite-v1:0:4k",
|
||||
"amazon.titan-text-lite-v1",
|
||||
"amazon.titan-text-express-v1:0:8k",
|
||||
"amazon.titan-text-express-v1",
|
||||
"anthropic.claude-instant-v1:2:100k",
|
||||
"anthropic.claude-instant-v1",
|
||||
"anthropic.claude-v2:0:18k",
|
||||
"anthropic.claude-v2:0:100k",
|
||||
"anthropic.claude-v2:1:18k",
|
||||
"anthropic.claude-v2:1:200k",
|
||||
"anthropic.claude-v2:1",
|
||||
"anthropic.claude-v2:2:18k",
|
||||
"anthropic.claude-v2:2:200k",
|
||||
"anthropic.claude-v2:2",
|
||||
"anthropic.claude-v2",
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:28k",
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:200k",
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"anthropic.claude-3-haiku-20240307-v1:0:48k",
|
||||
"anthropic.claude-3-haiku-20240307-v1:0:200k",
|
||||
"anthropic.claude-3-haiku-20240307-v1:0",
|
||||
"cohere.command-text-v14:7:4k",
|
||||
"cohere.command-text-v14",
|
||||
"cohere.command-light-text-v14:7:4k",
|
||||
"cohere.command-light-text-v14",
|
||||
"meta.llama2-70b-v1",
|
||||
"meta.llama3-8b-instruct-v1:0",
|
||||
"meta.llama3-70b-instruct-v1:0",
|
||||
"mistral.mistral-7b-instruct-v0:2",
|
||||
"mistral.mixtral-8x7b-instruct-v0:1",
|
||||
"mistral.mistral-large-2402-v1:0",
|
||||
"ai21": Ai21Provider(),
|
||||
"cohere": CohereProvider(),
|
||||
"anthropic": AnthropicProvider(),
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -23,3 +23,47 @@ def messages_to_prompt_llama(messages: list[dict]):
|
|||
|
||||
return prompt
|
||||
|
||||
|
||||
NOT_SUUPORT_STREAM_MODELS = {
|
||||
"ai21.j2-grande-instruct",
|
||||
"ai21.j2-jumbo-instruct",
|
||||
"ai21.j2-mid",
|
||||
"ai21.j2-mid-v1",
|
||||
"ai21.j2-ultra",
|
||||
"ai21.j2-ultra-v1",
|
||||
}
|
||||
|
||||
SUPPORT_STREAM_MODELS = {
|
||||
"amazon.titan-tg1-large",
|
||||
"amazon.titan-text-lite-v1:0:4k",
|
||||
"amazon.titan-text-lite-v1",
|
||||
"amazon.titan-text-express-v1:0:8k",
|
||||
"amazon.titan-text-express-v1",
|
||||
"anthropic.claude-instant-v1:2:100k",
|
||||
"anthropic.claude-instant-v1",
|
||||
"anthropic.claude-v2:0:18k",
|
||||
"anthropic.claude-v2:0:100k",
|
||||
"anthropic.claude-v2:1:18k",
|
||||
"anthropic.claude-v2:1:200k",
|
||||
"anthropic.claude-v2:1",
|
||||
"anthropic.claude-v2:2:18k",
|
||||
"anthropic.claude-v2:2:200k",
|
||||
"anthropic.claude-v2:2",
|
||||
"anthropic.claude-v2",
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:28k",
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:200k",
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"anthropic.claude-3-haiku-20240307-v1:0:48k",
|
||||
"anthropic.claude-3-haiku-20240307-v1:0:200k",
|
||||
"anthropic.claude-3-haiku-20240307-v1:0",
|
||||
"cohere.command-text-v14:7:4k",
|
||||
"cohere.command-text-v14",
|
||||
"cohere.command-light-text-v14:7:4k",
|
||||
"cohere.command-light-text-v14",
|
||||
"meta.llama2-70b-v1",
|
||||
"meta.llama3-8b-instruct-v1:0",
|
||||
"meta.llama3-70b-instruct-v1:0",
|
||||
"mistral.mistral-7b-instruct-v0:2",
|
||||
"mistral.mixtral-8x7b-instruct-v0:1",
|
||||
"mistral.mistral-large-2402-v1:0",
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue