mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-11 15:15:18 +02:00
update max tokens and support max_tokens field
This commit is contained in:
parent
6452adf82a
commit
784c6265ee
4 changed files with 53 additions and 44 deletions
|
|
@ -5,7 +5,7 @@ from metagpt.configs.llm_config import LLMConfig, LLMType
|
|||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.bedrock.bedrock_provider import get_provider
|
||||
from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS, SUPPORT_STREAM_MODELS
|
||||
from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS, get_max_tokens
|
||||
import boto3
|
||||
|
||||
|
||||
|
|
@ -15,6 +15,7 @@ class AmazonBedrockLLM(BaseLLM):
|
|||
self.config = config
|
||||
self.__client = self.__init_client("bedrock-runtime")
|
||||
self.__provider = get_provider(self.config.model)
|
||||
logger.warning("Amazon bedrock doesn't support async now")
|
||||
|
||||
def __init_client(self, service_name: Literal["bedrock-runtime", "bedrock"]):
|
||||
# access key from https://us-east-1.console.aws.amazon.com/iam
|
||||
|
|
@ -38,7 +39,13 @@ class AmazonBedrockLLM(BaseLLM):
|
|||
@property
|
||||
def _generate_kwargs(self) -> dict:
|
||||
# for now only use temperature due to the difference of request body
|
||||
model_max_tokens = get_max_tokens(self.config.model)
|
||||
if self.config.max_token > model_max_tokens:
|
||||
max_tokens = model_max_tokens
|
||||
else:
|
||||
max_tokens = self.config.max_token
|
||||
return {
|
||||
self.__provider.max_tokens_field_name: max_tokens,
|
||||
"temperature": self.config.temperature
|
||||
}
|
||||
|
||||
|
|
@ -59,6 +66,7 @@ class AmazonBedrockLLM(BaseLLM):
|
|||
|
||||
request_body = self.__provider.get_request_body(
|
||||
messages, **self._generate_kwargs)
|
||||
|
||||
response = self.__client.invoke_model_with_response_stream(
|
||||
modelId=self.config.model, body=request_body
|
||||
)
|
||||
|
|
@ -75,7 +83,13 @@ class AmazonBedrockLLM(BaseLLM):
|
|||
|
||||
async def acompletion(self, messages: list[dict]):
|
||||
# Amazon bedrock doesn't support async now
|
||||
return self._achat_completion(messages)
|
||||
return await self._achat_completion(messages)
|
||||
|
||||
async def acompletion_text(self, messages: list[dict], stream: bool = False,
|
||||
timeout: int = USE_CONFIG_TIMEOUT) -> str:
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
return await self._achat_completion(messages)
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
|
||||
return self.completion(messages)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
|
|||
|
||||
class BaseBedrockProvider(ABC):
|
||||
# to handle different generation kwargs
|
||||
max_tokens_field_name = "max_tokens"
|
||||
|
||||
@abstractmethod
|
||||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ class CohereProvider(BaseBedrockProvider):
|
|||
|
||||
class MetaProvider(BaseBedrockProvider):
|
||||
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html
|
||||
max_tokens_field_name = "max_gen_len"
|
||||
|
||||
def messages_to_prompt(self, messages: list[dict]):
|
||||
return messages_to_prompt_llama(messages)
|
||||
|
|
@ -44,6 +45,7 @@ class MetaProvider(BaseBedrockProvider):
|
|||
|
||||
class Ai21Provider(BaseBedrockProvider):
|
||||
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-jurassic2.html
|
||||
max_tokens_field_name = "maxTokens"
|
||||
|
||||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
return rsp_dict['completions'][0]["data"]["text"]
|
||||
|
|
@ -51,6 +53,7 @@ class Ai21Provider(BaseBedrockProvider):
|
|||
|
||||
class AmazonProvider(BaseBedrockProvider):
|
||||
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html
|
||||
max_tokens_field_name = "maxTokenCount"
|
||||
|
||||
def get_request_body(self, messages: list[dict], **generate_kwargs):
|
||||
body = json.dumps({
|
||||
|
|
|
|||
|
|
@ -1,5 +1,35 @@
|
|||
from metagpt.logs import logger
|
||||
|
||||
NOT_SUUPORT_STREAM_MODELS = {
|
||||
"ai21.j2-grande-instruct": 8000,
|
||||
"ai21.j2-jumbo-instruct": 8000,
|
||||
"ai21.j2-mid": 8000,
|
||||
"ai21.j2-mid-v1": 8000,
|
||||
"ai21.j2-ultra": 8000,
|
||||
"ai21.j2-ultra-v1": 8000,
|
||||
}
|
||||
|
||||
SUPPORT_STREAM_MODELS = {
|
||||
"amazon.titan-tg1-large": 8000,
|
||||
"amazon.titan-text-express-v1": 8000,
|
||||
"anthropic.claude-instant-v1": 100000,
|
||||
"anthropic.claude-v1": 100000,
|
||||
"anthropic.claude-v2": 100000,
|
||||
"anthropic.claude-v2:1": 200000,
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0": 200000,
|
||||
"anthropic.claude-3-haiku-20240307-v1:0": 200000,
|
||||
"anthropic.claude-3-opus-20240229-v1:0": 200000,
|
||||
"cohere.command-text-v14": 4096,
|
||||
"cohere.command-light-text-v14": 4096,
|
||||
"meta.llama2-70b-v1": 4096,
|
||||
"meta.llama3-8b-instruct-v1:0": 2000,
|
||||
"meta.llama3-70b-instruct-v1:0": 2000,
|
||||
"mistral.mistral-7b-instruct-v0:2": 32000,
|
||||
"mistral.mixtral-8x7b-instruct-v0:1": 32000,
|
||||
"mistral.mistral-large-2402-v1:0": 32000,
|
||||
}
|
||||
|
||||
|
||||
def messages_to_prompt_llama(messages: list[dict]):
|
||||
BOS, EOS = "<s>", "</s>"
|
||||
B_INST, E_INST = "[INST]", "[/INST]"
|
||||
|
|
@ -23,46 +53,7 @@ def messages_to_prompt_llama(messages: list[dict]):
|
|||
return prompt
|
||||
|
||||
|
||||
NOT_SUUPORT_STREAM_MODELS = {
|
||||
"ai21.j2-grande-instruct",
|
||||
"ai21.j2-jumbo-instruct",
|
||||
"ai21.j2-mid",
|
||||
"ai21.j2-mid-v1",
|
||||
"ai21.j2-ultra",
|
||||
"ai21.j2-ultra-v1",
|
||||
}
|
||||
def get_max_tokens(model_id) -> int:
|
||||
return (NOT_SUUPORT_STREAM_MODELS | SUPPORT_STREAM_MODELS)[model_id]
|
||||
|
||||
|
||||
SUPPORT_STREAM_MODELS = {
|
||||
"amazon.titan-tg1-large",
|
||||
"amazon.titan-text-lite-v1:0:4k",
|
||||
"amazon.titan-text-lite-v1",
|
||||
"amazon.titan-text-express-v1:0:8k",
|
||||
"amazon.titan-text-express-v1",
|
||||
"anthropic.claude-instant-v1:2:100k",
|
||||
"anthropic.claude-instant-v1",
|
||||
"anthropic.claude-v2:0:18k",
|
||||
"anthropic.claude-v2:0:100k",
|
||||
"anthropic.claude-v2:1:18k",
|
||||
"anthropic.claude-v2:1:200k",
|
||||
"anthropic.claude-v2:1",
|
||||
"anthropic.claude-v2:2:18k",
|
||||
"anthropic.claude-v2:2:200k",
|
||||
"anthropic.claude-v2:2",
|
||||
"anthropic.claude-v2",
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:28k",
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:200k",
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"anthropic.claude-3-haiku-20240307-v1:0:48k",
|
||||
"anthropic.claude-3-haiku-20240307-v1:0:200k",
|
||||
"anthropic.claude-3-haiku-20240307-v1:0",
|
||||
"cohere.command-text-v14:7:4k",
|
||||
"cohere.command-text-v14",
|
||||
"cohere.command-light-text-v14:7:4k",
|
||||
"cohere.command-light-text-v14",
|
||||
"meta.llama2-70b-v1",
|
||||
"meta.llama3-8b-instruct-v1:0",
|
||||
"meta.llama3-70b-instruct-v1:0",
|
||||
"mistral.mistral-7b-instruct-v0:2",
|
||||
"mistral.mixtral-8x7b-instruct-v0:1",
|
||||
"mistral.mistral-large-2402-v1:0",
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue