diff --git a/metagpt/configs/llm_config.py b/metagpt/configs/llm_config.py index ef034ca49..5fc5519fe 100644 --- a/metagpt/configs/llm_config.py +++ b/metagpt/configs/llm_config.py @@ -100,6 +100,10 @@ class LLMConfig(YamlModel): # For Messages Control use_system_prompt: bool = True + # reasoning / thinking switch + reasoning: bool = False + reasoning_tokens: int = 4000 # reasoning budget tokens to generate, usually smaller than max_tokens + @field_validator("api_key") @classmethod def check_llm_key(cls, v): diff --git a/metagpt/provider/base_llm.py b/metagpt/provider/base_llm.py index a95e8dbd3..80e51f8ac 100644 --- a/metagpt/provider/base_llm.py +++ b/metagpt/provider/base_llm.py @@ -43,6 +43,16 @@ class BaseLLM(ABC): model: Optional[str] = None # deprecated pricing_plan: Optional[str] = None + _reasoning_content: Optional[str] = None # content from reasoning mode + + @property + def reasoning_content(self): + return self._reasoning_content + + @reasoning_content.setter + def reasoning_content(self, value: str): + self._reasoning_content = value + @abstractmethod def __init__(self, config: LLMConfig): pass diff --git a/metagpt/provider/bedrock/base_provider.py b/metagpt/provider/bedrock/base_provider.py index ebc55483b..9a7275fe0 100644 --- a/metagpt/provider/bedrock/base_provider.py +++ b/metagpt/provider/bedrock/base_provider.py @@ -1,11 +1,16 @@ import json from abc import ABC, abstractmethod +from typing import Union class BaseBedrockProvider(ABC): # to handle different generation kwargs max_tokens_field_name = "max_tokens" + def __init__(self, reasoning: bool = False, reasoning_tokens: int = 4000): + self.reasoning = reasoning + self.reasoning_tokens = reasoning_tokens + @abstractmethod def _get_completion_from_dict(self, rsp_dict: dict) -> str: ... @@ -14,14 +19,14 @@ class BaseBedrockProvider(ABC): body = json.dumps({"prompt": self.messages_to_prompt(messages), **const_kwargs}) return body - def get_choice_text(self, response_body: dict) -> str: + def get_choice_text(self, response_body: dict) -> Union[str, dict[str, str]]: completions = self._get_completion_from_dict(response_body) return completions - def get_choice_text_from_stream(self, event) -> str: + def get_choice_text_from_stream(self, event) -> Union[bool, str]: rsp_dict = json.loads(event["chunk"]["bytes"]) completions = self._get_completion_from_dict(rsp_dict) - return completions + return False, completions def messages_to_prompt(self, messages: list[dict]) -> str: """[{"role": "user", "content": msg}] to user: etc.""" diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py index c5e3b7bd2..c8bc4a5e3 100644 --- a/metagpt/provider/bedrock/bedrock_provider.py +++ b/metagpt/provider/bedrock/bedrock_provider.py @@ -1,5 +1,5 @@ import json -from typing import Literal, Tuple +from typing import Literal, Tuple, Union from metagpt.provider.bedrock.base_provider import BaseBedrockProvider from metagpt.provider.bedrock.utils import ( @@ -32,6 +32,10 @@ class AnthropicProvider(BaseBedrockProvider): return self.messages_to_prompt(system_messages), user_messages def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs) -> str: + if self.reasoning: + generate_kwargs["temperature"] = 1 # should be 1 + generate_kwargs["thinking"] = {"type": "enabled", "budget_tokens": self.reasoning_tokens} + system_message, user_messages = self._split_system_user_messages(messages) body = json.dumps( { @@ -43,17 +47,26 @@ class AnthropicProvider(BaseBedrockProvider): ) return body - def _get_completion_from_dict(self, rsp_dict: dict) -> str: + def _get_completion_from_dict(self, rsp_dict: dict) -> dict[str, Tuple[str, str]]: + if self.reasoning: + return {"reasoning_content": rsp_dict["content"][0]["thinking"], "content": rsp_dict["content"][1]["text"]} return rsp_dict["content"][0]["text"] - def get_choice_text_from_stream(self, event) -> str: + def get_choice_text_from_stream(self, event) -> Union[bool, str]: # https://docs.anthropic.com/claude/reference/messages-streaming rsp_dict = json.loads(event["chunk"]["bytes"]) if rsp_dict["type"] == "content_block_delta": - completions = rsp_dict["delta"]["text"] - return completions + reasoning = False + if rsp_dict["delta"]["type"] == "text_delta": + completions = rsp_dict["delta"]["text"] + elif rsp_dict["delta"]["type"] == "thinking_delta": + completions = rsp_dict["delta"]["thinking"] + reasoning = True + elif rsp_dict["delta"]["type"] == "signature_delta": + completions = "" + return reasoning, completions else: - return "" + return False, "" class CohereProvider(BaseBedrockProvider): @@ -87,10 +100,10 @@ class CohereProvider(BaseBedrockProvider): body = json.dumps({"prompt": prompt, "stream": kwargs.get("stream", False), **generate_kwargs}) return body - def get_choice_text_from_stream(self, event) -> str: + def get_choice_text_from_stream(self, event) -> Union[bool, str]: rsp_dict = json.loads(event["chunk"]["bytes"]) completions = rsp_dict.get("text", "") - return completions + return False, completions class MetaProvider(BaseBedrockProvider): @@ -133,10 +146,10 @@ class Ai21Provider(BaseBedrockProvider): ) return body - def get_choice_text_from_stream(self, event) -> str: + def get_choice_text_from_stream(self, event) -> Union[bool, str]: rsp_dict = json.loads(event["chunk"]["bytes"]) completions = rsp_dict.get("choices", [{}])[0].get("delta", {}).get("content", "") - return completions + return False, completions def _get_completion_from_dict(self, rsp_dict: dict) -> str: if self.model_type == "j2": @@ -159,10 +172,10 @@ class AmazonProvider(BaseBedrockProvider): def _get_completion_from_dict(self, rsp_dict: dict) -> str: return rsp_dict["results"][0]["outputText"] - def get_choice_text_from_stream(self, event) -> str: + def get_choice_text_from_stream(self, event) -> Union[bool, str]: rsp_dict = json.loads(event["chunk"]["bytes"]) completions = rsp_dict["outputText"] - return completions + return False, completions PROVIDERS = { @@ -175,8 +188,14 @@ PROVIDERS = { } -def get_provider(model_id: str): - provider, model_name = model_id.split(".")[0:2] # meta、mistral…… +def get_provider(model_id: str, reasoning: bool = False, reasoning_tokens: int = 4000): + arr = model_id.split(".") + if len(arr) == 2: + provider, model_name = arr # meta、mistral…… + elif len(arr) == 3: + # some model_ids may contain country like us.xx.xxx + _, provider, model_name = arr + if provider not in PROVIDERS: raise KeyError(f"{provider} is not supported!") if provider == "meta": @@ -188,4 +207,4 @@ def get_provider(model_id: str): elif provider == "cohere": # distinguish between R/R+ and older models return PROVIDERS[provider](model_name) - return PROVIDERS[provider]() + return PROVIDERS[provider](reasoning=reasoning, reasoning_tokens=reasoning_tokens) diff --git a/metagpt/provider/bedrock/utils.py b/metagpt/provider/bedrock/utils.py index e67796362..371af38f9 100644 --- a/metagpt/provider/bedrock/utils.py +++ b/metagpt/provider/bedrock/utils.py @@ -48,6 +48,9 @@ SUPPORT_STREAM_MODELS = { "anthropic.claude-3-opus-20240229-v1:0": 4096, # Claude 3.5 Sonnet "anthropic.claude-3-5-sonnet-20240620-v1:0": 8192, + # Claude 3.7 Sonnet + "us.anthropic.claude-3-7-sonnet-20250219-v1:0": 131072, + "anthropic.claude-3-7-sonnet-20250219-v1:0": 131072, # Command Text # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command.html "cohere.command-text-v14": 4096, @@ -135,20 +138,6 @@ def messages_to_prompt_llama3(messages: list[dict]) -> str: return prompt -def messages_to_prompt_claude2(messages: list[dict]) -> str: - GENERAL_TEMPLATE = "\n\n{role}: {content}" - prompt = "" - for message in messages: - role = message.get("role", "") - content = message.get("content", "") - prompt += GENERAL_TEMPLATE.format(role=role, content=content) - - if role != "assistant": - prompt += "\n\nAssistant:" - - return prompt - - def get_max_tokens(model_id: str) -> int: try: max_tokens = (NOT_SUPPORT_STREAM_MODELS | SUPPORT_STREAM_MODELS)[model_id] diff --git a/metagpt/provider/bedrock_api.py b/metagpt/provider/bedrock_api.py index 72eefa7e5..8bbd5fe67 100644 --- a/metagpt/provider/bedrock_api.py +++ b/metagpt/provider/bedrock_api.py @@ -23,7 +23,9 @@ class BedrockLLM(BaseLLM): def __init__(self, config: LLMConfig): self.config = config self.__client = self.__init_client("bedrock-runtime") - self.__provider = get_provider(self.config.model) + self.__provider = get_provider( + self.config.model, reasoning=self.config.reasoning, reasoning_tokens=self.config.reasoning_tokens + ) self.cost_manager = CostManager(token_costs=BEDROCK_TOKEN_COSTS) if self.config.model in NOT_SUPPORT_STREAM_MODELS: logger.warning(f"model {self.config.model} doesn't support streaming output!") @@ -102,7 +104,11 @@ class BedrockLLM(BaseLLM): # However,aioboto3 doesn't support invoke model def get_choice_text(self, rsp: dict) -> str: - return self.__provider.get_choice_text(rsp) + rsp = self.__provider.get_choice_text(rsp) + if isinstance(rsp, dict): + self.reasoning_content = rsp.get("reasoning_content") + rsp = rsp.get("content") + return rsp async def acompletion(self, messages: list[dict]) -> dict: request_body = self.__provider.get_request_body(messages, self._const_kwargs) @@ -133,10 +139,16 @@ class BedrockLLM(BaseLLM): async def _get_stream_response_body(self, stream_response) -> List[str]: def collect_content() -> str: collected_content = [] + collected_reasoning_content = [] for event in stream_response["body"]: - chunk_text = self.__provider.get_choice_text_from_stream(event) - collected_content.append(chunk_text) + reasoning, chunk_text = self.__provider.get_choice_text_from_stream(event) + if reasoning: + collected_reasoning_content.append(chunk_text) + else: + collected_content.append(chunk_text) log_llm_stream(chunk_text) + if collected_reasoning_content: + self.reasoning_content = "".join(collected_reasoning_content) return collected_content loop = asyncio.get_running_loop()