mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-30 11:26:23 +02:00
Merge remote-tracking branch 'origin/main'
This commit is contained in:
commit
eb3c6d14f9
41 changed files with 1093 additions and 363 deletions
|
|
@ -24,6 +24,8 @@ class LLMType(Enum):
|
|||
METAGPT = "metagpt"
|
||||
AZURE = "azure"
|
||||
OLLAMA = "ollama"
|
||||
QIANFAN = "qianfan" # Baidu BCE
|
||||
DASHSCOPE = "dashscope" # Aliyun LingJi DashScope
|
||||
|
||||
def __missing__(self, key):
|
||||
return self.OPENAI
|
||||
|
|
@ -36,13 +38,18 @@ class LLMConfig(YamlModel):
|
|||
Optional Fields in pydantic: https://docs.pydantic.dev/latest/migration/#required-optional-and-nullable-fields
|
||||
"""
|
||||
|
||||
api_key: str
|
||||
api_key: str = "sk-"
|
||||
api_type: LLMType = LLMType.OPENAI
|
||||
base_url: str = "https://api.openai.com/v1"
|
||||
api_version: Optional[str] = None
|
||||
|
||||
model: Optional[str] = None # also stands for DEPLOYMENT_NAME
|
||||
|
||||
# For Cloud Service Provider like Baidu/ Alibaba
|
||||
access_key: Optional[str] = None
|
||||
secret_key: Optional[str] = None
|
||||
endpoint: Optional[str] = None # for self-deployed model on the cloud
|
||||
|
||||
# For Spark(Xunfei), maybe remove later
|
||||
app_id: Optional[str] = None
|
||||
api_secret: Optional[str] = None
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@
|
|||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.vectorstores.faiss import FAISS
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
|
|
@ -15,6 +14,7 @@ from metagpt.const import DATA_PATH, MEM_TTL
|
|||
from metagpt.document_store.faiss_store import FaissStore
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.embedding import get_embedding
|
||||
from metagpt.utils.serialize import deserialize_message, serialize_message
|
||||
|
||||
|
||||
|
|
@ -30,7 +30,7 @@ class MemoryStorage(FaissStore):
|
|||
self.threshold: float = 0.1 # experience value. TODO The threshold to filter similar memories
|
||||
self._initialized: bool = False
|
||||
|
||||
self.embedding = embedding or OpenAIEmbeddings()
|
||||
self.embedding = embedding or get_embedding()
|
||||
self.store: FAISS = None # Faiss engine
|
||||
|
||||
@property
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@ from metagpt.provider.azure_openai_api import AzureOpenAILLM
|
|||
from metagpt.provider.metagpt_api import MetaGPTLLM
|
||||
from metagpt.provider.human_provider import HumanProvider
|
||||
from metagpt.provider.spark_api import SparkLLM
|
||||
from metagpt.provider.qianfan_api import QianFanLLM
|
||||
from metagpt.provider.dashscope_api import DashScopeLLM
|
||||
|
||||
__all__ = [
|
||||
"FireworksLLM",
|
||||
|
|
@ -28,4 +30,6 @@ __all__ = [
|
|||
"OllamaLLM",
|
||||
"HumanProvider",
|
||||
"SparkLLM",
|
||||
"QianFanLLM",
|
||||
"DashScopeLLM",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -11,11 +11,12 @@ from abc import ABC, abstractmethod
|
|||
from typing import Optional, Union
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.cost_manager import CostManager, Costs
|
||||
|
||||
|
||||
class BaseLLM(ABC):
|
||||
|
|
@ -67,6 +68,28 @@ class BaseLLM(ABC):
|
|||
def _default_system_msg(self):
|
||||
return self._system_msg(self.system_prompt)
|
||||
|
||||
def _update_costs(self, usage: Union[dict, BaseModel], model: str = None, local_calc_usage: bool = True):
|
||||
"""update each request's token cost
|
||||
Args:
|
||||
model (str): model name or in some scenarios called endpoint
|
||||
local_calc_usage (bool): some models don't calculate usage, it will overwrite LLMConfig.calc_usage
|
||||
"""
|
||||
calc_usage = self.config.calc_usage and local_calc_usage
|
||||
model = model or self.model
|
||||
usage = usage.model_dump() if isinstance(usage, BaseModel) else usage
|
||||
if calc_usage and self.cost_manager:
|
||||
try:
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0))
|
||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
self.cost_manager.update_cost(prompt_tokens, completion_tokens, model)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.__class__.__name__} updates costs failed! exp: {e}")
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
if not self.cost_manager:
|
||||
return Costs(0, 0, 0, 0)
|
||||
return self.cost_manager.get_costs()
|
||||
|
||||
async def aask(
|
||||
self,
|
||||
msg: str,
|
||||
|
|
|
|||
248
metagpt/provider/dashscope_api.py
Normal file
248
metagpt/provider/dashscope_api.py
Normal file
|
|
@ -0,0 +1,248 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
import json
|
||||
from http import HTTPStatus
|
||||
from typing import Any, AsyncGenerator, Dict, List, Union
|
||||
|
||||
import dashscope
|
||||
from dashscope.aigc.generation import Generation
|
||||
from dashscope.api_entities.aiohttp_request import AioHttpRequest
|
||||
from dashscope.api_entities.api_request_data import ApiRequestData
|
||||
from dashscope.api_entities.api_request_factory import _get_protocol_params
|
||||
from dashscope.api_entities.dashscope_response import (
|
||||
GenerationOutput,
|
||||
GenerationResponse,
|
||||
Message,
|
||||
)
|
||||
from dashscope.client.base_api import BaseAioApi
|
||||
from dashscope.common.constants import SERVICE_API_PATH, ApiProtocol
|
||||
from dashscope.common.error import (
|
||||
InputDataRequired,
|
||||
InputRequired,
|
||||
ModelRequired,
|
||||
UnsupportedApiProtocol,
|
||||
)
|
||||
from tenacity import (
|
||||
after_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.base_llm import BaseLLM, LLMConfig
|
||||
from metagpt.provider.llm_provider_registry import LLMType, register_provider
|
||||
from metagpt.provider.openai_api import log_and_reraise
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.token_counter import DASHSCOPE_TOKEN_COSTS
|
||||
|
||||
|
||||
def build_api_arequest(
|
||||
model: str, input: object, task_group: str, task: str, function: str, api_key: str, is_service=True, **kwargs
|
||||
):
|
||||
(
|
||||
api_protocol,
|
||||
ws_stream_mode,
|
||||
is_binary_input,
|
||||
http_method,
|
||||
stream,
|
||||
async_request,
|
||||
query,
|
||||
headers,
|
||||
request_timeout,
|
||||
form,
|
||||
resources,
|
||||
) = _get_protocol_params(kwargs)
|
||||
task_id = kwargs.pop("task_id", None)
|
||||
if api_protocol in [ApiProtocol.HTTP, ApiProtocol.HTTPS]:
|
||||
if not dashscope.base_http_api_url.endswith("/"):
|
||||
http_url = dashscope.base_http_api_url + "/"
|
||||
else:
|
||||
http_url = dashscope.base_http_api_url
|
||||
|
||||
if is_service:
|
||||
http_url = http_url + SERVICE_API_PATH + "/"
|
||||
|
||||
if task_group:
|
||||
http_url += "%s/" % task_group
|
||||
if task:
|
||||
http_url += "%s/" % task
|
||||
if function:
|
||||
http_url += function
|
||||
request = AioHttpRequest(
|
||||
url=http_url,
|
||||
api_key=api_key,
|
||||
http_method=http_method,
|
||||
stream=stream,
|
||||
async_request=async_request,
|
||||
query=query,
|
||||
timeout=request_timeout,
|
||||
task_id=task_id,
|
||||
)
|
||||
else:
|
||||
raise UnsupportedApiProtocol("Unsupported protocol: %s, support [http, https, websocket]" % api_protocol)
|
||||
|
||||
if headers is not None:
|
||||
request.add_headers(headers=headers)
|
||||
|
||||
if input is None and form is None:
|
||||
raise InputDataRequired("There is no input data and form data")
|
||||
|
||||
request_data = ApiRequestData(
|
||||
model,
|
||||
task_group=task_group,
|
||||
task=task,
|
||||
function=function,
|
||||
input=input,
|
||||
form=form,
|
||||
is_binary_input=is_binary_input,
|
||||
api_protocol=api_protocol,
|
||||
)
|
||||
request_data.add_resources(resources)
|
||||
request_data.add_parameters(**kwargs)
|
||||
request.data = request_data
|
||||
return request
|
||||
|
||||
|
||||
class AGeneration(Generation, BaseAioApi):
|
||||
@classmethod
|
||||
async def acall(
|
||||
cls,
|
||||
model: str,
|
||||
prompt: Any = None,
|
||||
history: list = None,
|
||||
api_key: str = None,
|
||||
messages: List[Message] = None,
|
||||
plugins: Union[str, Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> Union[GenerationResponse, AsyncGenerator[GenerationResponse, None]]:
|
||||
if (prompt is None or not prompt) and (messages is None or not messages):
|
||||
raise InputRequired("prompt or messages is required!")
|
||||
if model is None or not model:
|
||||
raise ModelRequired("Model is required!")
|
||||
task_group, function = "aigc", "generation" # fixed value
|
||||
if plugins is not None:
|
||||
headers = kwargs.pop("headers", {})
|
||||
if isinstance(plugins, str):
|
||||
headers["X-DashScope-Plugin"] = plugins
|
||||
else:
|
||||
headers["X-DashScope-Plugin"] = json.dumps(plugins)
|
||||
kwargs["headers"] = headers
|
||||
input, parameters = cls._build_input_parameters(model, prompt, history, messages, **kwargs)
|
||||
|
||||
api_key, model = BaseAioApi._validate_params(api_key, model)
|
||||
request = build_api_arequest(
|
||||
model=model,
|
||||
input=input,
|
||||
task_group=task_group,
|
||||
task=Generation.task,
|
||||
function=function,
|
||||
api_key=api_key,
|
||||
**kwargs,
|
||||
)
|
||||
response = await request.aio_call()
|
||||
is_stream = kwargs.get("stream", False)
|
||||
if is_stream:
|
||||
|
||||
async def aresp_iterator(response):
|
||||
async for resp in response:
|
||||
yield GenerationResponse.from_api_response(resp)
|
||||
|
||||
return aresp_iterator(response)
|
||||
else:
|
||||
return GenerationResponse.from_api_response(response)
|
||||
|
||||
|
||||
@register_provider(LLMType.DASHSCOPE)
|
||||
class DashScopeLLM(BaseLLM):
|
||||
def __init__(self, llm_config: LLMConfig):
|
||||
self.config = llm_config
|
||||
self.use_system_prompt = False # only some models support system_prompt
|
||||
self.__init_dashscope()
|
||||
self.cost_manager = CostManager(token_costs=self.token_costs)
|
||||
|
||||
def __init_dashscope(self):
|
||||
self.model = self.config.model
|
||||
self.api_key = self.config.api_key
|
||||
self.token_costs = DASHSCOPE_TOKEN_COSTS
|
||||
self.aclient: AGeneration = AGeneration
|
||||
|
||||
# check support system_message models
|
||||
support_system_models = [
|
||||
"qwen-", # all support
|
||||
"llama2-", # all support
|
||||
"baichuan2-7b-chat-v1",
|
||||
"chatglm3-6b",
|
||||
]
|
||||
for support_model in support_system_models:
|
||||
if support_model in self.model:
|
||||
self.use_system_prompt = True
|
||||
|
||||
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
|
||||
kwargs = {
|
||||
"api_key": self.api_key,
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
"result_format": "message",
|
||||
}
|
||||
if self.config.temperature > 0:
|
||||
# different model has default temperature. only set when it"s specified.
|
||||
kwargs["temperature"] = self.config.temperature
|
||||
if stream:
|
||||
kwargs["incremental_output"] = True
|
||||
return kwargs
|
||||
|
||||
def _check_response(self, resp: GenerationResponse):
|
||||
if resp.status_code != HTTPStatus.OK:
|
||||
raise RuntimeError(f"code: {resp.code}, request_id: {resp.request_id}, message: {resp.message}")
|
||||
|
||||
def get_choice_text(self, output: GenerationOutput) -> str:
|
||||
return output.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||||
|
||||
def completion(self, messages: list[dict]) -> GenerationOutput:
|
||||
resp: GenerationResponse = self.aclient.call(**self._const_kwargs(messages, stream=False))
|
||||
self._check_response(resp)
|
||||
|
||||
self._update_costs(dict(resp.usage))
|
||||
return resp.output
|
||||
|
||||
async def _achat_completion(self, messages: list[dict]) -> GenerationOutput:
|
||||
resp: GenerationResponse = await self.aclient.acall(**self._const_kwargs(messages, stream=False))
|
||||
self._check_response(resp)
|
||||
self._update_costs(dict(resp.usage))
|
||||
return resp.output
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=3) -> GenerationOutput:
|
||||
return await self._achat_completion(messages)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict]) -> str:
|
||||
resp = await self.aclient.acall(**self._const_kwargs(messages, stream=True))
|
||||
collected_content = []
|
||||
usage = {}
|
||||
async for chunk in resp:
|
||||
self._check_response(chunk)
|
||||
content = chunk.output.choices[0]["message"]["content"]
|
||||
usage = dict(chunk.usage) # each chunk has usage
|
||||
log_llm_stream(content)
|
||||
collected_content.append(content)
|
||||
log_llm_stream("\n")
|
||||
self._update_costs(usage)
|
||||
full_content = "".join(collected_content)
|
||||
return full_content
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
after=after_log(logger, logger.level("WARNING").name),
|
||||
retry=retry_if_exception_type(ConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
resp = await self._achat_completion(messages)
|
||||
return self.get_choice_text(resp)
|
||||
|
|
@ -16,10 +16,10 @@ from tenacity import (
|
|||
)
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.logs import logger
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import OpenAILLM, log_and_reraise
|
||||
from metagpt.utils.cost_manager import CostManager, Costs
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
|
||||
MODEL_GRADE_TOKEN_COSTS = {
|
||||
"-1": {"prompt": 0.0, "completion": 0.0}, # abnormal condition
|
||||
|
|
@ -81,17 +81,6 @@ class FireworksLLM(OpenAILLM):
|
|||
kwargs = dict(api_key=self.config.api_key, base_url=self.config.base_url)
|
||||
return kwargs
|
||||
|
||||
def _update_costs(self, usage: CompletionUsage):
|
||||
if self.config.calc_usage and usage:
|
||||
try:
|
||||
# use FireworksCostManager not context.cost_manager
|
||||
self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"updating costs failed!, exp: {e}")
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
return self.cost_manager.get_costs()
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
|
||||
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
|
||||
**self._cons_kwargs(messages), stream=True
|
||||
|
|
@ -107,10 +96,11 @@ class FireworksLLM(OpenAILLM):
|
|||
finish_reason = choice.finish_reason if hasattr(choice, "finish_reason") else None
|
||||
if choice_delta.content:
|
||||
collected_content.append(choice_delta.content)
|
||||
print(choice_delta.content, end="")
|
||||
log_llm_stream(choice_delta.content)
|
||||
if finish_reason:
|
||||
# fireworks api return usage when finish_reason is not None
|
||||
usage = CompletionUsage(**chunk.usage)
|
||||
log_llm_stream("\n")
|
||||
|
||||
full_content = "".join(collected_content)
|
||||
self._update_costs(usage)
|
||||
|
|
|
|||
|
|
@ -72,16 +72,6 @@ class GeminiLLM(BaseLLM):
|
|||
kwargs = {"contents": messages, "generation_config": GenerationConfig(temperature=0.3), "stream": stream}
|
||||
return kwargs
|
||||
|
||||
def _update_costs(self, usage: dict):
|
||||
"""update each request's token cost"""
|
||||
if self.config.calc_usage:
|
||||
try:
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0))
|
||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"google gemini updats costs failed! exp: {e}")
|
||||
|
||||
def get_choice_text(self, resp: GenerateContentResponse) -> str:
|
||||
return resp.text
|
||||
|
||||
|
|
|
|||
|
|
@ -46,16 +46,6 @@ class OllamaLLM(BaseLLM):
|
|||
kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream}
|
||||
return kwargs
|
||||
|
||||
def _update_costs(self, usage: dict):
|
||||
"""update each request's token cost"""
|
||||
if self.config.calc_usage:
|
||||
try:
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0))
|
||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"ollama updats costs failed! exp: {e}")
|
||||
|
||||
def get_choice_text(self, resp: dict) -> str:
|
||||
"""get the resp content from llm response"""
|
||||
assist_msg = resp.get("message", {})
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from metagpt.configs.llm_config import LLMConfig, LLMType
|
|||
from metagpt.logs import logger
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import OpenAILLM
|
||||
from metagpt.utils.cost_manager import Costs, TokenCostManager
|
||||
from metagpt.utils.cost_manager import TokenCostManager
|
||||
from metagpt.utils.token_counter import count_message_tokens, count_string_tokens
|
||||
|
||||
|
||||
|
|
@ -34,14 +34,3 @@ class OpenLLM(OpenAILLM):
|
|||
logger.error(f"usage calculation failed!: {e}")
|
||||
|
||||
return usage
|
||||
|
||||
def _update_costs(self, usage: CompletionUsage):
|
||||
if self.config.calc_usage and usage:
|
||||
try:
|
||||
# use OpenLLMCostManager not CONFIG.cost_manager
|
||||
self._cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"updating costs failed!, exp: {e}")
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
return self._cost_manager.get_costs()
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA
|
|||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import CodeParser, decode_image
|
||||
from metagpt.utils.cost_manager import CostManager, Costs
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
from metagpt.utils.token_counter import (
|
||||
count_message_tokens,
|
||||
|
|
@ -56,16 +56,13 @@ class OpenAILLM(BaseLLM):
|
|||
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.config = config
|
||||
self._init_model()
|
||||
self._init_client()
|
||||
self.auto_max_tokens = False
|
||||
self.cost_manager: Optional[CostManager] = None
|
||||
|
||||
def _init_model(self):
|
||||
self.model = self.config.model # Used in _calc_usage & _cons_kwargs
|
||||
|
||||
def _init_client(self):
|
||||
"""https://github.com/openai/openai-python#async-usage"""
|
||||
self.model = self.config.model # Used in _calc_usage & _cons_kwargs
|
||||
kwargs = self._make_client_kwargs()
|
||||
self.aclient = AsyncOpenAI(**kwargs)
|
||||
|
||||
|
|
@ -102,7 +99,7 @@ class OpenAILLM(BaseLLM):
|
|||
"max_tokens": self._get_max_tokens(messages),
|
||||
"n": 1,
|
||||
# "stop": None, # default it's None and gpt4-v can't have this one
|
||||
"temperature": 0.3,
|
||||
"temperature": self.config.temperature,
|
||||
"model": self.model,
|
||||
"timeout": max(self.config.timeout, timeout),
|
||||
}
|
||||
|
|
@ -272,16 +269,6 @@ class OpenAILLM(BaseLLM):
|
|||
|
||||
return usage
|
||||
|
||||
@handle_exception
|
||||
def _update_costs(self, usage: CompletionUsage):
|
||||
if self.config.calc_usage and usage and self.cost_manager:
|
||||
self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
if not self.cost_manager:
|
||||
return Costs(0, 0, 0, 0)
|
||||
return self.cost_manager.get_costs()
|
||||
|
||||
def _get_max_tokens(self, messages: list[dict]):
|
||||
if not self.auto_max_tokens:
|
||||
return self.config.max_token
|
||||
|
|
|
|||
152
metagpt/provider/qianfan_api.py
Normal file
152
metagpt/provider/qianfan_api.py
Normal file
|
|
@ -0,0 +1,152 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : llm api of qianfan from Baidu, supports ERNIE(wen xin yi yan) and opensource models
|
||||
import copy
|
||||
import os
|
||||
|
||||
import qianfan
|
||||
from qianfan import ChatCompletion
|
||||
from qianfan.resources.typing import JsonBody
|
||||
from tenacity import (
|
||||
after_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import log_and_reraise
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.token_counter import (
|
||||
QIANFAN_ENDPOINT_TOKEN_COSTS,
|
||||
QIANFAN_MODEL_TOKEN_COSTS,
|
||||
)
|
||||
|
||||
|
||||
@register_provider(LLMType.QIANFAN)
|
||||
class QianFanLLM(BaseLLM):
|
||||
"""
|
||||
Refs
|
||||
Auth: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/3lmokh7n6#%E3%80%90%E6%8E%A8%E8%8D%90%E3%80%91%E4%BD%BF%E7%94%A8%E5%AE%89%E5%85%A8%E8%AE%A4%E8%AF%81aksk%E9%89%B4%E6%9D%83%E8%B0%83%E7%94%A8%E6%B5%81%E7%A8%8B
|
||||
Token Price: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7#tokens%E5%90%8E%E4%BB%98%E8%B4%B9
|
||||
Models: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/wlmhm7vuo#%E5%AF%B9%E8%AF%9Dchat
|
||||
https://cloud.baidu.com/doc/WENXINWORKSHOP/s/xlmokikxe#%E6%94%AF%E6%8C%81%E6%A8%A1%E5%9E%8B%E5%88%97%E8%A1%A8
|
||||
"""
|
||||
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.config = config
|
||||
self.use_system_prompt = False # only some ERNIE-x related models support system_prompt
|
||||
self.__init_qianfan()
|
||||
self.cost_manager = CostManager(token_costs=self.token_costs)
|
||||
|
||||
def __init_qianfan(self):
|
||||
if self.config.access_key and self.config.secret_key:
|
||||
# for system level auth, use access_key and secret_key, recommended by official
|
||||
# set environment variable due to official recommendation
|
||||
os.environ.setdefault("QIANFAN_ACCESS_KEY", self.config.access_key)
|
||||
os.environ.setdefault("QIANFAN_SECRET_KEY", self.config.secret_key)
|
||||
elif self.config.api_key and self.config.secret_key:
|
||||
# for application level auth, use api_key and secret_key
|
||||
# set environment variable due to official recommendation
|
||||
os.environ.setdefault("QIANFAN_AK", self.config.api_key)
|
||||
os.environ.setdefault("QIANFAN_SK", self.config.secret_key)
|
||||
else:
|
||||
raise ValueError("Set the `access_key`&`secret_key` or `api_key`&`secret_key` first")
|
||||
|
||||
support_system_pairs = [
|
||||
("ERNIE-Bot-4", "completions_pro"), # (model, corresponding-endpoint)
|
||||
("ERNIE-Bot-8k", "ernie_bot_8k"),
|
||||
("ERNIE-Bot", "completions"),
|
||||
("ERNIE-Bot-turbo", "eb-instant"),
|
||||
("ERNIE-Speed", "ernie_speed"),
|
||||
("EB-turbo-AppBuilder", "ai_apaas"),
|
||||
]
|
||||
if self.config.model in [pair[0] for pair in support_system_pairs]:
|
||||
# only some ERNIE models support
|
||||
self.use_system_prompt = True
|
||||
if self.config.endpoint in [pair[1] for pair in support_system_pairs]:
|
||||
self.use_system_prompt = True
|
||||
|
||||
assert not (self.config.model and self.config.endpoint), "Only set `model` or `endpoint` in the config"
|
||||
assert self.config.model or self.config.endpoint, "Should set one of `model` or `endpoint` in the config"
|
||||
|
||||
self.token_costs = copy.deepcopy(QIANFAN_MODEL_TOKEN_COSTS)
|
||||
self.token_costs.update(QIANFAN_ENDPOINT_TOKEN_COSTS)
|
||||
|
||||
# self deployed model on the cloud not to calculate usage, it charges resource pool rental fee
|
||||
self.calc_usage = self.config.calc_usage and self.config.endpoint is None
|
||||
self.aclient: ChatCompletion = qianfan.ChatCompletion()
|
||||
|
||||
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
|
||||
kwargs = {
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
}
|
||||
if self.config.temperature > 0:
|
||||
# different model has default temperature. only set when it's specified.
|
||||
kwargs["temperature"] = self.config.temperature
|
||||
if self.config.endpoint:
|
||||
kwargs["endpoint"] = self.config.endpoint
|
||||
elif self.config.model:
|
||||
kwargs["model"] = self.config.model
|
||||
|
||||
if self.use_system_prompt:
|
||||
# if the model support system prompt, extract and pass it
|
||||
if messages[0]["role"] == "system":
|
||||
kwargs["messages"] = messages[1:]
|
||||
kwargs["system"] = messages[0]["content"] # set system prompt here
|
||||
return kwargs
|
||||
|
||||
def _update_costs(self, usage: dict):
|
||||
"""update each request's token cost"""
|
||||
model_or_endpoint = self.config.model or self.config.endpoint
|
||||
local_calc_usage = model_or_endpoint in self.token_costs
|
||||
super()._update_costs(usage, model_or_endpoint, local_calc_usage)
|
||||
|
||||
def get_choice_text(self, resp: JsonBody) -> str:
|
||||
return resp.get("result", "")
|
||||
|
||||
def completion(self, messages: list[dict]) -> JsonBody:
|
||||
resp = self.aclient.do(**self._const_kwargs(messages=messages, stream=False))
|
||||
self._update_costs(resp.body.get("usage", {}))
|
||||
return resp.body
|
||||
|
||||
async def _achat_completion(self, messages: list[dict]) -> JsonBody:
|
||||
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=False))
|
||||
self._update_costs(resp.body.get("usage", {}))
|
||||
return resp.body
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=3) -> JsonBody:
|
||||
return await self._achat_completion(messages)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict]) -> str:
|
||||
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=True))
|
||||
collected_content = []
|
||||
usage = {}
|
||||
async for chunk in resp:
|
||||
content = chunk.body.get("result", "")
|
||||
usage = chunk.body.get("usage", {})
|
||||
log_llm_stream(content)
|
||||
collected_content.append(content)
|
||||
log_llm_stream("\n")
|
||||
|
||||
self._update_costs(usage)
|
||||
full_content = "".join(collected_content)
|
||||
return full_content
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
after=after_log(logger, logger.level("WARNING").name),
|
||||
retry=retry_if_exception_type(ConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
resp = await self._achat_completion(messages)
|
||||
return self.get_choice_text(resp)
|
||||
|
|
@ -53,16 +53,6 @@ class ZhiPuAILLM(BaseLLM):
|
|||
kwargs = {"model": self.model, "messages": messages, "stream": stream, "temperature": 0.3}
|
||||
return kwargs
|
||||
|
||||
def _update_costs(self, usage: dict):
|
||||
"""update each request's token cost"""
|
||||
if self.config.calc_usage:
|
||||
try:
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0))
|
||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"zhipuai updats costs failed! exp: {e}")
|
||||
|
||||
def completion(self, messages: list[dict], timeout=3) -> dict:
|
||||
resp: Completion = self.llm.chat.completions.create(**self._const_kwargs(messages))
|
||||
usage = resp.usage.model_dump()
|
||||
|
|
|
|||
|
|
@ -281,7 +281,7 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
|
|||
i = action
|
||||
self._init_action(i)
|
||||
self.actions.append(i)
|
||||
self.states.append(f"{len(self.actions)}. {action}")
|
||||
self.states.append(f"{len(self.actions) - 1}. {action}")
|
||||
|
||||
def _set_react_mode(self, react_mode: str, max_react_loop: int = 1, auto_run: bool = True, use_tools: bool = False):
|
||||
"""Set strategy of the Role reacting to observed Message. Variation lies in how
|
||||
|
|
|
|||
|
|
@ -2,14 +2,11 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import asyncio
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import typer
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.const import CONFIG_ROOT, METAGPT_ROOT
|
||||
from metagpt.context import Context
|
||||
from metagpt.const import CONFIG_ROOT
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
|
||||
app = typer.Typer(add_completion=False, pretty_exceptions_show_locals=False)
|
||||
|
|
@ -30,6 +27,8 @@ def generate_repo(
|
|||
recover_path=None,
|
||||
) -> ProjectRepo:
|
||||
"""Run the startup logic. Can be called from CLI or other Python scripts."""
|
||||
from metagpt.config2 import config
|
||||
from metagpt.context import Context
|
||||
from metagpt.roles import (
|
||||
Architect,
|
||||
Engineer,
|
||||
|
|
@ -122,7 +121,17 @@ def startup(
|
|||
)
|
||||
|
||||
|
||||
def copy_config_to(config_path=METAGPT_ROOT / "config" / "config2.yaml"):
|
||||
DEFAULT_CONFIG = """# Full Example: https://github.com/geekan/MetaGPT/blob/main/config/config2.example.yaml
|
||||
# Reflected Code: https://github.com/geekan/MetaGPT/blob/main/metagpt/config2.py
|
||||
llm:
|
||||
api_type: "openai" # or azure / ollama / open_llm etc. Check LLMType for more options
|
||||
model: "gpt-4-turbo-preview" # or gpt-3.5-turbo-1106 / gpt-4-1106-preview
|
||||
base_url: "https://api.openai.com/v1" # or forward url / other llm url
|
||||
api_key: "YOUR_API_KEY"
|
||||
"""
|
||||
|
||||
|
||||
def copy_config_to():
|
||||
"""Initialize the configuration file for MetaGPT."""
|
||||
target_path = CONFIG_ROOT / "config2.yaml"
|
||||
|
||||
|
|
@ -136,7 +145,7 @@ def copy_config_to(config_path=METAGPT_ROOT / "config" / "config2.yaml"):
|
|||
print(f"Existing configuration file backed up at {backup_path}")
|
||||
|
||||
# 复制文件
|
||||
shutil.copy(str(config_path), target_path)
|
||||
target_path.write_text(DEFAULT_CONFIG, encoding="utf-8")
|
||||
print(f"Configuration file initialized at {target_path}")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ class CostManager(BaseModel):
|
|||
total_budget: float = 0
|
||||
max_budget: float = 10.0
|
||||
total_cost: float = 0
|
||||
token_costs: dict[str, dict[str, float]] = TOKEN_COSTS # different model's token cost
|
||||
|
||||
def update_cost(self, prompt_tokens, completion_tokens, model):
|
||||
"""
|
||||
|
|
@ -41,12 +42,13 @@ class CostManager(BaseModel):
|
|||
"""
|
||||
self.total_prompt_tokens += prompt_tokens
|
||||
self.total_completion_tokens += completion_tokens
|
||||
if model not in TOKEN_COSTS:
|
||||
if model not in self.token_costs:
|
||||
logger.warning(f"Model {model} not found in TOKEN_COSTS.")
|
||||
return
|
||||
|
||||
cost = (
|
||||
prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model]["completion"]
|
||||
prompt_tokens * self.token_costs[model]["prompt"]
|
||||
+ completion_tokens * self.token_costs[model]["completion"]
|
||||
) / 1000
|
||||
self.total_cost += cost
|
||||
logger.info(
|
||||
|
|
|
|||
|
|
@ -119,6 +119,7 @@ def repair_json_format(output: str) -> str:
|
|||
logger.info(f"repair_json_format: {'}]'}")
|
||||
elif output.startswith("{") and output.endswith("]"):
|
||||
output = output[:-1] + "}"
|
||||
|
||||
# remove comments in output json string, after json value content, maybe start with #, maybe start with //
|
||||
arr = output.split("\n")
|
||||
new_arr = []
|
||||
|
|
@ -208,6 +209,17 @@ def repair_invalid_json(output: str, error: str) -> str:
|
|||
elif (rline[col_no] in ["'", '"']) and (line.startswith('"') or line.startswith("'")) and "," not in line:
|
||||
# problem, `"""` or `'''` without `,`
|
||||
new_line = f",{line}"
|
||||
elif col_no - 1 >= 0 and rline[col_no - 1] in ['"', "'"]:
|
||||
# backslash problem like \" in the output
|
||||
char = rline[col_no - 1]
|
||||
nearest_char_idx = rline[col_no:].find(char)
|
||||
new_line = (
|
||||
rline[: col_no - 1]
|
||||
+ "\\"
|
||||
+ rline[col_no - 1 : col_no + nearest_char_idx]
|
||||
+ "\\"
|
||||
+ rline[col_no + nearest_char_idx :]
|
||||
)
|
||||
elif '",' not in line and "," not in line and '"' not in line:
|
||||
new_line = f'{line}",'
|
||||
elif not line.endswith(","):
|
||||
|
|
|
|||
|
|
@ -38,6 +38,88 @@ TOKEN_COSTS = {
|
|||
}
|
||||
|
||||
|
||||
"""
|
||||
QianFan Token Price https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7#tokens%E5%90%8E%E4%BB%98%E8%B4%B9
|
||||
Due to QianFan has multi price strategies, we unify `Tokens post-payment` as a statistical method.
|
||||
"""
|
||||
QIANFAN_MODEL_TOKEN_COSTS = {
|
||||
"ERNIE-Bot-4": {"prompt": 0.017, "completion": 0.017},
|
||||
"ERNIE-Bot-8k": {"prompt": 0.0034, "completion": 0.0067},
|
||||
"ERNIE-Bot": {"prompt": 0.0017, "completion": 0.0017},
|
||||
"ERNIE-Bot-turbo": {"prompt": 0.0011, "completion": 0.0011},
|
||||
"EB-turbo-AppBuilder": {"prompt": 0.0011, "completion": 0.0011},
|
||||
"ERNIE-Speed": {"prompt": 0.00056, "completion": 0.0011},
|
||||
"BLOOMZ-7B": {"prompt": 0.00056, "completion": 0.00056},
|
||||
"Llama-2-7B-Chat": {"prompt": 0.00056, "completion": 0.00056},
|
||||
"Llama-2-13B-Chat": {"prompt": 0.00084, "completion": 0.00084},
|
||||
"Llama-2-70B-Chat": {"prompt": 0.0049, "completion": 0.0049},
|
||||
"ChatGLM2-6B-32K": {"prompt": 0.00056, "completion": 0.00056},
|
||||
"AquilaChat-7B": {"prompt": 0.00056, "completion": 0.00056},
|
||||
"Mixtral-8x7B-Instruct": {"prompt": 0.0049, "completion": 0.0049},
|
||||
"SQLCoder-7B": {"prompt": 0.00056, "completion": 0.00056},
|
||||
"CodeLlama-7B-Instruct": {"prompt": 0.00056, "completion": 0.00056},
|
||||
"XuanYuan-70B-Chat-4bit": {"prompt": 0.0049, "completion": 0.0049},
|
||||
"Qianfan-BLOOMZ-7B-compressed": {"prompt": 0.00056, "completion": 0.00056},
|
||||
"Qianfan-Chinese-Llama-2-7B": {"prompt": 0.00056, "completion": 0.00056},
|
||||
"Qianfan-Chinese-Llama-2-13B": {"prompt": 0.00084, "completion": 0.00084},
|
||||
"ChatLaw": {"prompt": 0.0011, "completion": 0.0011},
|
||||
"Yi-34B-Chat": {"prompt": 0.0, "completion": 0.0},
|
||||
}
|
||||
|
||||
QIANFAN_ENDPOINT_TOKEN_COSTS = {
|
||||
"completions_pro": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Bot-4"],
|
||||
"ernie_bot_8k": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Bot-8k"],
|
||||
"completions": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Bot"],
|
||||
"eb-instant": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Bot-turbo"],
|
||||
"ai_apaas": QIANFAN_MODEL_TOKEN_COSTS["EB-turbo-AppBuilder"],
|
||||
"ernie_speed": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Speed"],
|
||||
"bloomz_7b1": QIANFAN_MODEL_TOKEN_COSTS["BLOOMZ-7B"],
|
||||
"llama_2_7b": QIANFAN_MODEL_TOKEN_COSTS["Llama-2-7B-Chat"],
|
||||
"llama_2_13b": QIANFAN_MODEL_TOKEN_COSTS["Llama-2-13B-Chat"],
|
||||
"llama_2_70b": QIANFAN_MODEL_TOKEN_COSTS["Llama-2-70B-Chat"],
|
||||
"chatglm2_6b_32k": QIANFAN_MODEL_TOKEN_COSTS["ChatGLM2-6B-32K"],
|
||||
"aquilachat_7b": QIANFAN_MODEL_TOKEN_COSTS["AquilaChat-7B"],
|
||||
"mixtral_8x7b_instruct": QIANFAN_MODEL_TOKEN_COSTS["Mixtral-8x7B-Instruct"],
|
||||
"sqlcoder_7b": QIANFAN_MODEL_TOKEN_COSTS["SQLCoder-7B"],
|
||||
"codellama_7b_instruct": QIANFAN_MODEL_TOKEN_COSTS["CodeLlama-7B-Instruct"],
|
||||
"xuanyuan_70b_chat": QIANFAN_MODEL_TOKEN_COSTS["XuanYuan-70B-Chat-4bit"],
|
||||
"qianfan_bloomz_7b_compressed": QIANFAN_MODEL_TOKEN_COSTS["Qianfan-BLOOMZ-7B-compressed"],
|
||||
"qianfan_chinese_llama_2_7b": QIANFAN_MODEL_TOKEN_COSTS["Qianfan-Chinese-Llama-2-7B"],
|
||||
"qianfan_chinese_llama_2_13b": QIANFAN_MODEL_TOKEN_COSTS["Qianfan-Chinese-Llama-2-13B"],
|
||||
"chatlaw": QIANFAN_MODEL_TOKEN_COSTS["ChatLaw"],
|
||||
"yi_34b_chat": QIANFAN_MODEL_TOKEN_COSTS["Yi-34B-Chat"],
|
||||
}
|
||||
|
||||
"""
|
||||
DashScope Token price https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing
|
||||
Different model has different detail page. Attention, some model are free for a limited time.
|
||||
"""
|
||||
DASHSCOPE_TOKEN_COSTS = {
|
||||
"qwen-turbo": {"prompt": 0.0011, "completion": 0.0011},
|
||||
"qwen-plus": {"prompt": 0.0028, "completion": 0.0028},
|
||||
"qwen-max": {"prompt": 0.0, "completion": 0.0},
|
||||
"qwen-max-1201": {"prompt": 0.0, "completion": 0.0},
|
||||
"qwen-max-longcontext": {"prompt": 0.0, "completion": 0.0},
|
||||
"llama2-7b-chat-v2": {"prompt": 0.0, "completion": 0.0},
|
||||
"llama2-13b-chat-v2": {"prompt": 0.0, "completion": 0.0},
|
||||
"qwen-72b-chat": {"prompt": 0.0, "completion": 0.0},
|
||||
"qwen-14b-chat": {"prompt": 0.0011, "completion": 0.0011},
|
||||
"qwen-7b-chat": {"prompt": 0.00084, "completion": 0.00084},
|
||||
"qwen-1.8b-chat": {"prompt": 0.0, "completion": 0.0},
|
||||
"baichuan2-13b-chat-v1": {"prompt": 0.0011, "completion": 0.0011},
|
||||
"baichuan2-7b-chat-v1": {"prompt": 0.00084, "completion": 0.00084},
|
||||
"baichuan-7b-v1": {"prompt": 0.0, "completion": 0.0},
|
||||
"chatglm-6b-v2": {"prompt": 0.0011, "completion": 0.0011},
|
||||
"chatglm3-6b": {"prompt": 0.0, "completion": 0.0},
|
||||
"ziya-llama-13b-v1": {"prompt": 0.0, "completion": 0.0}, # no price page, judge it as free
|
||||
"dolly-12b-v2": {"prompt": 0.0, "completion": 0.0},
|
||||
"belle-llama-13b-2m-v1": {"prompt": 0.0, "completion": 0.0},
|
||||
"moss-moon-003-sft-v1": {"prompt": 0.0, "completion": 0.0},
|
||||
"chatyuan-large-v2": {"prompt": 0.0, "completion": 0.0},
|
||||
"billa-7b-sft-v1": {"prompt": 0.0, "completion": 0.0},
|
||||
}
|
||||
|
||||
|
||||
TOKEN_MAX = {
|
||||
"gpt-3.5-turbo": 4096,
|
||||
"gpt-3.5-turbo-0301": 4096,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue