mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-09 07:42:38 +02:00
Merge branch 'main' into code_interpreter
This commit is contained in:
commit
a570c81ccf
37 changed files with 632 additions and 356 deletions
|
|
@ -16,6 +16,7 @@ from metagpt.provider.azure_openai_api import AzureOpenAILLM
|
|||
from metagpt.provider.metagpt_api import MetaGPTLLM
|
||||
from metagpt.provider.human_provider import HumanProvider
|
||||
from metagpt.provider.spark_api import SparkLLM
|
||||
from metagpt.provider.qianfan_api import QianFanLLM
|
||||
|
||||
__all__ = [
|
||||
"FireworksLLM",
|
||||
|
|
@ -28,4 +29,5 @@ __all__ = [
|
|||
"OllamaLLM",
|
||||
"HumanProvider",
|
||||
"SparkLLM",
|
||||
"QianFanLLM",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -11,11 +11,12 @@ from abc import ABC, abstractmethod
|
|||
from typing import Optional, Union
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.cost_manager import CostManager, Costs
|
||||
|
||||
|
||||
class BaseLLM(ABC):
|
||||
|
|
@ -67,6 +68,28 @@ class BaseLLM(ABC):
|
|||
def _default_system_msg(self):
|
||||
return self._system_msg(self.system_prompt)
|
||||
|
||||
def _update_costs(self, usage: Union[dict, BaseModel], model: str = None, local_calc_usage: bool = True):
|
||||
"""update each request's token cost
|
||||
Args:
|
||||
model (str): model name or in some scenarios called endpoint
|
||||
local_calc_usage (bool): some models don't calculate usage, it will overwrite LLMConfig.calc_usage
|
||||
"""
|
||||
calc_usage = self.config.calc_usage and local_calc_usage
|
||||
model = model if model else self.model
|
||||
usage = usage.model_dump() if isinstance(usage, BaseModel) else usage
|
||||
if calc_usage and self.cost_manager:
|
||||
try:
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0))
|
||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
self.cost_manager.update_cost(prompt_tokens, completion_tokens, model)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.__class__.__name__} updats costs failed! exp: {e}")
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
if not self.cost_manager:
|
||||
return Costs(0, 0, 0, 0)
|
||||
return self.cost_manager.get_costs()
|
||||
|
||||
async def aask(
|
||||
self,
|
||||
msg: str,
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from metagpt.configs.llm_config import LLMConfig, LLMType
|
|||
from metagpt.logs import logger
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import OpenAILLM, log_and_reraise
|
||||
from metagpt.utils.cost_manager import CostManager, Costs
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
|
||||
MODEL_GRADE_TOKEN_COSTS = {
|
||||
"-1": {"prompt": 0.0, "completion": 0.0}, # abnormal condition
|
||||
|
|
@ -81,17 +81,6 @@ class FireworksLLM(OpenAILLM):
|
|||
kwargs = dict(api_key=self.config.api_key, base_url=self.config.base_url)
|
||||
return kwargs
|
||||
|
||||
def _update_costs(self, usage: CompletionUsage):
|
||||
if self.config.calc_usage and usage:
|
||||
try:
|
||||
# use FireworksCostManager not context.cost_manager
|
||||
self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"updating costs failed!, exp: {e}")
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
return self.cost_manager.get_costs()
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
|
||||
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
|
||||
**self._cons_kwargs(messages), stream=True
|
||||
|
|
@ -113,7 +102,7 @@ class FireworksLLM(OpenAILLM):
|
|||
usage = CompletionUsage(**chunk.usage)
|
||||
|
||||
full_content = "".join(collected_content)
|
||||
self._update_costs(usage)
|
||||
self._update_costs(usage.model_dump())
|
||||
return full_content
|
||||
|
||||
@retry(
|
||||
|
|
|
|||
|
|
@ -60,7 +60,8 @@ class GeneralAPIRequestor(APIRequestor):
|
|||
self, result: requests.Response, stream: bool
|
||||
) -> Tuple[Union[bytes, Iterator[Generator]], bytes]:
|
||||
"""Returns the response(s) and a bool indicating whether it is a stream."""
|
||||
if stream and "text/event-stream" in result.headers.get("Content-Type", ""):
|
||||
content_type = result.headers.get("Content-Type", "")
|
||||
if stream and ("text/event-stream" in content_type or "application/x-ndjson" in content_type):
|
||||
return (
|
||||
self._interpret_response_line(line, result.status_code, result.headers, stream=True)
|
||||
for line in parse_stream(result.iter_lines())
|
||||
|
|
|
|||
|
|
@ -72,16 +72,6 @@ class GeminiLLM(BaseLLM):
|
|||
kwargs = {"contents": messages, "generation_config": GenerationConfig(temperature=0.3), "stream": stream}
|
||||
return kwargs
|
||||
|
||||
def _update_costs(self, usage: dict):
|
||||
"""update each request's token cost"""
|
||||
if self.config.calc_usage:
|
||||
try:
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0))
|
||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"google gemini updats costs failed! exp: {e}")
|
||||
|
||||
def get_choice_text(self, resp: GenerateContentResponse) -> str:
|
||||
return resp.text
|
||||
|
||||
|
|
|
|||
|
|
@ -46,16 +46,6 @@ class OllamaLLM(BaseLLM):
|
|||
kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream}
|
||||
return kwargs
|
||||
|
||||
def _update_costs(self, usage: dict):
|
||||
"""update each request's token cost"""
|
||||
if self.config.calc_usage:
|
||||
try:
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0))
|
||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"ollama updats costs failed! exp: {e}")
|
||||
|
||||
def get_choice_text(self, resp: dict) -> str:
|
||||
"""get the resp content from llm response"""
|
||||
assist_msg = resp.get("message", {})
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from metagpt.configs.llm_config import LLMConfig, LLMType
|
|||
from metagpt.logs import logger
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import OpenAILLM
|
||||
from metagpt.utils.cost_manager import Costs, TokenCostManager
|
||||
from metagpt.utils.cost_manager import TokenCostManager
|
||||
from metagpt.utils.token_counter import count_message_tokens, count_string_tokens
|
||||
|
||||
|
||||
|
|
@ -34,14 +34,3 @@ class OpenLLM(OpenAILLM):
|
|||
logger.error(f"usage calculation failed!: {e}")
|
||||
|
||||
return usage
|
||||
|
||||
def _update_costs(self, usage: CompletionUsage):
|
||||
if self.config.calc_usage and usage:
|
||||
try:
|
||||
# use OpenLLMCostManager not CONFIG.cost_manager
|
||||
self._cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"updating costs failed!, exp: {e}")
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
return self._cost_manager.get_costs()
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA
|
|||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import CodeParser, decode_image
|
||||
from metagpt.utils.cost_manager import CostManager, Costs
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
from metagpt.utils.token_counter import (
|
||||
count_message_tokens,
|
||||
|
|
@ -56,16 +56,13 @@ class OpenAILLM(BaseLLM):
|
|||
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.config = config
|
||||
self._init_model()
|
||||
self._init_client()
|
||||
self.auto_max_tokens = False
|
||||
self.cost_manager: Optional[CostManager] = None
|
||||
|
||||
def _init_model(self):
|
||||
self.model = self.config.model # Used in _calc_usage & _cons_kwargs
|
||||
|
||||
def _init_client(self):
|
||||
"""https://github.com/openai/openai-python#async-usage"""
|
||||
self.model = self.config.model # Used in _calc_usage & _cons_kwargs
|
||||
kwargs = self._make_client_kwargs()
|
||||
self.aclient = AsyncOpenAI(**kwargs)
|
||||
|
||||
|
|
@ -268,24 +265,16 @@ class OpenAILLM(BaseLLM):
|
|||
usage.prompt_tokens = count_message_tokens(messages, self.model)
|
||||
usage.completion_tokens = count_string_tokens(rsp, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"usage calculation failed: {e}")
|
||||
logger.warning(f"usage calculation failed: {e}")
|
||||
|
||||
return usage
|
||||
|
||||
@handle_exception
|
||||
def _update_costs(self, usage: CompletionUsage):
|
||||
if self.config.calc_usage and usage and self.cost_manager:
|
||||
self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
if not self.cost_manager:
|
||||
return Costs(0, 0, 0, 0)
|
||||
return self.cost_manager.get_costs()
|
||||
|
||||
def _get_max_tokens(self, messages: list[dict]):
|
||||
if not self.auto_max_tokens:
|
||||
return self.config.max_token
|
||||
return get_max_completion_tokens(messages, self.model, self.config.max_tokens)
|
||||
# FIXME
|
||||
# https://community.openai.com/t/why-is-gpt-3-5-turbo-1106-max-tokens-limited-to-4096/494973/3
|
||||
return min(get_max_completion_tokens(messages, self.model, self.config.max_token), 4096)
|
||||
|
||||
@handle_exception
|
||||
async def amoderation(self, content: Union[str, list[str]]):
|
||||
|
|
|
|||
152
metagpt/provider/qianfan_api.py
Normal file
152
metagpt/provider/qianfan_api.py
Normal file
|
|
@ -0,0 +1,152 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : llm api of qianfan from Baidu, supports ERNIE(wen xin yi yan) and opensource models
|
||||
import copy
|
||||
import os
|
||||
|
||||
import qianfan
|
||||
from qianfan import ChatCompletion
|
||||
from qianfan.resources.typing import JsonBody
|
||||
from tenacity import (
|
||||
after_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import log_and_reraise
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.token_counter import (
|
||||
QianFan_EndPoint_TOKEN_COSTS,
|
||||
QianFan_MODEL_TOKEN_COSTS,
|
||||
)
|
||||
|
||||
|
||||
@register_provider(LLMType.QIANFAN)
|
||||
class QianFanLLM(BaseLLM):
|
||||
"""
|
||||
Refs
|
||||
Auth: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/3lmokh7n6#%E3%80%90%E6%8E%A8%E8%8D%90%E3%80%91%E4%BD%BF%E7%94%A8%E5%AE%89%E5%85%A8%E8%AE%A4%E8%AF%81aksk%E9%89%B4%E6%9D%83%E8%B0%83%E7%94%A8%E6%B5%81%E7%A8%8B
|
||||
Token Price: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7#tokens%E5%90%8E%E4%BB%98%E8%B4%B9
|
||||
Models: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/wlmhm7vuo#%E5%AF%B9%E8%AF%9Dchat
|
||||
https://cloud.baidu.com/doc/WENXINWORKSHOP/s/xlmokikxe#%E6%94%AF%E6%8C%81%E6%A8%A1%E5%9E%8B%E5%88%97%E8%A1%A8
|
||||
"""
|
||||
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.config = config
|
||||
self.use_system_prompt = False # only some ERNIE-x related models support system_prompt
|
||||
self.__init_qianfan()
|
||||
self.cost_manager = CostManager(token_costs=self.token_costs)
|
||||
|
||||
def __init_qianfan(self):
|
||||
if self.config.access_key and self.config.secret_key:
|
||||
# for system level auth, use access_key and secret_key, recommended by official
|
||||
# set environment variable due to official recommendation
|
||||
os.environ.setdefault("QIANFAN_ACCESS_KEY", self.config.access_key)
|
||||
os.environ.setdefault("QIANFAN_SECRET_KEY", self.config.secret_key)
|
||||
elif self.config.api_key and self.config.secret_key:
|
||||
# for application level auth, use api_key and secret_key
|
||||
# set environment variable due to official recommendation
|
||||
os.environ.setdefault("QIANFAN_AK", self.config.api_key)
|
||||
os.environ.setdefault("QIANFAN_SK", self.config.secret_key)
|
||||
else:
|
||||
raise ValueError("Set the `access_key`&`secret_key` or `api_key`&`secret_key` first")
|
||||
|
||||
support_system_pairs = [
|
||||
("ERNIE-Bot-4", "completions_pro"), # (model, corresponding-endpoint)
|
||||
("ERNIE-Bot-8k", "ernie_bot_8k"),
|
||||
("ERNIE-Bot", "completions"),
|
||||
("ERNIE-Bot-turbo", "eb-instant"),
|
||||
("ERNIE-Speed", "ernie_speed"),
|
||||
("EB-turbo-AppBuilder", "ai_apaas"),
|
||||
]
|
||||
if self.config.model in [pair[0] for pair in support_system_pairs]:
|
||||
# only some ERNIE models support
|
||||
self.use_system_prompt = True
|
||||
if self.config.endpoint in [pair[1] for pair in support_system_pairs]:
|
||||
self.use_system_prompt = True
|
||||
|
||||
assert not (self.config.model and self.config.endpoint), "Only set `model` or `endpoint` in the config"
|
||||
assert self.config.model or self.config.endpoint, "Should set one of `model` or `endpoint` in the config"
|
||||
|
||||
self.token_costs = copy.deepcopy(QianFan_MODEL_TOKEN_COSTS)
|
||||
self.token_costs.update(QianFan_EndPoint_TOKEN_COSTS)
|
||||
|
||||
# self deployed model on the cloud not to calculate usage, it charges resource pool rental fee
|
||||
self.calc_usage = self.config.calc_usage and self.config.endpoint is None
|
||||
self.aclient: ChatCompletion = qianfan.ChatCompletion()
|
||||
|
||||
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
|
||||
kwargs = {
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
}
|
||||
if self.config.temperature > 0:
|
||||
# different model has default temperature. only set when it's specified.
|
||||
kwargs["temperature"] = self.config.temperature
|
||||
if self.config.endpoint:
|
||||
kwargs["endpoint"] = self.config.endpoint
|
||||
elif self.config.model:
|
||||
kwargs["model"] = self.config.model
|
||||
|
||||
if self.use_system_prompt:
|
||||
# if the model support system prompt, extract and pass it
|
||||
if messages[0]["role"] == "system":
|
||||
kwargs["messages"] = messages[1:]
|
||||
kwargs["system"] = messages[0]["content"] # set system prompt here
|
||||
return kwargs
|
||||
|
||||
def _update_costs(self, usage: dict):
|
||||
"""update each request's token cost"""
|
||||
model_or_endpoint = self.config.model if self.config.model else self.config.endpoint
|
||||
local_calc_usage = True if model_or_endpoint in self.token_costs else False
|
||||
super()._update_costs(usage, model_or_endpoint, local_calc_usage)
|
||||
|
||||
def get_choice_text(self, resp: JsonBody) -> str:
|
||||
return resp.get("result", "")
|
||||
|
||||
def completion(self, messages: list[dict]) -> JsonBody:
|
||||
resp = self.aclient.do(**self._const_kwargs(messages=messages, stream=False))
|
||||
self._update_costs(resp.body.get("usage", {}))
|
||||
return resp.body
|
||||
|
||||
async def _achat_completion(self, messages: list[dict]) -> JsonBody:
|
||||
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=False))
|
||||
self._update_costs(resp.body.get("usage", {}))
|
||||
return resp.body
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=3) -> JsonBody:
|
||||
return await self._achat_completion(messages)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict]) -> str:
|
||||
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=True))
|
||||
collected_content = []
|
||||
usage = {}
|
||||
async for chunk in resp:
|
||||
content = chunk.body.get("result", "")
|
||||
usage = chunk.body.get("usage", {})
|
||||
log_llm_stream(content)
|
||||
collected_content.append(content)
|
||||
log_llm_stream("\n")
|
||||
|
||||
self._update_costs(usage)
|
||||
full_content = "".join(collected_content)
|
||||
return full_content
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
after=after_log(logger, logger.level("WARNING").name),
|
||||
retry=retry_if_exception_type(ConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
resp = await self._achat_completion(messages)
|
||||
return self.get_choice_text(resp)
|
||||
|
|
@ -3,9 +3,8 @@
|
|||
# @Desc : zhipuai LLM from https://open.bigmodel.cn/dev/api#sdk
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import openai
|
||||
import zhipuai
|
||||
from requests import ConnectionError
|
||||
from tenacity import (
|
||||
after_log,
|
||||
|
|
@ -14,6 +13,7 @@ from tenacity import (
|
|||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
from zhipuai.types.chat.chat_completion import Completion
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
|
|
@ -21,6 +21,7 @@ from metagpt.provider.base_llm import BaseLLM
|
|||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import log_and_reraise
|
||||
from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
|
||||
|
||||
class ZhiPuEvent(Enum):
|
||||
|
|
@ -38,37 +39,22 @@ class ZhiPuAILLM(BaseLLM):
|
|||
"""
|
||||
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.__init_zhipuai(config)
|
||||
self.llm = ZhiPuModelAPI
|
||||
self.model = "chatglm_turbo" # so far only one model, just use it
|
||||
self.use_system_prompt: bool = False # zhipuai has no system prompt when use api
|
||||
self.config = config
|
||||
self.__init_zhipuai()
|
||||
self.cost_manager: Optional[CostManager] = None
|
||||
|
||||
def __init_zhipuai(self, config: LLMConfig):
|
||||
assert config.api_key
|
||||
zhipuai.api_key = config.api_key
|
||||
# due to use openai sdk, set the api_key but it will't be used.
|
||||
# openai.api_key = zhipuai.api_key # due to use openai sdk, set the api_key but it will't be used.
|
||||
if config.proxy:
|
||||
# FIXME: openai v1.x sdk has no proxy support
|
||||
openai.proxy = config.proxy
|
||||
def __init_zhipuai(self):
|
||||
assert self.config.api_key
|
||||
self.api_key = self.config.api_key
|
||||
self.model = self.config.model # so far, it support glm-3-turbo、glm-4
|
||||
self.llm = ZhiPuModelAPI(api_key=self.api_key)
|
||||
|
||||
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
|
||||
kwargs = {"model": self.model, "messages": messages, "stream": stream, "temperature": 0.3}
|
||||
return kwargs
|
||||
|
||||
def _update_costs(self, usage: dict):
|
||||
"""update each request's token cost"""
|
||||
if self.config.calc_usage:
|
||||
try:
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0))
|
||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
self.config.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"zhipuai updats costs failed! exp: {e}")
|
||||
|
||||
def completion(self, messages: list[dict], timeout=3) -> dict:
|
||||
resp = self.llm.chat.completions.create(**self._const_kwargs(messages))
|
||||
resp: Completion = self.llm.chat.completions.create(**self._const_kwargs(messages))
|
||||
usage = resp.usage.model_dump()
|
||||
self._update_costs(usage)
|
||||
return resp.model_dump()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue