mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-24 14:15:17 +02:00
update bedrock
This commit is contained in:
parent
07d4be2df3
commit
8c84de7468
6 changed files with 75 additions and 36 deletions
|
|
@ -100,6 +100,10 @@ class LLMConfig(YamlModel):
|
|||
# For Messages Control
|
||||
use_system_prompt: bool = True
|
||||
|
||||
# reasoning / thinking switch
|
||||
reasoning: bool = False
|
||||
reasoning_tokens: int = 4000 # reasoning budget tokens to generate, usually smaller than max_tokens
|
||||
|
||||
@field_validator("api_key")
|
||||
@classmethod
|
||||
def check_llm_key(cls, v):
|
||||
|
|
|
|||
|
|
@ -43,6 +43,16 @@ class BaseLLM(ABC):
|
|||
model: Optional[str] = None # deprecated
|
||||
pricing_plan: Optional[str] = None
|
||||
|
||||
_reasoning_content: Optional[str] = None # content from reasoning mode
|
||||
|
||||
@property
|
||||
def reasoning_content(self):
|
||||
return self._reasoning_content
|
||||
|
||||
@reasoning_content.setter
|
||||
def reasoning_content(self, value: str):
|
||||
self._reasoning_content = value
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, config: LLMConfig):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,11 +1,16 @@
|
|||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Union
|
||||
|
||||
|
||||
class BaseBedrockProvider(ABC):
|
||||
# to handle different generation kwargs
|
||||
max_tokens_field_name = "max_tokens"
|
||||
|
||||
def __init__(self, reasoning: bool = False, reasoning_tokens: int = 4000):
|
||||
self.reasoning = reasoning
|
||||
self.reasoning_tokens = reasoning_tokens
|
||||
|
||||
@abstractmethod
|
||||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
...
|
||||
|
|
@ -14,14 +19,14 @@ class BaseBedrockProvider(ABC):
|
|||
body = json.dumps({"prompt": self.messages_to_prompt(messages), **const_kwargs})
|
||||
return body
|
||||
|
||||
def get_choice_text(self, response_body: dict) -> str:
|
||||
def get_choice_text(self, response_body: dict) -> Union[str, dict[str, str]]:
|
||||
completions = self._get_completion_from_dict(response_body)
|
||||
return completions
|
||||
|
||||
def get_choice_text_from_stream(self, event) -> str:
|
||||
def get_choice_text_from_stream(self, event) -> Union[bool, str]:
|
||||
rsp_dict = json.loads(event["chunk"]["bytes"])
|
||||
completions = self._get_completion_from_dict(rsp_dict)
|
||||
return completions
|
||||
return False, completions
|
||||
|
||||
def messages_to_prompt(self, messages: list[dict]) -> str:
|
||||
"""[{"role": "user", "content": msg}] to user: <msg> etc."""
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import json
|
||||
from typing import Literal, Tuple
|
||||
from typing import Literal, Tuple, Union
|
||||
|
||||
from metagpt.provider.bedrock.base_provider import BaseBedrockProvider
|
||||
from metagpt.provider.bedrock.utils import (
|
||||
|
|
@ -32,6 +32,10 @@ class AnthropicProvider(BaseBedrockProvider):
|
|||
return self.messages_to_prompt(system_messages), user_messages
|
||||
|
||||
def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs) -> str:
|
||||
if self.reasoning:
|
||||
generate_kwargs["temperature"] = 1 # should be 1
|
||||
generate_kwargs["thinking"] = {"type": "enabled", "budget_tokens": self.reasoning_tokens}
|
||||
|
||||
system_message, user_messages = self._split_system_user_messages(messages)
|
||||
body = json.dumps(
|
||||
{
|
||||
|
|
@ -43,17 +47,26 @@ class AnthropicProvider(BaseBedrockProvider):
|
|||
)
|
||||
return body
|
||||
|
||||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
def _get_completion_from_dict(self, rsp_dict: dict) -> dict[str, Tuple[str, str]]:
|
||||
if self.reasoning:
|
||||
return {"reasoning_content": rsp_dict["content"][0]["thinking"], "content": rsp_dict["content"][1]["text"]}
|
||||
return rsp_dict["content"][0]["text"]
|
||||
|
||||
def get_choice_text_from_stream(self, event) -> str:
|
||||
def get_choice_text_from_stream(self, event) -> Union[bool, str]:
|
||||
# https://docs.anthropic.com/claude/reference/messages-streaming
|
||||
rsp_dict = json.loads(event["chunk"]["bytes"])
|
||||
if rsp_dict["type"] == "content_block_delta":
|
||||
completions = rsp_dict["delta"]["text"]
|
||||
return completions
|
||||
reasoning = False
|
||||
if rsp_dict["delta"]["type"] == "text_delta":
|
||||
completions = rsp_dict["delta"]["text"]
|
||||
elif rsp_dict["delta"]["type"] == "thinking_delta":
|
||||
completions = rsp_dict["delta"]["thinking"]
|
||||
reasoning = True
|
||||
elif rsp_dict["delta"]["type"] == "signature_delta":
|
||||
completions = ""
|
||||
return reasoning, completions
|
||||
else:
|
||||
return ""
|
||||
return False, ""
|
||||
|
||||
|
||||
class CohereProvider(BaseBedrockProvider):
|
||||
|
|
@ -87,10 +100,10 @@ class CohereProvider(BaseBedrockProvider):
|
|||
body = json.dumps({"prompt": prompt, "stream": kwargs.get("stream", False), **generate_kwargs})
|
||||
return body
|
||||
|
||||
def get_choice_text_from_stream(self, event) -> str:
|
||||
def get_choice_text_from_stream(self, event) -> Union[bool, str]:
|
||||
rsp_dict = json.loads(event["chunk"]["bytes"])
|
||||
completions = rsp_dict.get("text", "")
|
||||
return completions
|
||||
return False, completions
|
||||
|
||||
|
||||
class MetaProvider(BaseBedrockProvider):
|
||||
|
|
@ -133,10 +146,10 @@ class Ai21Provider(BaseBedrockProvider):
|
|||
)
|
||||
return body
|
||||
|
||||
def get_choice_text_from_stream(self, event) -> str:
|
||||
def get_choice_text_from_stream(self, event) -> Union[bool, str]:
|
||||
rsp_dict = json.loads(event["chunk"]["bytes"])
|
||||
completions = rsp_dict.get("choices", [{}])[0].get("delta", {}).get("content", "")
|
||||
return completions
|
||||
return False, completions
|
||||
|
||||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
if self.model_type == "j2":
|
||||
|
|
@ -159,10 +172,10 @@ class AmazonProvider(BaseBedrockProvider):
|
|||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
return rsp_dict["results"][0]["outputText"]
|
||||
|
||||
def get_choice_text_from_stream(self, event) -> str:
|
||||
def get_choice_text_from_stream(self, event) -> Union[bool, str]:
|
||||
rsp_dict = json.loads(event["chunk"]["bytes"])
|
||||
completions = rsp_dict["outputText"]
|
||||
return completions
|
||||
return False, completions
|
||||
|
||||
|
||||
PROVIDERS = {
|
||||
|
|
@ -175,8 +188,14 @@ PROVIDERS = {
|
|||
}
|
||||
|
||||
|
||||
def get_provider(model_id: str):
|
||||
provider, model_name = model_id.split(".")[0:2] # meta、mistral……
|
||||
def get_provider(model_id: str, reasoning: bool = False, reasoning_tokens: int = 4000):
|
||||
arr = model_id.split(".")
|
||||
if len(arr) == 2:
|
||||
provider, model_name = arr # meta、mistral……
|
||||
elif len(arr) == 3:
|
||||
# some model_ids may contain country like us.xx.xxx
|
||||
_, provider, model_name = arr
|
||||
|
||||
if provider not in PROVIDERS:
|
||||
raise KeyError(f"{provider} is not supported!")
|
||||
if provider == "meta":
|
||||
|
|
@ -188,4 +207,4 @@ def get_provider(model_id: str):
|
|||
elif provider == "cohere":
|
||||
# distinguish between R/R+ and older models
|
||||
return PROVIDERS[provider](model_name)
|
||||
return PROVIDERS[provider]()
|
||||
return PROVIDERS[provider](reasoning=reasoning, reasoning_tokens=reasoning_tokens)
|
||||
|
|
|
|||
|
|
@ -48,6 +48,9 @@ SUPPORT_STREAM_MODELS = {
|
|||
"anthropic.claude-3-opus-20240229-v1:0": 4096,
|
||||
# Claude 3.5 Sonnet
|
||||
"anthropic.claude-3-5-sonnet-20240620-v1:0": 8192,
|
||||
# Claude 3.7 Sonnet
|
||||
"us.anthropic.claude-3-7-sonnet-20250219-v1:0": 131072,
|
||||
"anthropic.claude-3-7-sonnet-20250219-v1:0": 131072,
|
||||
# Command Text
|
||||
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command.html
|
||||
"cohere.command-text-v14": 4096,
|
||||
|
|
@ -135,20 +138,6 @@ def messages_to_prompt_llama3(messages: list[dict]) -> str:
|
|||
return prompt
|
||||
|
||||
|
||||
def messages_to_prompt_claude2(messages: list[dict]) -> str:
|
||||
GENERAL_TEMPLATE = "\n\n{role}: {content}"
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
role = message.get("role", "")
|
||||
content = message.get("content", "")
|
||||
prompt += GENERAL_TEMPLATE.format(role=role, content=content)
|
||||
|
||||
if role != "assistant":
|
||||
prompt += "\n\nAssistant:"
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def get_max_tokens(model_id: str) -> int:
|
||||
try:
|
||||
max_tokens = (NOT_SUPPORT_STREAM_MODELS | SUPPORT_STREAM_MODELS)[model_id]
|
||||
|
|
|
|||
|
|
@ -23,7 +23,9 @@ class BedrockLLM(BaseLLM):
|
|||
def __init__(self, config: LLMConfig):
|
||||
self.config = config
|
||||
self.__client = self.__init_client("bedrock-runtime")
|
||||
self.__provider = get_provider(self.config.model)
|
||||
self.__provider = get_provider(
|
||||
self.config.model, reasoning=self.config.reasoning, reasoning_tokens=self.config.reasoning_tokens
|
||||
)
|
||||
self.cost_manager = CostManager(token_costs=BEDROCK_TOKEN_COSTS)
|
||||
if self.config.model in NOT_SUPPORT_STREAM_MODELS:
|
||||
logger.warning(f"model {self.config.model} doesn't support streaming output!")
|
||||
|
|
@ -102,7 +104,11 @@ class BedrockLLM(BaseLLM):
|
|||
# However,aioboto3 doesn't support invoke model
|
||||
|
||||
def get_choice_text(self, rsp: dict) -> str:
|
||||
return self.__provider.get_choice_text(rsp)
|
||||
rsp = self.__provider.get_choice_text(rsp)
|
||||
if isinstance(rsp, dict):
|
||||
self.reasoning_content = rsp.get("reasoning_content")
|
||||
rsp = rsp.get("content")
|
||||
return rsp
|
||||
|
||||
async def acompletion(self, messages: list[dict]) -> dict:
|
||||
request_body = self.__provider.get_request_body(messages, self._const_kwargs)
|
||||
|
|
@ -133,10 +139,16 @@ class BedrockLLM(BaseLLM):
|
|||
async def _get_stream_response_body(self, stream_response) -> List[str]:
|
||||
def collect_content() -> str:
|
||||
collected_content = []
|
||||
collected_reasoning_content = []
|
||||
for event in stream_response["body"]:
|
||||
chunk_text = self.__provider.get_choice_text_from_stream(event)
|
||||
collected_content.append(chunk_text)
|
||||
reasoning, chunk_text = self.__provider.get_choice_text_from_stream(event)
|
||||
if reasoning:
|
||||
collected_reasoning_content.append(chunk_text)
|
||||
else:
|
||||
collected_content.append(chunk_text)
|
||||
log_llm_stream(chunk_text)
|
||||
if collected_reasoning_content:
|
||||
self.reasoning_content = "".join(collected_reasoning_content)
|
||||
return collected_content
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue