mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-02 14:45:17 +02:00
Merge pull request #1726 from better629/reasoning
add LLM Reasoning models
This commit is contained in:
commit
55726e06a5
15 changed files with 233 additions and 68 deletions
|
|
@ -13,7 +13,9 @@ from metagpt.logs import logger
|
|||
|
||||
async def ask_and_print(question: str, llm: LLM, system_prompt) -> str:
|
||||
logger.info(f"Q: {question}")
|
||||
rsp = await llm.aask(question, system_msgs=[system_prompt])
|
||||
rsp = await llm.aask(question, system_msgs=[system_prompt], stream=True)
|
||||
if llm.reasoning_content:
|
||||
logger.info(f"A reasoning: {llm.reasoning_content}")
|
||||
logger.info(f"A: {rsp}")
|
||||
return rsp
|
||||
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ class LLMType(Enum):
|
|||
DEEPSEEK = "deepseek"
|
||||
SILICONFLOW = "siliconflow"
|
||||
OPENROUTER = "openrouter"
|
||||
OPENROUTER_REASONING = "openrouter_reasoning"
|
||||
BEDROCK = "bedrock"
|
||||
ARK = "ark" # https://www.volcengine.com/docs/82379/1263482#python-sdk
|
||||
|
||||
|
|
@ -107,6 +108,10 @@ class LLMConfig(YamlModel):
|
|||
# For Messages Control
|
||||
use_system_prompt: bool = True
|
||||
|
||||
# reasoning / thinking switch
|
||||
reasoning: bool = False
|
||||
reasoning_max_token: int = 4000 # reasoning budget tokens to generate, usually smaller than max_token
|
||||
|
||||
@field_validator("api_key")
|
||||
@classmethod
|
||||
def check_llm_key(cls, v):
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from metagpt.provider.dashscope_api import DashScopeLLM
|
|||
from metagpt.provider.anthropic_api import AnthropicLLM
|
||||
from metagpt.provider.bedrock_api import BedrockLLM
|
||||
from metagpt.provider.ark_api import ArkLLM
|
||||
from metagpt.provider.openrouter_reasoning import OpenrouterReasoningLLM
|
||||
|
||||
__all__ = [
|
||||
"GeminiLLM",
|
||||
|
|
@ -34,4 +35,5 @@ __all__ = [
|
|||
"AnthropicLLM",
|
||||
"BedrockLLM",
|
||||
"ArkLLM",
|
||||
"OpenrouterReasoningLLM",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -33,6 +33,8 @@ class AnthropicLLM(BaseLLM):
|
|||
if messages[0]["role"] == "system":
|
||||
kwargs["messages"] = messages[1:]
|
||||
kwargs["system"] = messages[0]["content"] # set system prompt here
|
||||
if self.config.reasoning:
|
||||
kwargs["thinking"] = {"type": "enabled", "budget_tokens": self.config.reasoning_max_token}
|
||||
return kwargs
|
||||
|
||||
def _update_costs(self, usage: Usage, model: str = None, local_calc_usage: bool = True):
|
||||
|
|
@ -40,7 +42,12 @@ class AnthropicLLM(BaseLLM):
|
|||
super()._update_costs(usage, model)
|
||||
|
||||
def get_choice_text(self, resp: Message) -> str:
|
||||
return resp.content[0].text
|
||||
if len(resp.content) > 1:
|
||||
self.reasoning_content = resp.content[0].thinking
|
||||
text = resp.content[1].text
|
||||
else:
|
||||
text = resp.content[0].text
|
||||
return text
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> Message:
|
||||
resp: Message = await self.aclient.messages.create(**self._const_kwargs(messages))
|
||||
|
|
@ -53,6 +60,7 @@ class AnthropicLLM(BaseLLM):
|
|||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
|
||||
stream = await self.aclient.messages.create(**self._const_kwargs(messages, stream=True))
|
||||
collected_content = []
|
||||
collected_reasoning_content = []
|
||||
usage = Usage(input_tokens=0, output_tokens=0)
|
||||
async for event in stream:
|
||||
event_type = event.type
|
||||
|
|
@ -60,13 +68,19 @@ class AnthropicLLM(BaseLLM):
|
|||
usage.input_tokens = event.message.usage.input_tokens
|
||||
usage.output_tokens = event.message.usage.output_tokens
|
||||
elif event_type == "content_block_delta":
|
||||
content = event.delta.text
|
||||
log_llm_stream(content)
|
||||
collected_content.append(content)
|
||||
delta_type = event.delta.type
|
||||
if delta_type == "thinking_delta":
|
||||
collected_reasoning_content.append(event.delta.thinking)
|
||||
elif delta_type == "text_delta":
|
||||
content = event.delta.text
|
||||
log_llm_stream(content)
|
||||
collected_content.append(content)
|
||||
elif event_type == "message_delta":
|
||||
usage.output_tokens = event.usage.output_tokens # update final output_tokens
|
||||
|
||||
log_llm_stream("\n")
|
||||
self._update_costs(usage)
|
||||
full_content = "".join(collected_content)
|
||||
if collected_reasoning_content:
|
||||
self.reasoning_content = "".join(collected_reasoning_content)
|
||||
return full_content
|
||||
|
|
|
|||
|
|
@ -45,6 +45,16 @@ class BaseLLM(ABC):
|
|||
model: Optional[str] = None # deprecated
|
||||
pricing_plan: Optional[str] = None
|
||||
|
||||
_reasoning_content: Optional[str] = None # content from reasoning mode
|
||||
|
||||
@property
|
||||
def reasoning_content(self):
|
||||
return self._reasoning_content
|
||||
|
||||
@reasoning_content.setter
|
||||
def reasoning_content(self, value: str):
|
||||
self._reasoning_content = value
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, config: LLMConfig):
|
||||
pass
|
||||
|
|
@ -216,7 +226,10 @@ class BaseLLM(ABC):
|
|||
|
||||
def get_choice_text(self, rsp: dict) -> str:
|
||||
"""Required to provide the first text of choice"""
|
||||
return rsp.get("choices")[0]["message"]["content"]
|
||||
message = rsp.get("choices")[0]["message"]
|
||||
if "reasoning_content" in message:
|
||||
self.reasoning_content = message["reasoning_content"]
|
||||
return message["content"]
|
||||
|
||||
def get_choice_delta_text(self, rsp: dict) -> str:
|
||||
"""Required to provide the first text of stream choice"""
|
||||
|
|
|
|||
|
|
@ -1,11 +1,16 @@
|
|||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Union
|
||||
|
||||
|
||||
class BaseBedrockProvider(ABC):
|
||||
# to handle different generation kwargs
|
||||
max_tokens_field_name = "max_tokens"
|
||||
|
||||
def __init__(self, reasoning: bool = False, reasoning_max_token: int = 4000):
|
||||
self.reasoning = reasoning
|
||||
self.reasoning_max_token = reasoning_max_token
|
||||
|
||||
@abstractmethod
|
||||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
...
|
||||
|
|
@ -14,14 +19,14 @@ class BaseBedrockProvider(ABC):
|
|||
body = json.dumps({"prompt": self.messages_to_prompt(messages), **const_kwargs})
|
||||
return body
|
||||
|
||||
def get_choice_text(self, response_body: dict) -> str:
|
||||
def get_choice_text(self, response_body: dict) -> Union[str, dict[str, str]]:
|
||||
completions = self._get_completion_from_dict(response_body)
|
||||
return completions
|
||||
|
||||
def get_choice_text_from_stream(self, event) -> str:
|
||||
def get_choice_text_from_stream(self, event) -> Union[bool, str]:
|
||||
rsp_dict = json.loads(event["chunk"]["bytes"])
|
||||
completions = self._get_completion_from_dict(rsp_dict)
|
||||
return completions
|
||||
return False, completions
|
||||
|
||||
def messages_to_prompt(self, messages: list[dict]) -> str:
|
||||
"""[{"role": "user", "content": msg}] to user: <msg> etc."""
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import json
|
||||
from typing import Literal, Tuple
|
||||
from typing import Literal, Tuple, Union
|
||||
|
||||
from metagpt.provider.bedrock.base_provider import BaseBedrockProvider
|
||||
from metagpt.provider.bedrock.utils import (
|
||||
|
|
@ -20,6 +20,8 @@ class MistralProvider(BaseBedrockProvider):
|
|||
|
||||
class AnthropicProvider(BaseBedrockProvider):
|
||||
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-37.html
|
||||
# https://docs.aws.amazon.com/code-library/latest/ug/python_3_bedrock-runtime_code_examples.html#anthropic_claude
|
||||
|
||||
def _split_system_user_messages(self, messages: list[dict]) -> Tuple[str, list[dict]]:
|
||||
system_messages = []
|
||||
|
|
@ -32,6 +34,10 @@ class AnthropicProvider(BaseBedrockProvider):
|
|||
return self.messages_to_prompt(system_messages), user_messages
|
||||
|
||||
def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs) -> str:
|
||||
if self.reasoning:
|
||||
generate_kwargs["temperature"] = 1 # should be 1
|
||||
generate_kwargs["thinking"] = {"type": "enabled", "budget_tokens": self.reasoning_max_token}
|
||||
|
||||
system_message, user_messages = self._split_system_user_messages(messages)
|
||||
body = json.dumps(
|
||||
{
|
||||
|
|
@ -43,17 +49,27 @@ class AnthropicProvider(BaseBedrockProvider):
|
|||
)
|
||||
return body
|
||||
|
||||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
def _get_completion_from_dict(self, rsp_dict: dict) -> dict[str, Tuple[str, str]]:
|
||||
if self.reasoning:
|
||||
return {"reasoning_content": rsp_dict["content"][0]["thinking"], "content": rsp_dict["content"][1]["text"]}
|
||||
return rsp_dict["content"][0]["text"]
|
||||
|
||||
def get_choice_text_from_stream(self, event) -> str:
|
||||
def get_choice_text_from_stream(self, event) -> Union[bool, str]:
|
||||
# https://docs.anthropic.com/claude/reference/messages-streaming
|
||||
rsp_dict = json.loads(event["chunk"]["bytes"])
|
||||
if rsp_dict["type"] == "content_block_delta":
|
||||
completions = rsp_dict["delta"]["text"]
|
||||
return completions
|
||||
reasoning = False
|
||||
delta_type = rsp_dict["delta"]["type"]
|
||||
if delta_type == "text_delta":
|
||||
completions = rsp_dict["delta"]["text"]
|
||||
elif delta_type == "thinking_delta":
|
||||
completions = rsp_dict["delta"]["thinking"]
|
||||
reasoning = True
|
||||
elif delta_type == "signature_delta":
|
||||
completions = ""
|
||||
return reasoning, completions
|
||||
else:
|
||||
return ""
|
||||
return False, ""
|
||||
|
||||
|
||||
class CohereProvider(BaseBedrockProvider):
|
||||
|
|
@ -87,10 +103,10 @@ class CohereProvider(BaseBedrockProvider):
|
|||
body = json.dumps({"prompt": prompt, "stream": kwargs.get("stream", False), **generate_kwargs})
|
||||
return body
|
||||
|
||||
def get_choice_text_from_stream(self, event) -> str:
|
||||
def get_choice_text_from_stream(self, event) -> Union[bool, str]:
|
||||
rsp_dict = json.loads(event["chunk"]["bytes"])
|
||||
completions = rsp_dict.get("text", "")
|
||||
return completions
|
||||
return False, completions
|
||||
|
||||
|
||||
class MetaProvider(BaseBedrockProvider):
|
||||
|
|
@ -133,10 +149,10 @@ class Ai21Provider(BaseBedrockProvider):
|
|||
)
|
||||
return body
|
||||
|
||||
def get_choice_text_from_stream(self, event) -> str:
|
||||
def get_choice_text_from_stream(self, event) -> Union[bool, str]:
|
||||
rsp_dict = json.loads(event["chunk"]["bytes"])
|
||||
completions = rsp_dict.get("choices", [{}])[0].get("delta", {}).get("content", "")
|
||||
return completions
|
||||
return False, completions
|
||||
|
||||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
if self.model_type == "j2":
|
||||
|
|
@ -159,10 +175,10 @@ class AmazonProvider(BaseBedrockProvider):
|
|||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
return rsp_dict["results"][0]["outputText"]
|
||||
|
||||
def get_choice_text_from_stream(self, event) -> str:
|
||||
def get_choice_text_from_stream(self, event) -> Union[bool, str]:
|
||||
rsp_dict = json.loads(event["chunk"]["bytes"])
|
||||
completions = rsp_dict["outputText"]
|
||||
return completions
|
||||
return False, completions
|
||||
|
||||
|
||||
PROVIDERS = {
|
||||
|
|
@ -175,8 +191,14 @@ PROVIDERS = {
|
|||
}
|
||||
|
||||
|
||||
def get_provider(model_id: str):
|
||||
provider, model_name = model_id.split(".")[0:2] # meta、mistral……
|
||||
def get_provider(model_id: str, reasoning: bool = False, reasoning_max_token: int = 4000):
|
||||
arr = model_id.split(".")
|
||||
if len(arr) == 2:
|
||||
provider, model_name = arr # meta、mistral……
|
||||
elif len(arr) == 3:
|
||||
# some model_ids may contain country like us.xx.xxx
|
||||
_, provider, model_name = arr
|
||||
|
||||
if provider not in PROVIDERS:
|
||||
raise KeyError(f"{provider} is not supported!")
|
||||
if provider == "meta":
|
||||
|
|
@ -188,4 +210,4 @@ def get_provider(model_id: str):
|
|||
elif provider == "cohere":
|
||||
# distinguish between R/R+ and older models
|
||||
return PROVIDERS[provider](model_name)
|
||||
return PROVIDERS[provider]()
|
||||
return PROVIDERS[provider](reasoning=reasoning, reasoning_max_token=reasoning_max_token)
|
||||
|
|
|
|||
|
|
@ -48,6 +48,9 @@ SUPPORT_STREAM_MODELS = {
|
|||
"anthropic.claude-3-opus-20240229-v1:0": 4096,
|
||||
# Claude 3.5 Sonnet
|
||||
"anthropic.claude-3-5-sonnet-20240620-v1:0": 8192,
|
||||
# Claude 3.7 Sonnet
|
||||
"us.anthropic.claude-3-7-sonnet-20250219-v1:0": 131072,
|
||||
"anthropic.claude-3-7-sonnet-20250219-v1:0": 131072,
|
||||
# Command Text
|
||||
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command.html
|
||||
"cohere.command-text-v14": 4096,
|
||||
|
|
@ -135,20 +138,6 @@ def messages_to_prompt_llama3(messages: list[dict]) -> str:
|
|||
return prompt
|
||||
|
||||
|
||||
def messages_to_prompt_claude2(messages: list[dict]) -> str:
|
||||
GENERAL_TEMPLATE = "\n\n{role}: {content}"
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
role = message.get("role", "")
|
||||
content = message.get("content", "")
|
||||
prompt += GENERAL_TEMPLATE.format(role=role, content=content)
|
||||
|
||||
if role != "assistant":
|
||||
prompt += "\n\nAssistant:"
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def get_max_tokens(model_id: str) -> int:
|
||||
try:
|
||||
max_tokens = (NOT_SUPPORT_STREAM_MODELS | SUPPORT_STREAM_MODELS)[model_id]
|
||||
|
|
|
|||
|
|
@ -23,7 +23,9 @@ class BedrockLLM(BaseLLM):
|
|||
def __init__(self, config: LLMConfig):
|
||||
self.config = config
|
||||
self.__client = self.__init_client("bedrock-runtime")
|
||||
self.__provider = get_provider(self.config.model)
|
||||
self.__provider = get_provider(
|
||||
self.config.model, reasoning=self.config.reasoning, reasoning_max_token=self.config.reasoning_max_token
|
||||
)
|
||||
self.cost_manager = CostManager(token_costs=BEDROCK_TOKEN_COSTS)
|
||||
if self.config.model in NOT_SUPPORT_STREAM_MODELS:
|
||||
logger.warning(f"model {self.config.model} doesn't support streaming output!")
|
||||
|
|
@ -102,7 +104,11 @@ class BedrockLLM(BaseLLM):
|
|||
# However,aioboto3 doesn't support invoke model
|
||||
|
||||
def get_choice_text(self, rsp: dict) -> str:
|
||||
return self.__provider.get_choice_text(rsp)
|
||||
rsp = self.__provider.get_choice_text(rsp)
|
||||
if isinstance(rsp, dict):
|
||||
self.reasoning_content = rsp.get("reasoning_content")
|
||||
rsp = rsp.get("content")
|
||||
return rsp
|
||||
|
||||
async def acompletion(self, messages: list[dict]) -> dict:
|
||||
request_body = self.__provider.get_request_body(messages, self._const_kwargs)
|
||||
|
|
@ -133,10 +139,16 @@ class BedrockLLM(BaseLLM):
|
|||
async def _get_stream_response_body(self, stream_response) -> List[str]:
|
||||
def collect_content() -> str:
|
||||
collected_content = []
|
||||
collected_reasoning_content = []
|
||||
for event in stream_response["body"]:
|
||||
chunk_text = self.__provider.get_choice_text_from_stream(event)
|
||||
collected_content.append(chunk_text)
|
||||
log_llm_stream(chunk_text)
|
||||
reasoning, chunk_text = self.__provider.get_choice_text_from_stream(event)
|
||||
if reasoning:
|
||||
collected_reasoning_content.append(chunk_text)
|
||||
else:
|
||||
collected_content.append(chunk_text)
|
||||
log_llm_stream(chunk_text)
|
||||
if collected_reasoning_content:
|
||||
self.reasoning_content = "".join(collected_reasoning_content)
|
||||
return collected_content
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
|
|
|
|||
|
|
@ -150,6 +150,14 @@ class OpenAIResponse:
|
|||
h = self._headers.get("Openai-Processing-Ms")
|
||||
return None if h is None else round(float(h))
|
||||
|
||||
def decode_asjson(self) -> Optional[dict]:
|
||||
bstr = self.data.strip()
|
||||
if bstr.startswith(b"{") and bstr.endswith(b"}"):
|
||||
bstr = bstr.decode("utf-8")
|
||||
else:
|
||||
bstr = parse_stream_helper(bstr)
|
||||
return json.loads(bstr) if bstr else None
|
||||
|
||||
|
||||
def _build_api_url(url, query):
|
||||
scheme, netloc, path, base_query, fragment = urlsplit(url)
|
||||
|
|
@ -547,13 +555,6 @@ class APIRequestor:
|
|||
}
|
||||
try:
|
||||
result = await session.request(**request_kwargs)
|
||||
# log_info(
|
||||
# "LLM API response",
|
||||
# path=abs_url,
|
||||
# response_code=result.status,
|
||||
# processing_ms=result.headers.get("LLM-Processing-Ms"),
|
||||
# request_id=result.headers.get("X-Request-Id"),
|
||||
# )
|
||||
return result
|
||||
except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as e:
|
||||
raise openai.APITimeoutError("Request timed out") from e
|
||||
|
|
|
|||
|
|
@ -82,7 +82,7 @@ class OpenAILLM(BaseLLM):
|
|||
def _get_proxy_params(self) -> dict:
|
||||
params = {}
|
||||
if self.config.proxy:
|
||||
params = {"proxies": self.config.proxy}
|
||||
params = {"proxy": self.config.proxy}
|
||||
if self.config.base_url:
|
||||
params["base_url"] = self.config.base_url
|
||||
|
||||
|
|
@ -94,12 +94,19 @@ class OpenAILLM(BaseLLM):
|
|||
)
|
||||
usage = None
|
||||
collected_messages = []
|
||||
collected_reasoning_messages = []
|
||||
has_finished = False
|
||||
async for chunk in response:
|
||||
chunk_message = chunk.choices[0].delta.content or "" if chunk.choices else "" # extract the message
|
||||
finish_reason = (
|
||||
chunk.choices[0].finish_reason if chunk.choices and hasattr(chunk.choices[0], "finish_reason") else None
|
||||
)
|
||||
if not chunk.choices:
|
||||
continue
|
||||
|
||||
choice0 = chunk.choices[0]
|
||||
choice_delta = choice0.delta
|
||||
if hasattr(choice_delta, "reasoning_content") and choice_delta.reasoning_content:
|
||||
collected_reasoning_messages.append(choice_delta.reasoning_content) # for deepseek
|
||||
continue
|
||||
chunk_message = choice_delta.content or "" # extract the message
|
||||
finish_reason = choice0.finish_reason if hasattr(choice0, "finish_reason") else None
|
||||
log_llm_stream(chunk_message)
|
||||
collected_messages.append(chunk_message)
|
||||
chunk_has_usage = hasattr(chunk, "usage") and chunk.usage
|
||||
|
|
@ -110,17 +117,16 @@ class OpenAILLM(BaseLLM):
|
|||
if finish_reason:
|
||||
if chunk_has_usage:
|
||||
# Some services have usage as an attribute of the chunk, such as Fireworks
|
||||
if isinstance(chunk.usage, CompletionUsage):
|
||||
usage = chunk.usage
|
||||
else:
|
||||
usage = CompletionUsage(**chunk.usage)
|
||||
elif hasattr(chunk.choices[0], "usage"):
|
||||
usage = CompletionUsage(**chunk.usage) if isinstance(chunk.usage, dict) else chunk.usage
|
||||
elif hasattr(choice0, "usage"):
|
||||
# The usage of some services is an attribute of chunk.choices[0], such as Moonshot
|
||||
usage = CompletionUsage(**chunk.choices[0].usage)
|
||||
usage = CompletionUsage(**choice0.usage)
|
||||
has_finished = True
|
||||
|
||||
log_llm_stream("\n")
|
||||
full_reply_content = "".join(collected_messages)
|
||||
if collected_reasoning_messages:
|
||||
self.reasoning_content = "".join(collected_reasoning_messages)
|
||||
if not usage:
|
||||
# Some services do not provide the usage attribute, such as OpenAI or OpenLLM
|
||||
usage = self._calc_usage(messages, full_reply_content)
|
||||
|
|
|
|||
86
metagpt/provider/openrouter_reasoning.py
Normal file
86
metagpt/provider/openrouter_reasoning.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
import json
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.const import USE_CONFIG_TIMEOUT
|
||||
from metagpt.logs import log_llm_stream
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.general_api_requestor import GeneralAPIRequestor, OpenAIResponse
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
|
||||
|
||||
@register_provider([LLMType.OPENROUTER_REASONING])
|
||||
class OpenrouterReasoningLLM(BaseLLM):
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.client = GeneralAPIRequestor(base_url=config.base_url)
|
||||
self.config = config
|
||||
self.model = self.config.model
|
||||
self.http_method = "post"
|
||||
self.base_url = "https://openrouter.ai/api/v1"
|
||||
self.url_suffix = "/chat/completions"
|
||||
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.config.api_key}"}
|
||||
|
||||
def decode(self, response: OpenAIResponse) -> dict:
|
||||
return json.loads(response.data.decode("utf-8"))
|
||||
|
||||
def _const_kwargs(
|
||||
self, messages: list[dict], stream: bool = False, timeout=USE_CONFIG_TIMEOUT, **extra_kwargs
|
||||
) -> dict:
|
||||
kwargs = {
|
||||
"messages": messages,
|
||||
"include_reasoning": True,
|
||||
"max_tokens": self.config.max_token,
|
||||
"temperature": self.config.temperature,
|
||||
"model": self.model,
|
||||
"stream": stream,
|
||||
}
|
||||
return kwargs
|
||||
|
||||
def get_choice_text(self, rsp: dict) -> str:
|
||||
if "reasoning" in rsp["choices"][0]["message"]:
|
||||
self.reasoning_content = rsp["choices"][0]["message"]["reasoning"]
|
||||
return rsp["choices"][0]["message"]["content"]
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> dict:
|
||||
payload = self._const_kwargs(messages)
|
||||
resp, _, _ = await self.client.arequest(
|
||||
url=self.url_suffix, method=self.http_method, params=payload, headers=self.headers # empty
|
||||
)
|
||||
resp = resp.decode_asjson()
|
||||
self._update_costs(resp["usage"], model=self.model)
|
||||
return resp
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
|
||||
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
|
||||
self.headers["Content-Type"] = "text/event-stream" # update header to adapt the client
|
||||
payload = self._const_kwargs(messages, stream=True)
|
||||
resp, _, _ = await self.client.arequest(
|
||||
url=self.url_suffix, method=self.http_method, params=payload, headers=self.headers, stream=True # empty
|
||||
)
|
||||
collected_content = []
|
||||
collected_reasoning_content = []
|
||||
usage = {}
|
||||
async for chunk in resp:
|
||||
chunk = chunk.decode_asjson()
|
||||
if not chunk:
|
||||
continue
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
if "reasoning" in delta and delta["reasoning"]:
|
||||
collected_reasoning_content.append(delta["reasoning"])
|
||||
elif delta["content"]:
|
||||
collected_content.append(delta["content"])
|
||||
log_llm_stream(delta["content"])
|
||||
|
||||
usage = chunk.get("usage")
|
||||
|
||||
log_llm_stream("\n")
|
||||
self._update_costs(usage, model=self.model)
|
||||
full_content = "".join(collected_content)
|
||||
if collected_reasoning_content:
|
||||
self.reasoning_content = "".join(collected_reasoning_content)
|
||||
return full_content
|
||||
|
|
@ -144,6 +144,6 @@ class FireworksCostManager(CostManager):
|
|||
cost = (prompt_tokens * token_costs["prompt"] + completion_tokens * token_costs["completion"]) / 1000000
|
||||
self.total_cost += cost
|
||||
logger.info(
|
||||
f"Total running cost: ${self.total_cost:.4f}"
|
||||
f"Total running cost: ${self.total_cost:.4f}, "
|
||||
f"Current cost: ${cost:.4f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -72,6 +72,7 @@ TOKEN_COSTS = {
|
|||
"claude-3-5-sonnet-20240620": {"prompt": 0.003, "completion": 0.015},
|
||||
"claude-3-opus-20240229": {"prompt": 0.015, "completion": 0.075},
|
||||
"claude-3-haiku-20240307": {"prompt": 0.00025, "completion": 0.00125},
|
||||
"claude-3-7-sonnet-20250219": {"prompt": 0.003, "completion": 0.015},
|
||||
"yi-34b-chat-0205": {"prompt": 0.0003, "completion": 0.0003},
|
||||
"yi-34b-chat-200k": {"prompt": 0.0017, "completion": 0.0017},
|
||||
"openai/gpt-4": {"prompt": 0.03, "completion": 0.06}, # start, for openrouter
|
||||
|
|
@ -92,9 +93,16 @@ TOKEN_COSTS = {
|
|||
"openai/o1-preview": {"prompt": 0.015, "completion": 0.06},
|
||||
"openai/o1-mini": {"prompt": 0.003, "completion": 0.012},
|
||||
"anthropic/claude-3-opus": {"prompt": 0.015, "completion": 0.075},
|
||||
"anthropic/claude-3.5-sonnet": {"prompt": 0.003, "completion": 0.015},
|
||||
"anthropic/claude-3.7-sonnet": {"prompt": 0.003, "completion": 0.015},
|
||||
"anthropic/claude-3.7-sonnet:beta": {"prompt": 0.003, "completion": 0.015},
|
||||
"anthropic/claude-3.7-sonnet:thinking": {"prompt": 0.003, "completion": 0.015},
|
||||
"anthropic.claude-3-7-sonnet-20250219-v1:0": {"prompt": 0.003, "completion": 0.015},
|
||||
"us.anthropic.claude-3-7-sonnet-20250219-v1:0": {"prompt": 0.003, "completion": 0.015},
|
||||
"google/gemini-pro-1.5": {"prompt": 0.0025, "completion": 0.0075}, # for openrouter, end
|
||||
"deepseek-chat": {"prompt": 0.00014, "completion": 0.00028},
|
||||
"deepseek-coder": {"prompt": 0.00014, "completion": 0.00028},
|
||||
"deepseek-chat": {"prompt": 0.00027, "completion": 0.0011},
|
||||
"deepseek-coder": {"prompt": 0.00027, "completion": 0.0011},
|
||||
"deepseek-reasoner": {"prompt": 0.00055, "completion": 0.0022},
|
||||
# For ark model https://www.volcengine.com/docs/82379/1099320
|
||||
"doubao-lite-4k-240515": {"prompt": 0.000043, "completion": 0.000086},
|
||||
"doubao-lite-32k-240515": {"prompt": 0.000043, "completion": 0.000086},
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ lancedb==0.4.0
|
|||
loguru==0.6.0
|
||||
meilisearch==0.21.0
|
||||
numpy~=1.26.4
|
||||
openai~=1.39.0
|
||||
openai~=1.64.0
|
||||
openpyxl~=3.1.5
|
||||
beautifulsoup4==4.12.3
|
||||
pandas==2.1.1
|
||||
|
|
@ -31,7 +31,7 @@ tqdm==4.66.2
|
|||
#unstructured[local-inference]
|
||||
# selenium>4
|
||||
# webdriver_manager<3.9
|
||||
anthropic==0.18.1
|
||||
anthropic==0.47.2
|
||||
typing-inspect==0.8.0
|
||||
libcst==1.0.1
|
||||
qdrant-client==1.7.0
|
||||
|
|
@ -59,7 +59,7 @@ nbformat==5.9.2
|
|||
ipython==8.17.2
|
||||
ipykernel==6.27.1
|
||||
scikit_learn==1.3.2
|
||||
typing-extensions==4.9.0
|
||||
typing-extensions==4.11.0
|
||||
socksio~=1.0.0
|
||||
gitignore-parser==0.1.9
|
||||
# connexion[uvicorn]~=3.0.5 # Used by metagpt/tools/openapi_v3_hello.py
|
||||
|
|
@ -88,4 +88,4 @@ spark_ai_python~=0.3.30
|
|||
agentops
|
||||
tree_sitter~=0.23.2
|
||||
tree_sitter_python~=0.23.2
|
||||
httpx==0.27.2
|
||||
httpx==0.28.1
|
||||
Loading…
Add table
Add a link
Reference in a new issue