mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-05 14:55:18 +02:00
add openrouter reasoning
This commit is contained in:
parent
9245e8ab80
commit
ff477066c5
9 changed files with 115 additions and 17 deletions
|
|
@ -36,6 +36,7 @@ class LLMType(Enum):
|
|||
MISTRAL = "mistral"
|
||||
YI = "yi" # lingyiwanwu
|
||||
OPENROUTER = "openrouter"
|
||||
OPENROUTER_REASONING = "openrouter_reasoning"
|
||||
BEDROCK = "bedrock"
|
||||
ARK = "ark" # https://www.volcengine.com/docs/82379/1263482#python-sdk
|
||||
|
||||
|
|
@ -102,7 +103,7 @@ class LLMConfig(YamlModel):
|
|||
|
||||
# reasoning / thinking switch
|
||||
reasoning: bool = False
|
||||
reasoning_tokens: int = 4000 # reasoning budget tokens to generate, usually smaller than max_tokens
|
||||
reasoning_max_token: int = 1024 # reasoning budget tokens to generate, usually smaller than max_token
|
||||
|
||||
@field_validator("api_key")
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from metagpt.provider.dashscope_api import DashScopeLLM
|
|||
from metagpt.provider.anthropic_api import AnthropicLLM
|
||||
from metagpt.provider.bedrock_api import BedrockLLM
|
||||
from metagpt.provider.ark_api import ArkLLM
|
||||
from metagpt.provider.openrouter_reasoning import OpenrouterReasoningLLM
|
||||
|
||||
__all__ = [
|
||||
"GeminiLLM",
|
||||
|
|
@ -34,4 +35,5 @@ __all__ = [
|
|||
"AnthropicLLM",
|
||||
"BedrockLLM",
|
||||
"ArkLLM",
|
||||
"OpenrouterReasoningLLM",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class AnthropicLLM(BaseLLM):
|
|||
kwargs["messages"] = messages[1:]
|
||||
kwargs["system"] = messages[0]["content"] # set system prompt here
|
||||
if self.config.reasoning:
|
||||
kwargs["thinking"] = {"type": "enabled", "budget_tokens": self.config.reasoning_tokens}
|
||||
kwargs["thinking"] = {"type": "enabled", "budget_tokens": self.config.reasoning_max_token}
|
||||
return kwargs
|
||||
|
||||
def _update_costs(self, usage: Usage, model: str = None, local_calc_usage: bool = True):
|
||||
|
|
|
|||
|
|
@ -7,9 +7,9 @@ class BaseBedrockProvider(ABC):
|
|||
# to handle different generation kwargs
|
||||
max_tokens_field_name = "max_tokens"
|
||||
|
||||
def __init__(self, reasoning: bool = False, reasoning_tokens: int = 4000):
|
||||
def __init__(self, reasoning: bool = False, reasoning_max_token: int = 1024):
|
||||
self.reasoning = reasoning
|
||||
self.reasoning_tokens = reasoning_tokens
|
||||
self.reasoning_max_token = reasoning_max_token
|
||||
|
||||
@abstractmethod
|
||||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
|
|
|
|||
|
|
@ -20,6 +20,8 @@ class MistralProvider(BaseBedrockProvider):
|
|||
|
||||
class AnthropicProvider(BaseBedrockProvider):
|
||||
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-37.html
|
||||
# https://docs.aws.amazon.com/code-library/latest/ug/python_3_bedrock-runtime_code_examples.html#anthropic_claude
|
||||
|
||||
def _split_system_user_messages(self, messages: list[dict]) -> Tuple[str, list[dict]]:
|
||||
system_messages = []
|
||||
|
|
@ -34,7 +36,7 @@ class AnthropicProvider(BaseBedrockProvider):
|
|||
def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs) -> str:
|
||||
if self.reasoning:
|
||||
generate_kwargs["temperature"] = 1 # should be 1
|
||||
generate_kwargs["thinking"] = {"type": "enabled", "budget_tokens": self.reasoning_tokens}
|
||||
generate_kwargs["thinking"] = {"type": "enabled", "budget_tokens": self.reasoning_max_token}
|
||||
|
||||
system_message, user_messages = self._split_system_user_messages(messages)
|
||||
body = json.dumps(
|
||||
|
|
@ -189,7 +191,7 @@ PROVIDERS = {
|
|||
}
|
||||
|
||||
|
||||
def get_provider(model_id: str, reasoning: bool = False, reasoning_tokens: int = 4000):
|
||||
def get_provider(model_id: str, reasoning: bool = False, reasoning_max_token: int = 1024):
|
||||
arr = model_id.split(".")
|
||||
if len(arr) == 2:
|
||||
provider, model_name = arr # meta、mistral……
|
||||
|
|
@ -208,4 +210,4 @@ def get_provider(model_id: str, reasoning: bool = False, reasoning_tokens: int =
|
|||
elif provider == "cohere":
|
||||
# distinguish between R/R+ and older models
|
||||
return PROVIDERS[provider](model_name)
|
||||
return PROVIDERS[provider](reasoning=reasoning, reasoning_tokens=reasoning_tokens)
|
||||
return PROVIDERS[provider](reasoning=reasoning, reasoning_max_token=reasoning_max_token)
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ class BedrockLLM(BaseLLM):
|
|||
self.config = config
|
||||
self.__client = self.__init_client("bedrock-runtime")
|
||||
self.__provider = get_provider(
|
||||
self.config.model, reasoning=self.config.reasoning, reasoning_tokens=self.config.reasoning_tokens
|
||||
self.config.model, reasoning=self.config.reasoning, reasoning_max_token=self.config.reasoning_max_token
|
||||
)
|
||||
self.cost_manager = CostManager(token_costs=BEDROCK_TOKEN_COSTS)
|
||||
if self.config.model in NOT_SUPPORT_STREAM_MODELS:
|
||||
|
|
|
|||
|
|
@ -150,6 +150,14 @@ class OpenAIResponse:
|
|||
h = self._headers.get("Openai-Processing-Ms")
|
||||
return None if h is None else round(float(h))
|
||||
|
||||
def decode_asjson(self) -> Optional[dict]:
|
||||
bstr = self.data.strip()
|
||||
if bstr.startswith(b"{") and bstr.endswith(b"}"):
|
||||
bstr = bstr.decode("utf-8")
|
||||
else:
|
||||
bstr = parse_stream_helper(bstr)
|
||||
return json.loads(bstr) if bstr else None
|
||||
|
||||
|
||||
def _build_api_url(url, query):
|
||||
scheme, netloc, path, base_query, fragment = urlsplit(url)
|
||||
|
|
@ -547,13 +555,6 @@ class APIRequestor:
|
|||
}
|
||||
try:
|
||||
result = await session.request(**request_kwargs)
|
||||
# log_info(
|
||||
# "LLM API response",
|
||||
# path=abs_url,
|
||||
# response_code=result.status,
|
||||
# processing_ms=result.headers.get("LLM-Processing-Ms"),
|
||||
# request_id=result.headers.get("X-Request-Id"),
|
||||
# )
|
||||
return result
|
||||
except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as e:
|
||||
raise openai.APITimeoutError("Request timed out") from e
|
||||
|
|
|
|||
86
metagpt/provider/openrouter_reasoning.py
Normal file
86
metagpt/provider/openrouter_reasoning.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
import json
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.const import USE_CONFIG_TIMEOUT
|
||||
from metagpt.logs import log_llm_stream
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.general_api_requestor import GeneralAPIRequestor, OpenAIResponse
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
|
||||
|
||||
@register_provider([LLMType.OPENROUTER_REASONING])
|
||||
class OpenrouterReasoningLLM(BaseLLM):
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.client = GeneralAPIRequestor(base_url=config.base_url)
|
||||
self.config = config
|
||||
self.model = self.config.model
|
||||
self.http_method = "post"
|
||||
self.base_url = "https://openrouter.ai/api/v1"
|
||||
self.url_suffix = "/chat/completions"
|
||||
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.config.api_key}"}
|
||||
|
||||
def decode(self, response: OpenAIResponse) -> dict:
|
||||
return json.loads(response.data.decode("utf-8"))
|
||||
|
||||
def _const_kwargs(
|
||||
self, messages: list[dict], stream: bool = False, timeout=USE_CONFIG_TIMEOUT, **extra_kwargs
|
||||
) -> dict:
|
||||
kwargs = {
|
||||
"messages": messages,
|
||||
"include_reasoning": True,
|
||||
"max_tokens": self.config.max_token,
|
||||
"temperature": self.config.temperature,
|
||||
"model": self.model,
|
||||
"stream": stream,
|
||||
}
|
||||
return kwargs
|
||||
|
||||
def get_choice_text(self, rsp: dict) -> str:
|
||||
if "reasoning" in rsp["choices"][0]["message"]:
|
||||
self.reasoning_content = rsp["choices"][0]["message"]["reasoning"]
|
||||
return rsp["choices"][0]["message"]["content"]
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> dict:
|
||||
payload = self._const_kwargs(messages)
|
||||
resp, _, _ = await self.client.arequest(
|
||||
url=self.url_suffix, method=self.http_method, params=payload, headers=self.headers # empty
|
||||
)
|
||||
resp = resp.decode_asjson()
|
||||
self._update_costs(resp["usage"], model=self.model)
|
||||
return resp
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
|
||||
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
|
||||
self.headers["Content-Type"] = "text/event-stream" # update header to adapt the client
|
||||
payload = self._const_kwargs(messages, stream=True)
|
||||
resp, _, _ = await self.client.arequest(
|
||||
url=self.url_suffix, method=self.http_method, params=payload, headers=self.headers, stream=True # empty
|
||||
)
|
||||
collected_content = []
|
||||
collected_reasoning_content = []
|
||||
usage = {}
|
||||
async for chunk in resp:
|
||||
chunk = chunk.decode_asjson()
|
||||
if not chunk:
|
||||
continue
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
if "reasoning" in delta and delta["reasoning"]:
|
||||
collected_reasoning_content.append(delta["reasoning"])
|
||||
elif delta["content"]:
|
||||
collected_content.append(delta["content"])
|
||||
log_llm_stream(delta["content"])
|
||||
|
||||
usage = chunk.get("usage")
|
||||
|
||||
log_llm_stream("\n")
|
||||
self._update_costs(usage, model=self.model)
|
||||
full_content = "".join(collected_content)
|
||||
if collected_reasoning_content:
|
||||
self.reasoning_content = "".join(collected_reasoning_content)
|
||||
return full_content
|
||||
|
|
@ -74,6 +74,7 @@ TOKEN_COSTS = {
|
|||
"claude-3-5-sonnet-20240620": {"prompt": 0.003, "completion": 0.015},
|
||||
"claude-3-opus-20240229": {"prompt": 0.015, "completion": 0.075},
|
||||
"claude-3-haiku-20240307": {"prompt": 0.00025, "completion": 0.00125},
|
||||
"claude-3-7-sonnet-20250219": {"prompt": 0.003, "completion": 0.015},
|
||||
"yi-34b-chat-0205": {"prompt": 0.0003, "completion": 0.0003},
|
||||
"yi-34b-chat-200k": {"prompt": 0.0017, "completion": 0.0017},
|
||||
"yi-large": {"prompt": 0.0028, "completion": 0.0028},
|
||||
|
|
@ -86,9 +87,14 @@ TOKEN_COSTS = {
|
|||
"openai/o1-mini": {"prompt": 0.003, "completion": 0.012},
|
||||
"anthropic/claude-3-opus": {"prompt": 0.015, "completion": 0.075},
|
||||
"anthropic/claude-3.5-sonnet": {"prompt": 0.003, "completion": 0.015},
|
||||
"anthropic/claude-3.7-sonnet": {"prompt": 0.003, "completion": 0.015},
|
||||
"anthropic/claude-3.7-sonnet:beta": {"prompt": 0.003, "completion": 0.015},
|
||||
"anthropic/claude-3.7-sonnet:thinking": {"prompt": 0.003, "completion": 0.015},
|
||||
"us.anthropic.claude-3-7-sonnet-20250219-v1:0": {"prompt": 0.003, "completion": 0.015},
|
||||
"google/gemini-pro-1.5": {"prompt": 0.0025, "completion": 0.0075}, # for openrouter, end
|
||||
"deepseek-chat": {"prompt": 0.00014, "completion": 0.00028},
|
||||
"deepseek-coder": {"prompt": 0.00014, "completion": 0.00028},
|
||||
"deepseek-chat": {"prompt": 0.00027, "completion": 0.0011},
|
||||
"deepseek-coder": {"prompt": 0.00027, "completion": 0.0011},
|
||||
"deepseek-reasoner": {"prompt": 0.00055, "completion": 0.0022},
|
||||
# For ark model https://www.volcengine.com/docs/82379/1099320
|
||||
"doubao-lite-4k-240515": {"prompt": 0.000043, "completion": 0.000086},
|
||||
"doubao-lite-32k-240515": {"prompt": 0.000043, "completion": 0.000086},
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue