mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
resolve problem and add cost manager
This commit is contained in:
parent
f14a1f63ef
commit
0006b62901
4 changed files with 112 additions and 63 deletions
|
|
@ -13,25 +13,34 @@ NOT_SUUPORT_STREAM_MODELS = {
|
|||
SUPPORT_STREAM_MODELS = {
|
||||
"amazon.titan-tg1-large": 8000,
|
||||
"amazon.titan-text-express-v1": 8000,
|
||||
"amazon.titan-text-express-v1:0:8k": 8000,
|
||||
"amazon.titan-text-lite-v1:0:4k": 4000,
|
||||
"amazon.titan-text-lite-v1": 4000,
|
||||
"anthropic.claude-instant-v1": 100000,
|
||||
"anthropic.claude-instant-v1:2:100k": 100000,
|
||||
"anthropic.claude-v1": 100000,
|
||||
"anthropic.claude-v2": 100000,
|
||||
"anthropic.claude-v2:1": 200000,
|
||||
"anthropic.claude-v2:0:18k": 18000,
|
||||
"anthropic.claude-v2:1:200k": 200000,
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0": 200000,
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:28k": 28000,
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:200k": 200000,
|
||||
"anthropic.claude-3-haiku-20240307-v1:0": 200000,
|
||||
"anthropic.claude-3-haiku-20240307-v1:0:48k": 48000,
|
||||
"anthropic.claude-3-haiku-20240307-v1:0:200k": 200000,
|
||||
# currently (2024-4-29) only available at US West (Oregon) AWS Region.
|
||||
"anthropic.claude-3-opus-20240229-v1:0": 200000,
|
||||
"cohere.command-text-v14": 4000,
|
||||
"cohere.command-text-v14:7:4k": 4000,
|
||||
"cohere.command-light-text-v14": 4000,
|
||||
"cohere.command-light-text-v14:7:4k": 4000,
|
||||
"meta.llama2-70b-v1": 4000,
|
||||
"meta.llama2-13b-chat-v1:0:4k": 4000,
|
||||
"meta.llama2-13b-chat-v1": 2000,
|
||||
"meta.llama2-70b-v1": 4000,
|
||||
"meta.llama2-70b-v1:0:4k": 4000,
|
||||
"meta.llama2-70b-chat-v1": 4000,
|
||||
"meta.llama2-70b-chat-v1:0:4k": 4000,
|
||||
"meta.llama3-8b-instruct-v1:0": 2000,
|
||||
"meta.llama3-70b-instruct-v1:0": 2000,
|
||||
"mistral.mistral-7b-instruct-v0:2": 32000,
|
||||
|
|
@ -43,14 +52,14 @@ SUPPORT_STREAM_MODELS = {
|
|||
|
||||
|
||||
def messages_to_prompt_llama2(messages: list[dict]) -> str:
|
||||
BOS, EOS = "<s>", "</s>"
|
||||
BOS = ("<s>",)
|
||||
B_INST, E_INST = "[INST]", "[/INST]"
|
||||
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
||||
|
||||
prompt = f"{BOS}"
|
||||
for message in messages:
|
||||
role = message["role"]
|
||||
content = message["content"]
|
||||
role = message.get("role", "")
|
||||
content = message.get("content", "")
|
||||
if role == "system":
|
||||
prompt += f"{B_SYS} {content} {E_SYS}"
|
||||
elif role == "user":
|
||||
|
|
@ -58,25 +67,24 @@ def messages_to_prompt_llama2(messages: list[dict]) -> str:
|
|||
elif role == "assistant":
|
||||
prompt += f"{content}"
|
||||
else:
|
||||
logger.warning(
|
||||
f"Unknown role name {role} when formatting messages")
|
||||
logger.warning(f"Unknown role name {role} when formatting messages")
|
||||
prompt += f"{content}"
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def messages_to_prompt_llama3(messages: list[dict]) -> str:
|
||||
BOS, EOS = "<|begin_of_text|>", "<|eot_id|>"
|
||||
BOS = "<|begin_of_text|>"
|
||||
GENERAL_TEMPLATE = "<|start_header_id|>{role}<|end_header_id|>\n\n{content}<|eot_id|>"
|
||||
|
||||
prompt = f"{BOS}"
|
||||
for message in messages:
|
||||
role = message["role"]
|
||||
content = message["content"]
|
||||
role = message.get("role", "")
|
||||
content = message.get("content", "")
|
||||
prompt += GENERAL_TEMPLATE.format(role=role, content=content)
|
||||
|
||||
if role != "assistant":
|
||||
prompt += f"<|start_header_id|>assistant<|end_header_id|>"
|
||||
prompt += "<|start_header_id|>assistant<|end_header_id|>"
|
||||
|
||||
return prompt
|
||||
|
||||
|
|
@ -85,15 +93,20 @@ def messages_to_prompt_claude2(messages: list[dict]) -> str:
|
|||
GENERAL_TEMPLATE = "\n\n{role}: {content}"
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
role = message["role"]
|
||||
content = message["content"]
|
||||
role = message.get("role", "")
|
||||
content = message.get("content", "")
|
||||
prompt += GENERAL_TEMPLATE.format(role=role, content=content)
|
||||
|
||||
if role != "assistant":
|
||||
prompt += f"\n\nAssistant:"
|
||||
prompt += "\n\nAssistant:"
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def get_max_tokens(model_id) -> int:
|
||||
return (NOT_SUUPORT_STREAM_MODELS | SUPPORT_STREAM_MODELS)[model_id]
|
||||
def get_max_tokens(model_id: str) -> int:
|
||||
try:
|
||||
max_tokens = (NOT_SUUPORT_STREAM_MODELS | SUPPORT_STREAM_MODELS)[model_id]
|
||||
except KeyError:
|
||||
logger.warning(f"Couldn't find model:{model_id} , max tokens has been set to 2048")
|
||||
max_tokens = 2048
|
||||
return max_tokens
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ from metagpt.provider.base_llm import BaseLLM
|
|||
from metagpt.provider.bedrock.bedrock_provider import get_provider
|
||||
from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS, get_max_tokens
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.token_counter import BEDROCK_TOKEN_COSTS
|
||||
|
||||
|
||||
@register_provider([LLMType.BEDROCK])
|
||||
|
|
@ -19,6 +21,7 @@ class BedrockLLM(BaseLLM):
|
|||
self.config = config
|
||||
self.__client = self.__init_client("bedrock-runtime")
|
||||
self.__provider = get_provider(self.config.model)
|
||||
self.cost_manager = CostManager(token_costs=BEDROCK_TOKEN_COSTS)
|
||||
logger.warning("Amazon bedrock doesn't support asynchronous now")
|
||||
|
||||
def __init_client(self, service_name: Literal["bedrock-runtime", "bedrock"]):
|
||||
|
|
@ -62,14 +65,14 @@ class BedrockLLM(BaseLLM):
|
|||
def invoke_model(self, request_body: str) -> dict:
|
||||
response = self.__client.invoke_model(modelId=self.config.model, body=request_body)
|
||||
usage = self._get_usage(response)
|
||||
self._update_costs(usage)
|
||||
self._update_costs(usage, self.config.model)
|
||||
response_body = self._get_response_body(response)
|
||||
return response_body
|
||||
|
||||
def invoke_model_with_response_stream(self, request_body: str) -> EventStream:
|
||||
response = self.__client.invoke_model_with_response_stream(modelId=self.config.model, body=request_body)
|
||||
usage = self._get_usage(response)
|
||||
self._update_costs(usage)
|
||||
self._update_costs(usage, self.config.model)
|
||||
return response
|
||||
|
||||
@property
|
||||
|
|
@ -82,16 +85,29 @@ class BedrockLLM(BaseLLM):
|
|||
|
||||
return {self.__provider.max_tokens_field_name: max_tokens, "temperature": self.config.temperature}
|
||||
|
||||
def completion(self, messages: list[dict]) -> str:
|
||||
# boto3 don't support support asynchronous calls.
|
||||
# for asynchronous version of boto3, check out:
|
||||
# https://aioboto3.readthedocs.io/en/latest/usage.html
|
||||
# However,aioboto3 doesn't support invoke model
|
||||
|
||||
def get_choice_text(self, rsp: dict) -> str:
|
||||
return self.__provider.get_choice_text(rsp)
|
||||
|
||||
async def acompletion(self, messages: list[dict]) -> dict:
|
||||
request_body = self.__provider.get_request_body(messages, self._const_kwargs)
|
||||
response_body = self.invoke_model(request_body)
|
||||
completions = self.__provider.get_choice_text(response_body)
|
||||
return completions
|
||||
return response_body
|
||||
|
||||
def _chat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str:
|
||||
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
|
||||
return await self.acompletion(messages)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str:
|
||||
if self.config.model in NOT_SUUPORT_STREAM_MODELS:
|
||||
logger.warning(f"model {self.config.model} doesn't support streaming output!")
|
||||
return self.completion(messages)
|
||||
rsp = await self.acompletion(messages)
|
||||
full_text = self.get_choice_text(rsp)
|
||||
log_llm_stream(full_text)
|
||||
return full_text
|
||||
|
||||
request_body = self.__provider.get_request_body(messages, self._const_kwargs, stream=True)
|
||||
|
||||
|
|
@ -106,20 +122,6 @@ class BedrockLLM(BaseLLM):
|
|||
full_text = ("".join(collected_content)).lstrip()
|
||||
return full_text
|
||||
|
||||
# boto3 don't support support asynchronous calls.
|
||||
# for asynchronous version of boto3, check out:
|
||||
# https://aioboto3.readthedocs.io/en/latest/usage.html
|
||||
# However,aioboto3 doesn't support invoke model
|
||||
|
||||
async def acompletion(self, messages: list[dict]):
|
||||
return await self._achat_completion(messages)
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
|
||||
return self.completion(messages)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
|
||||
return self._chat_completion_stream(messages)
|
||||
|
||||
def _get_response_body(self, response) -> dict:
|
||||
response_body = json.loads(response["body"].read())
|
||||
return response_body
|
||||
|
|
|
|||
|
|
@ -198,6 +198,53 @@ TOKEN_MAX = {
|
|||
"openai/gpt-4-turbo-preview": 128000,
|
||||
}
|
||||
|
||||
# For Amazon Bedrock US region
|
||||
# See https://aws.amazon.com/cn/bedrock/pricing/
|
||||
|
||||
BEDROCK_TOKEN_COSTS = {
|
||||
"amazon.titan-tg1-large": {"prompt": 0.0008, "completion": 0.0008},
|
||||
"amazon.titan-text-express-v1": {"prompt": 0.0008, "completion": 0.0008},
|
||||
"amazon.titan-text-express-v1:0:8k": {"prompt": 0.0008, "completion": 0.0008},
|
||||
"amazon.titan-text-lite-v1:0:4k": {"prompt": 0.0003, "completion": 0.0004},
|
||||
"amazon.titan-text-lite-v1": {"prompt": 0.0003, "completion": 0.0004},
|
||||
"anthropic.claude-instant-v1": {"prompt": 0.0008, "completion": 0.00024},
|
||||
"anthropic.claude-instant-v1:2:100k": {"prompt": 0.0008, "completion": 0.00024},
|
||||
"anthropic.claude-v1": {"prompt": 0.008, "completion": 0.0024},
|
||||
"anthropic.claude-v2": {"prompt": 0.008, "completion": 0.0024},
|
||||
"anthropic.claude-v2:1": {"prompt": 0.008, "completion": 0.0024},
|
||||
"anthropic.claude-v2:0:18k": {"prompt": 0.008, "completion": 0.0024},
|
||||
"anthropic.claude-v2:1:200k": {"prompt": 0.008, "completion": 0.0024},
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0": {"prompt": 0.003, "completion": 0.015},
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:28k": {"prompt": 0.003, "completion": 0.015},
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:200k": {"prompt": 0.003, "completion": 0.015},
|
||||
"anthropic.claude-3-haiku-20240307-v1:0": {"prompt": 0.00025, "completion": 0.00125},
|
||||
"anthropic.claude-3-haiku-20240307-v1:0:48k": {"prompt": 0.00025, "completion": 0.00125},
|
||||
"anthropic.claude-3-haiku-20240307-v1:0:200k": {"prompt": 0.00025, "completion": 0.00125},
|
||||
# currently (2024-4-29) only available at US West (Oregon) AWS Region.
|
||||
"anthropic.claude-3-opus-20240229-v1:0": {"prompt": 0.015, "completion": 0.075},
|
||||
"cohere.command-text-v14": {"prompt": 0.0015, "completion": 0.0015},
|
||||
"cohere.command-text-v14:7:4k": {"prompt": 0.0015, "completion": 0.0015},
|
||||
"cohere.command-light-text-v14": {"prompt": 0.0003, "completion": 0.0003},
|
||||
"cohere.command-light-text-v14:7:4k": {"prompt": 0.0003, "completion": 0.0003},
|
||||
"meta.llama2-13b-chat-v1:0:4k": {"prompt": 0.00075, "completion": 0.001},
|
||||
"meta.llama2-13b-chat-v1": {"prompt": 0.00075, "completion": 0.001},
|
||||
"meta.llama2-70b-v1": {"prompt": 0.00195, "completion": 0.00256},
|
||||
"meta.llama2-70b-v1:0:4k": {"prompt": 0.00195, "completion": 0.00256},
|
||||
"meta.llama2-70b-chat-v1": {"prompt": 0.00195, "completion": 0.00256},
|
||||
"meta.llama2-70b-chat-v1:0:4k": {"prompt": 0.00195, "completion": 0.00256},
|
||||
"meta.llama3-8b-instruct-v1:0": {"prompt": 0.0004, "completion": 0.0006},
|
||||
"meta.llama3-70b-instruct-v1:0": {"prompt": 0.00265, "completion": 0.0035},
|
||||
"mistral.mistral-7b-instruct-v0:2": {"prompt": 0.00015, "completion": 0.0002},
|
||||
"mistral.mixtral-8x7b-instruct-v0:1": {"prompt": 0.00045, "completion": 0.0007},
|
||||
"mistral.mistral-large-2402-v1:0": {"prompt": 0.008, "completion": 0.024},
|
||||
"ai21.j2-grande-instruct": {"prompt": 0.0125, "completion": 0.0125},
|
||||
"ai21.j2-jumbo-instruct": {"prompt": 0.0188, "completion": 0.0188},
|
||||
"ai21.j2-mid": {"prompt": 0.0125, "completion": 0.0125},
|
||||
"ai21.j2-mid-v1": {"prompt": 0.0125, "completion": 0.0125},
|
||||
"ai21.j2-ultra": {"prompt": 0.0188, "completion": 0.0188},
|
||||
"ai21.j2-ultra-v1": {"prompt": 0.0188, "completion": 0.0188},
|
||||
}
|
||||
|
||||
|
||||
def count_message_tokens(messages, model="gpt-3.5-turbo-0125"):
|
||||
"""Return the number of tokens used by a list of messages."""
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import pytest
|
|||
from metagpt.provider.bedrock.utils import (
|
||||
NOT_SUUPORT_STREAM_MODELS,
|
||||
SUPPORT_STREAM_MODELS,
|
||||
get_max_tokens,
|
||||
)
|
||||
from metagpt.provider.bedrock_api import BedrockLLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config_bedrock
|
||||
|
|
@ -17,14 +16,19 @@ from tests.metagpt.provider.req_resp_const import (
|
|||
# all available model from bedrock
|
||||
models = SUPPORT_STREAM_MODELS | NOT_SUUPORT_STREAM_MODELS
|
||||
messages = [{"role": "user", "content": "Hi!"}]
|
||||
usage = {
|
||||
"prompt_tokens": 1000000,
|
||||
"completion_tokens": 1000000,
|
||||
}
|
||||
|
||||
|
||||
def mock_bedrock_provider_response(self, *args, **kwargs) -> dict:
|
||||
def mock_invoke_model(self: BedrockLLM, *args, **kwargs) -> dict:
|
||||
provider = self.config.model.split(".")[0]
|
||||
self._update_costs(usage, self.config.model)
|
||||
return BEDROCK_PROVIDER_RESPONSE_BODY[provider]
|
||||
|
||||
|
||||
def mock_bedrock_provider_stream_response(self, *args, **kwargs) -> dict:
|
||||
def mock_invoke_model_stream(self: BedrockLLM, *args, **kwargs) -> dict:
|
||||
# use json object to mock EventStream
|
||||
def dict2bytes(x):
|
||||
return json.dumps(x).encode("utf-8")
|
||||
|
|
@ -43,6 +47,7 @@ def mock_bedrock_provider_stream_response(self, *args, **kwargs) -> dict:
|
|||
response_body_bytes = dict2bytes(BEDROCK_PROVIDER_RESPONSE_BODY[provider])
|
||||
|
||||
response_body_stream = {"body": [{"chunk": {"bytes": response_body_bytes}}]}
|
||||
self._update_costs(usage, self.config.model)
|
||||
return response_body_stream
|
||||
|
||||
|
||||
|
|
@ -82,41 +87,23 @@ def bedrock_api(request) -> BedrockLLM:
|
|||
|
||||
class TestBedrockAPI:
|
||||
def _patch_invoke_model(self, mocker):
|
||||
mocker.patch("metagpt.provider.bedrock_api.BedrockLLM.invoke_model", mock_bedrock_provider_response)
|
||||
mocker.patch("metagpt.provider.bedrock_api.BedrockLLM.invoke_model", mock_invoke_model)
|
||||
|
||||
def _patch_invoke_model_stream(self, mocker):
|
||||
mocker.patch(
|
||||
"metagpt.provider.bedrock_api.BedrockLLM.invoke_model_with_response_stream",
|
||||
mock_bedrock_provider_stream_response,
|
||||
mock_invoke_model_stream,
|
||||
)
|
||||
|
||||
def test_const_kwargs(self, bedrock_api: BedrockLLM):
|
||||
provider = bedrock_api.provider
|
||||
assert bedrock_api._const_kwargs[provider.max_tokens_field_name] <= get_max_tokens(bedrock_api.config.model)
|
||||
|
||||
def test_get_request_body(self, bedrock_api: BedrockLLM):
|
||||
"""Ensure request body has correct format"""
|
||||
provider = bedrock_api.provider
|
||||
request_body = json.loads(provider.get_request_body(messages, bedrock_api._const_kwargs))
|
||||
|
||||
assert is_subset(request_body, get_bedrock_request_body(bedrock_api.config.model))
|
||||
|
||||
def test_completion(self, bedrock_api: BedrockLLM, mocker):
|
||||
self._patch_invoke_model(mocker)
|
||||
assert bedrock_api.completion(messages) == "Hello World"
|
||||
|
||||
def test_chat_completion_stream(self, bedrock_api: BedrockLLM, mocker):
|
||||
@pytest.mark.asyncio
|
||||
async def test_aask(self, bedrock_api: BedrockLLM, mocker):
|
||||
self._patch_invoke_model(mocker)
|
||||
self._patch_invoke_model_stream(mocker)
|
||||
assert bedrock_api._chat_completion_stream(messages) == "Hello World"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_achat_completion_stream(self, bedrock_api: BedrockLLM, mocker):
|
||||
self._patch_invoke_model_stream(mocker)
|
||||
self._patch_invoke_model(mocker)
|
||||
assert await bedrock_api._achat_completion_stream(messages) == "Hello World"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acompletion(self, bedrock_api: BedrockLLM, mocker):
|
||||
self._patch_invoke_model(mocker)
|
||||
assert await bedrock_api.acompletion(messages) == "Hello World"
|
||||
assert await bedrock_api.aask(messages, stream=False) == "Hello World"
|
||||
assert await bedrock_api.aask(messages, stream=True) == "Hello World"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue