Merge remote-tracking branch 'origin/main'

This commit is contained in:
mannaandpoem 2024-02-29 16:50:50 +08:00
commit eb3c6d14f9
41 changed files with 1093 additions and 363 deletions

View file

@ -24,6 +24,8 @@ class LLMType(Enum):
METAGPT = "metagpt"
AZURE = "azure"
OLLAMA = "ollama"
QIANFAN = "qianfan" # Baidu BCE
DASHSCOPE = "dashscope" # Aliyun LingJi DashScope
def __missing__(self, key):
return self.OPENAI
@ -36,13 +38,18 @@ class LLMConfig(YamlModel):
Optional Fields in pydantic: https://docs.pydantic.dev/latest/migration/#required-optional-and-nullable-fields
"""
api_key: str
api_key: str = "sk-"
api_type: LLMType = LLMType.OPENAI
base_url: str = "https://api.openai.com/v1"
api_version: Optional[str] = None
model: Optional[str] = None # also stands for DEPLOYMENT_NAME
# For Cloud Service Provider like Baidu/ Alibaba
access_key: Optional[str] = None
secret_key: Optional[str] = None
endpoint: Optional[str] = None # for self-deployed model on the cloud
# For Spark(Xunfei), maybe remove later
app_id: Optional[str] = None
api_secret: Optional[str] = None

View file

@ -7,7 +7,6 @@
from pathlib import Path
from typing import Optional
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores.faiss import FAISS
from langchain_core.embeddings import Embeddings
@ -15,6 +14,7 @@ from metagpt.const import DATA_PATH, MEM_TTL
from metagpt.document_store.faiss_store import FaissStore
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.utils.embedding import get_embedding
from metagpt.utils.serialize import deserialize_message, serialize_message
@ -30,7 +30,7 @@ class MemoryStorage(FaissStore):
self.threshold: float = 0.1 # experience value. TODO The threshold to filter similar memories
self._initialized: bool = False
self.embedding = embedding or OpenAIEmbeddings()
self.embedding = embedding or get_embedding()
self.store: FAISS = None # Faiss engine
@property

View file

@ -16,6 +16,8 @@ from metagpt.provider.azure_openai_api import AzureOpenAILLM
from metagpt.provider.metagpt_api import MetaGPTLLM
from metagpt.provider.human_provider import HumanProvider
from metagpt.provider.spark_api import SparkLLM
from metagpt.provider.qianfan_api import QianFanLLM
from metagpt.provider.dashscope_api import DashScopeLLM
__all__ = [
"FireworksLLM",
@ -28,4 +30,6 @@ __all__ = [
"OllamaLLM",
"HumanProvider",
"SparkLLM",
"QianFanLLM",
"DashScopeLLM",
]

View file

@ -11,11 +11,12 @@ from abc import ABC, abstractmethod
from typing import Optional, Union
from openai import AsyncOpenAI
from pydantic import BaseModel
from metagpt.configs.llm_config import LLMConfig
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.cost_manager import CostManager, Costs
class BaseLLM(ABC):
@ -67,6 +68,28 @@ class BaseLLM(ABC):
def _default_system_msg(self):
return self._system_msg(self.system_prompt)
def _update_costs(self, usage: Union[dict, BaseModel], model: str = None, local_calc_usage: bool = True):
"""update each request's token cost
Args:
model (str): model name or in some scenarios called endpoint
local_calc_usage (bool): some models don't calculate usage, it will overwrite LLMConfig.calc_usage
"""
calc_usage = self.config.calc_usage and local_calc_usage
model = model or self.model
usage = usage.model_dump() if isinstance(usage, BaseModel) else usage
if calc_usage and self.cost_manager:
try:
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
self.cost_manager.update_cost(prompt_tokens, completion_tokens, model)
except Exception as e:
logger.error(f"{self.__class__.__name__} updates costs failed! exp: {e}")
def get_costs(self) -> Costs:
if not self.cost_manager:
return Costs(0, 0, 0, 0)
return self.cost_manager.get_costs()
async def aask(
self,
msg: str,

View file

@ -0,0 +1,248 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
import json
from http import HTTPStatus
from typing import Any, AsyncGenerator, Dict, List, Union
import dashscope
from dashscope.aigc.generation import Generation
from dashscope.api_entities.aiohttp_request import AioHttpRequest
from dashscope.api_entities.api_request_data import ApiRequestData
from dashscope.api_entities.api_request_factory import _get_protocol_params
from dashscope.api_entities.dashscope_response import (
GenerationOutput,
GenerationResponse,
Message,
)
from dashscope.client.base_api import BaseAioApi
from dashscope.common.constants import SERVICE_API_PATH, ApiProtocol
from dashscope.common.error import (
InputDataRequired,
InputRequired,
ModelRequired,
UnsupportedApiProtocol,
)
from tenacity import (
after_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_random_exponential,
)
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_llm import BaseLLM, LLMConfig
from metagpt.provider.llm_provider_registry import LLMType, register_provider
from metagpt.provider.openai_api import log_and_reraise
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.token_counter import DASHSCOPE_TOKEN_COSTS
def build_api_arequest(
model: str, input: object, task_group: str, task: str, function: str, api_key: str, is_service=True, **kwargs
):
(
api_protocol,
ws_stream_mode,
is_binary_input,
http_method,
stream,
async_request,
query,
headers,
request_timeout,
form,
resources,
) = _get_protocol_params(kwargs)
task_id = kwargs.pop("task_id", None)
if api_protocol in [ApiProtocol.HTTP, ApiProtocol.HTTPS]:
if not dashscope.base_http_api_url.endswith("/"):
http_url = dashscope.base_http_api_url + "/"
else:
http_url = dashscope.base_http_api_url
if is_service:
http_url = http_url + SERVICE_API_PATH + "/"
if task_group:
http_url += "%s/" % task_group
if task:
http_url += "%s/" % task
if function:
http_url += function
request = AioHttpRequest(
url=http_url,
api_key=api_key,
http_method=http_method,
stream=stream,
async_request=async_request,
query=query,
timeout=request_timeout,
task_id=task_id,
)
else:
raise UnsupportedApiProtocol("Unsupported protocol: %s, support [http, https, websocket]" % api_protocol)
if headers is not None:
request.add_headers(headers=headers)
if input is None and form is None:
raise InputDataRequired("There is no input data and form data")
request_data = ApiRequestData(
model,
task_group=task_group,
task=task,
function=function,
input=input,
form=form,
is_binary_input=is_binary_input,
api_protocol=api_protocol,
)
request_data.add_resources(resources)
request_data.add_parameters(**kwargs)
request.data = request_data
return request
class AGeneration(Generation, BaseAioApi):
@classmethod
async def acall(
cls,
model: str,
prompt: Any = None,
history: list = None,
api_key: str = None,
messages: List[Message] = None,
plugins: Union[str, Dict[str, Any]] = None,
**kwargs,
) -> Union[GenerationResponse, AsyncGenerator[GenerationResponse, None]]:
if (prompt is None or not prompt) and (messages is None or not messages):
raise InputRequired("prompt or messages is required!")
if model is None or not model:
raise ModelRequired("Model is required!")
task_group, function = "aigc", "generation" # fixed value
if plugins is not None:
headers = kwargs.pop("headers", {})
if isinstance(plugins, str):
headers["X-DashScope-Plugin"] = plugins
else:
headers["X-DashScope-Plugin"] = json.dumps(plugins)
kwargs["headers"] = headers
input, parameters = cls._build_input_parameters(model, prompt, history, messages, **kwargs)
api_key, model = BaseAioApi._validate_params(api_key, model)
request = build_api_arequest(
model=model,
input=input,
task_group=task_group,
task=Generation.task,
function=function,
api_key=api_key,
**kwargs,
)
response = await request.aio_call()
is_stream = kwargs.get("stream", False)
if is_stream:
async def aresp_iterator(response):
async for resp in response:
yield GenerationResponse.from_api_response(resp)
return aresp_iterator(response)
else:
return GenerationResponse.from_api_response(response)
@register_provider(LLMType.DASHSCOPE)
class DashScopeLLM(BaseLLM):
def __init__(self, llm_config: LLMConfig):
self.config = llm_config
self.use_system_prompt = False # only some models support system_prompt
self.__init_dashscope()
self.cost_manager = CostManager(token_costs=self.token_costs)
def __init_dashscope(self):
self.model = self.config.model
self.api_key = self.config.api_key
self.token_costs = DASHSCOPE_TOKEN_COSTS
self.aclient: AGeneration = AGeneration
# check support system_message models
support_system_models = [
"qwen-", # all support
"llama2-", # all support
"baichuan2-7b-chat-v1",
"chatglm3-6b",
]
for support_model in support_system_models:
if support_model in self.model:
self.use_system_prompt = True
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
kwargs = {
"api_key": self.api_key,
"model": self.model,
"messages": messages,
"stream": stream,
"result_format": "message",
}
if self.config.temperature > 0:
# different model has default temperature. only set when it"s specified.
kwargs["temperature"] = self.config.temperature
if stream:
kwargs["incremental_output"] = True
return kwargs
def _check_response(self, resp: GenerationResponse):
if resp.status_code != HTTPStatus.OK:
raise RuntimeError(f"code: {resp.code}, request_id: {resp.request_id}, message: {resp.message}")
def get_choice_text(self, output: GenerationOutput) -> str:
return output.get("choices", [{}])[0].get("message", {}).get("content", "")
def completion(self, messages: list[dict]) -> GenerationOutput:
resp: GenerationResponse = self.aclient.call(**self._const_kwargs(messages, stream=False))
self._check_response(resp)
self._update_costs(dict(resp.usage))
return resp.output
async def _achat_completion(self, messages: list[dict]) -> GenerationOutput:
resp: GenerationResponse = await self.aclient.acall(**self._const_kwargs(messages, stream=False))
self._check_response(resp)
self._update_costs(dict(resp.usage))
return resp.output
async def acompletion(self, messages: list[dict], timeout=3) -> GenerationOutput:
return await self._achat_completion(messages)
async def _achat_completion_stream(self, messages: list[dict]) -> str:
resp = await self.aclient.acall(**self._const_kwargs(messages, stream=True))
collected_content = []
usage = {}
async for chunk in resp:
self._check_response(chunk)
content = chunk.output.choices[0]["message"]["content"]
usage = dict(chunk.usage) # each chunk has usage
log_llm_stream(content)
collected_content.append(content)
log_llm_stream("\n")
self._update_costs(usage)
full_content = "".join(collected_content)
return full_content
@retry(
stop=stop_after_attempt(3),
wait=wait_random_exponential(min=1, max=60),
after=after_log(logger, logger.level("WARNING").name),
retry=retry_if_exception_type(ConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
if stream:
return await self._achat_completion_stream(messages)
resp = await self._achat_completion(messages)
return self.get_choice_text(resp)

View file

@ -16,10 +16,10 @@ from tenacity import (
)
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.logs import logger
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import OpenAILLM, log_and_reraise
from metagpt.utils.cost_manager import CostManager, Costs
from metagpt.utils.cost_manager import CostManager
MODEL_GRADE_TOKEN_COSTS = {
"-1": {"prompt": 0.0, "completion": 0.0}, # abnormal condition
@ -81,17 +81,6 @@ class FireworksLLM(OpenAILLM):
kwargs = dict(api_key=self.config.api_key, base_url=self.config.base_url)
return kwargs
def _update_costs(self, usage: CompletionUsage):
if self.config.calc_usage and usage:
try:
# use FireworksCostManager not context.cost_manager
self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
except Exception as e:
logger.error(f"updating costs failed!, exp: {e}")
def get_costs(self) -> Costs:
return self.cost_manager.get_costs()
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
**self._cons_kwargs(messages), stream=True
@ -107,10 +96,11 @@ class FireworksLLM(OpenAILLM):
finish_reason = choice.finish_reason if hasattr(choice, "finish_reason") else None
if choice_delta.content:
collected_content.append(choice_delta.content)
print(choice_delta.content, end="")
log_llm_stream(choice_delta.content)
if finish_reason:
# fireworks api return usage when finish_reason is not None
usage = CompletionUsage(**chunk.usage)
log_llm_stream("\n")
full_content = "".join(collected_content)
self._update_costs(usage)

View file

@ -72,16 +72,6 @@ class GeminiLLM(BaseLLM):
kwargs = {"contents": messages, "generation_config": GenerationConfig(temperature=0.3), "stream": stream}
return kwargs
def _update_costs(self, usage: dict):
"""update each request's token cost"""
if self.config.calc_usage:
try:
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
except Exception as e:
logger.error(f"google gemini updats costs failed! exp: {e}")
def get_choice_text(self, resp: GenerateContentResponse) -> str:
return resp.text

View file

@ -46,16 +46,6 @@ class OllamaLLM(BaseLLM):
kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream}
return kwargs
def _update_costs(self, usage: dict):
"""update each request's token cost"""
if self.config.calc_usage:
try:
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
except Exception as e:
logger.error(f"ollama updats costs failed! exp: {e}")
def get_choice_text(self, resp: dict) -> str:
"""get the resp content from llm response"""
assist_msg = resp.get("message", {})

View file

@ -8,7 +8,7 @@ from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.logs import logger
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import OpenAILLM
from metagpt.utils.cost_manager import Costs, TokenCostManager
from metagpt.utils.cost_manager import TokenCostManager
from metagpt.utils.token_counter import count_message_tokens, count_string_tokens
@ -34,14 +34,3 @@ class OpenLLM(OpenAILLM):
logger.error(f"usage calculation failed!: {e}")
return usage
def _update_costs(self, usage: CompletionUsage):
if self.config.calc_usage and usage:
try:
# use OpenLLMCostManager not CONFIG.cost_manager
self._cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
except Exception as e:
logger.error(f"updating costs failed!, exp: {e}")
def get_costs(self) -> Costs:
return self._cost_manager.get_costs()

View file

@ -30,7 +30,7 @@ from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.schema import Message
from metagpt.utils.common import CodeParser, decode_image
from metagpt.utils.cost_manager import CostManager, Costs
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.exceptions import handle_exception
from metagpt.utils.token_counter import (
count_message_tokens,
@ -56,16 +56,13 @@ class OpenAILLM(BaseLLM):
def __init__(self, config: LLMConfig):
self.config = config
self._init_model()
self._init_client()
self.auto_max_tokens = False
self.cost_manager: Optional[CostManager] = None
def _init_model(self):
self.model = self.config.model # Used in _calc_usage & _cons_kwargs
def _init_client(self):
"""https://github.com/openai/openai-python#async-usage"""
self.model = self.config.model # Used in _calc_usage & _cons_kwargs
kwargs = self._make_client_kwargs()
self.aclient = AsyncOpenAI(**kwargs)
@ -102,7 +99,7 @@ class OpenAILLM(BaseLLM):
"max_tokens": self._get_max_tokens(messages),
"n": 1,
# "stop": None, # default it's None and gpt4-v can't have this one
"temperature": 0.3,
"temperature": self.config.temperature,
"model": self.model,
"timeout": max(self.config.timeout, timeout),
}
@ -272,16 +269,6 @@ class OpenAILLM(BaseLLM):
return usage
@handle_exception
def _update_costs(self, usage: CompletionUsage):
if self.config.calc_usage and usage and self.cost_manager:
self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
def get_costs(self) -> Costs:
if not self.cost_manager:
return Costs(0, 0, 0, 0)
return self.cost_manager.get_costs()
def _get_max_tokens(self, messages: list[dict]):
if not self.auto_max_tokens:
return self.config.max_token

View file

@ -0,0 +1,152 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : llm api of qianfan from Baidu, supports ERNIE(wen xin yi yan) and opensource models
import copy
import os
import qianfan
from qianfan import ChatCompletion
from qianfan.resources.typing import JsonBody
from tenacity import (
after_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_random_exponential,
)
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import log_and_reraise
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.token_counter import (
QIANFAN_ENDPOINT_TOKEN_COSTS,
QIANFAN_MODEL_TOKEN_COSTS,
)
@register_provider(LLMType.QIANFAN)
class QianFanLLM(BaseLLM):
"""
Refs
Auth: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/3lmokh7n6#%E3%80%90%E6%8E%A8%E8%8D%90%E3%80%91%E4%BD%BF%E7%94%A8%E5%AE%89%E5%85%A8%E8%AE%A4%E8%AF%81aksk%E9%89%B4%E6%9D%83%E8%B0%83%E7%94%A8%E6%B5%81%E7%A8%8B
Token Price: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7#tokens%E5%90%8E%E4%BB%98%E8%B4%B9
Models: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/wlmhm7vuo#%E5%AF%B9%E8%AF%9Dchat
https://cloud.baidu.com/doc/WENXINWORKSHOP/s/xlmokikxe#%E6%94%AF%E6%8C%81%E6%A8%A1%E5%9E%8B%E5%88%97%E8%A1%A8
"""
def __init__(self, config: LLMConfig):
self.config = config
self.use_system_prompt = False # only some ERNIE-x related models support system_prompt
self.__init_qianfan()
self.cost_manager = CostManager(token_costs=self.token_costs)
def __init_qianfan(self):
if self.config.access_key and self.config.secret_key:
# for system level auth, use access_key and secret_key, recommended by official
# set environment variable due to official recommendation
os.environ.setdefault("QIANFAN_ACCESS_KEY", self.config.access_key)
os.environ.setdefault("QIANFAN_SECRET_KEY", self.config.secret_key)
elif self.config.api_key and self.config.secret_key:
# for application level auth, use api_key and secret_key
# set environment variable due to official recommendation
os.environ.setdefault("QIANFAN_AK", self.config.api_key)
os.environ.setdefault("QIANFAN_SK", self.config.secret_key)
else:
raise ValueError("Set the `access_key`&`secret_key` or `api_key`&`secret_key` first")
support_system_pairs = [
("ERNIE-Bot-4", "completions_pro"), # (model, corresponding-endpoint)
("ERNIE-Bot-8k", "ernie_bot_8k"),
("ERNIE-Bot", "completions"),
("ERNIE-Bot-turbo", "eb-instant"),
("ERNIE-Speed", "ernie_speed"),
("EB-turbo-AppBuilder", "ai_apaas"),
]
if self.config.model in [pair[0] for pair in support_system_pairs]:
# only some ERNIE models support
self.use_system_prompt = True
if self.config.endpoint in [pair[1] for pair in support_system_pairs]:
self.use_system_prompt = True
assert not (self.config.model and self.config.endpoint), "Only set `model` or `endpoint` in the config"
assert self.config.model or self.config.endpoint, "Should set one of `model` or `endpoint` in the config"
self.token_costs = copy.deepcopy(QIANFAN_MODEL_TOKEN_COSTS)
self.token_costs.update(QIANFAN_ENDPOINT_TOKEN_COSTS)
# self deployed model on the cloud not to calculate usage, it charges resource pool rental fee
self.calc_usage = self.config.calc_usage and self.config.endpoint is None
self.aclient: ChatCompletion = qianfan.ChatCompletion()
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
kwargs = {
"messages": messages,
"stream": stream,
}
if self.config.temperature > 0:
# different model has default temperature. only set when it's specified.
kwargs["temperature"] = self.config.temperature
if self.config.endpoint:
kwargs["endpoint"] = self.config.endpoint
elif self.config.model:
kwargs["model"] = self.config.model
if self.use_system_prompt:
# if the model support system prompt, extract and pass it
if messages[0]["role"] == "system":
kwargs["messages"] = messages[1:]
kwargs["system"] = messages[0]["content"] # set system prompt here
return kwargs
def _update_costs(self, usage: dict):
"""update each request's token cost"""
model_or_endpoint = self.config.model or self.config.endpoint
local_calc_usage = model_or_endpoint in self.token_costs
super()._update_costs(usage, model_or_endpoint, local_calc_usage)
def get_choice_text(self, resp: JsonBody) -> str:
return resp.get("result", "")
def completion(self, messages: list[dict]) -> JsonBody:
resp = self.aclient.do(**self._const_kwargs(messages=messages, stream=False))
self._update_costs(resp.body.get("usage", {}))
return resp.body
async def _achat_completion(self, messages: list[dict]) -> JsonBody:
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=False))
self._update_costs(resp.body.get("usage", {}))
return resp.body
async def acompletion(self, messages: list[dict], timeout=3) -> JsonBody:
return await self._achat_completion(messages)
async def _achat_completion_stream(self, messages: list[dict]) -> str:
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=True))
collected_content = []
usage = {}
async for chunk in resp:
content = chunk.body.get("result", "")
usage = chunk.body.get("usage", {})
log_llm_stream(content)
collected_content.append(content)
log_llm_stream("\n")
self._update_costs(usage)
full_content = "".join(collected_content)
return full_content
@retry(
stop=stop_after_attempt(3),
wait=wait_random_exponential(min=1, max=60),
after=after_log(logger, logger.level("WARNING").name),
retry=retry_if_exception_type(ConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
if stream:
return await self._achat_completion_stream(messages)
resp = await self._achat_completion(messages)
return self.get_choice_text(resp)

View file

@ -53,16 +53,6 @@ class ZhiPuAILLM(BaseLLM):
kwargs = {"model": self.model, "messages": messages, "stream": stream, "temperature": 0.3}
return kwargs
def _update_costs(self, usage: dict):
"""update each request's token cost"""
if self.config.calc_usage:
try:
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
except Exception as e:
logger.error(f"zhipuai updats costs failed! exp: {e}")
def completion(self, messages: list[dict], timeout=3) -> dict:
resp: Completion = self.llm.chat.completions.create(**self._const_kwargs(messages))
usage = resp.usage.model_dump()

View file

@ -281,7 +281,7 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
i = action
self._init_action(i)
self.actions.append(i)
self.states.append(f"{len(self.actions)}. {action}")
self.states.append(f"{len(self.actions) - 1}. {action}")
def _set_react_mode(self, react_mode: str, max_react_loop: int = 1, auto_run: bool = True, use_tools: bool = False):
"""Set strategy of the Role reacting to observed Message. Variation lies in how

View file

@ -2,14 +2,11 @@
# -*- coding: utf-8 -*-
import asyncio
import shutil
from pathlib import Path
import typer
from metagpt.config2 import config
from metagpt.const import CONFIG_ROOT, METAGPT_ROOT
from metagpt.context import Context
from metagpt.const import CONFIG_ROOT
from metagpt.utils.project_repo import ProjectRepo
app = typer.Typer(add_completion=False, pretty_exceptions_show_locals=False)
@ -30,6 +27,8 @@ def generate_repo(
recover_path=None,
) -> ProjectRepo:
"""Run the startup logic. Can be called from CLI or other Python scripts."""
from metagpt.config2 import config
from metagpt.context import Context
from metagpt.roles import (
Architect,
Engineer,
@ -122,7 +121,17 @@ def startup(
)
def copy_config_to(config_path=METAGPT_ROOT / "config" / "config2.yaml"):
DEFAULT_CONFIG = """# Full Example: https://github.com/geekan/MetaGPT/blob/main/config/config2.example.yaml
# Reflected Code: https://github.com/geekan/MetaGPT/blob/main/metagpt/config2.py
llm:
api_type: "openai" # or azure / ollama / open_llm etc. Check LLMType for more options
model: "gpt-4-turbo-preview" # or gpt-3.5-turbo-1106 / gpt-4-1106-preview
base_url: "https://api.openai.com/v1" # or forward url / other llm url
api_key: "YOUR_API_KEY"
"""
def copy_config_to():
"""Initialize the configuration file for MetaGPT."""
target_path = CONFIG_ROOT / "config2.yaml"
@ -136,7 +145,7 @@ def copy_config_to(config_path=METAGPT_ROOT / "config" / "config2.yaml"):
print(f"Existing configuration file backed up at {backup_path}")
# 复制文件
shutil.copy(str(config_path), target_path)
target_path.write_text(DEFAULT_CONFIG, encoding="utf-8")
print(f"Configuration file initialized at {target_path}")

View file

@ -29,6 +29,7 @@ class CostManager(BaseModel):
total_budget: float = 0
max_budget: float = 10.0
total_cost: float = 0
token_costs: dict[str, dict[str, float]] = TOKEN_COSTS # different model's token cost
def update_cost(self, prompt_tokens, completion_tokens, model):
"""
@ -41,12 +42,13 @@ class CostManager(BaseModel):
"""
self.total_prompt_tokens += prompt_tokens
self.total_completion_tokens += completion_tokens
if model not in TOKEN_COSTS:
if model not in self.token_costs:
logger.warning(f"Model {model} not found in TOKEN_COSTS.")
return
cost = (
prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model]["completion"]
prompt_tokens * self.token_costs[model]["prompt"]
+ completion_tokens * self.token_costs[model]["completion"]
) / 1000
self.total_cost += cost
logger.info(

View file

@ -119,6 +119,7 @@ def repair_json_format(output: str) -> str:
logger.info(f"repair_json_format: {'}]'}")
elif output.startswith("{") and output.endswith("]"):
output = output[:-1] + "}"
# remove comments in output json string, after json value content, maybe start with #, maybe start with //
arr = output.split("\n")
new_arr = []
@ -208,6 +209,17 @@ def repair_invalid_json(output: str, error: str) -> str:
elif (rline[col_no] in ["'", '"']) and (line.startswith('"') or line.startswith("'")) and "," not in line:
# problem, `"""` or `'''` without `,`
new_line = f",{line}"
elif col_no - 1 >= 0 and rline[col_no - 1] in ['"', "'"]:
# backslash problem like \" in the output
char = rline[col_no - 1]
nearest_char_idx = rline[col_no:].find(char)
new_line = (
rline[: col_no - 1]
+ "\\"
+ rline[col_no - 1 : col_no + nearest_char_idx]
+ "\\"
+ rline[col_no + nearest_char_idx :]
)
elif '",' not in line and "," not in line and '"' not in line:
new_line = f'{line}",'
elif not line.endswith(","):

View file

@ -38,6 +38,88 @@ TOKEN_COSTS = {
}
"""
QianFan Token Price https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7#tokens%E5%90%8E%E4%BB%98%E8%B4%B9
Due to QianFan has multi price strategies, we unify `Tokens post-payment` as a statistical method.
"""
QIANFAN_MODEL_TOKEN_COSTS = {
"ERNIE-Bot-4": {"prompt": 0.017, "completion": 0.017},
"ERNIE-Bot-8k": {"prompt": 0.0034, "completion": 0.0067},
"ERNIE-Bot": {"prompt": 0.0017, "completion": 0.0017},
"ERNIE-Bot-turbo": {"prompt": 0.0011, "completion": 0.0011},
"EB-turbo-AppBuilder": {"prompt": 0.0011, "completion": 0.0011},
"ERNIE-Speed": {"prompt": 0.00056, "completion": 0.0011},
"BLOOMZ-7B": {"prompt": 0.00056, "completion": 0.00056},
"Llama-2-7B-Chat": {"prompt": 0.00056, "completion": 0.00056},
"Llama-2-13B-Chat": {"prompt": 0.00084, "completion": 0.00084},
"Llama-2-70B-Chat": {"prompt": 0.0049, "completion": 0.0049},
"ChatGLM2-6B-32K": {"prompt": 0.00056, "completion": 0.00056},
"AquilaChat-7B": {"prompt": 0.00056, "completion": 0.00056},
"Mixtral-8x7B-Instruct": {"prompt": 0.0049, "completion": 0.0049},
"SQLCoder-7B": {"prompt": 0.00056, "completion": 0.00056},
"CodeLlama-7B-Instruct": {"prompt": 0.00056, "completion": 0.00056},
"XuanYuan-70B-Chat-4bit": {"prompt": 0.0049, "completion": 0.0049},
"Qianfan-BLOOMZ-7B-compressed": {"prompt": 0.00056, "completion": 0.00056},
"Qianfan-Chinese-Llama-2-7B": {"prompt": 0.00056, "completion": 0.00056},
"Qianfan-Chinese-Llama-2-13B": {"prompt": 0.00084, "completion": 0.00084},
"ChatLaw": {"prompt": 0.0011, "completion": 0.0011},
"Yi-34B-Chat": {"prompt": 0.0, "completion": 0.0},
}
QIANFAN_ENDPOINT_TOKEN_COSTS = {
"completions_pro": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Bot-4"],
"ernie_bot_8k": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Bot-8k"],
"completions": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Bot"],
"eb-instant": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Bot-turbo"],
"ai_apaas": QIANFAN_MODEL_TOKEN_COSTS["EB-turbo-AppBuilder"],
"ernie_speed": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Speed"],
"bloomz_7b1": QIANFAN_MODEL_TOKEN_COSTS["BLOOMZ-7B"],
"llama_2_7b": QIANFAN_MODEL_TOKEN_COSTS["Llama-2-7B-Chat"],
"llama_2_13b": QIANFAN_MODEL_TOKEN_COSTS["Llama-2-13B-Chat"],
"llama_2_70b": QIANFAN_MODEL_TOKEN_COSTS["Llama-2-70B-Chat"],
"chatglm2_6b_32k": QIANFAN_MODEL_TOKEN_COSTS["ChatGLM2-6B-32K"],
"aquilachat_7b": QIANFAN_MODEL_TOKEN_COSTS["AquilaChat-7B"],
"mixtral_8x7b_instruct": QIANFAN_MODEL_TOKEN_COSTS["Mixtral-8x7B-Instruct"],
"sqlcoder_7b": QIANFAN_MODEL_TOKEN_COSTS["SQLCoder-7B"],
"codellama_7b_instruct": QIANFAN_MODEL_TOKEN_COSTS["CodeLlama-7B-Instruct"],
"xuanyuan_70b_chat": QIANFAN_MODEL_TOKEN_COSTS["XuanYuan-70B-Chat-4bit"],
"qianfan_bloomz_7b_compressed": QIANFAN_MODEL_TOKEN_COSTS["Qianfan-BLOOMZ-7B-compressed"],
"qianfan_chinese_llama_2_7b": QIANFAN_MODEL_TOKEN_COSTS["Qianfan-Chinese-Llama-2-7B"],
"qianfan_chinese_llama_2_13b": QIANFAN_MODEL_TOKEN_COSTS["Qianfan-Chinese-Llama-2-13B"],
"chatlaw": QIANFAN_MODEL_TOKEN_COSTS["ChatLaw"],
"yi_34b_chat": QIANFAN_MODEL_TOKEN_COSTS["Yi-34B-Chat"],
}
"""
DashScope Token price https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing
Different model has different detail page. Attention, some model are free for a limited time.
"""
DASHSCOPE_TOKEN_COSTS = {
"qwen-turbo": {"prompt": 0.0011, "completion": 0.0011},
"qwen-plus": {"prompt": 0.0028, "completion": 0.0028},
"qwen-max": {"prompt": 0.0, "completion": 0.0},
"qwen-max-1201": {"prompt": 0.0, "completion": 0.0},
"qwen-max-longcontext": {"prompt": 0.0, "completion": 0.0},
"llama2-7b-chat-v2": {"prompt": 0.0, "completion": 0.0},
"llama2-13b-chat-v2": {"prompt": 0.0, "completion": 0.0},
"qwen-72b-chat": {"prompt": 0.0, "completion": 0.0},
"qwen-14b-chat": {"prompt": 0.0011, "completion": 0.0011},
"qwen-7b-chat": {"prompt": 0.00084, "completion": 0.00084},
"qwen-1.8b-chat": {"prompt": 0.0, "completion": 0.0},
"baichuan2-13b-chat-v1": {"prompt": 0.0011, "completion": 0.0011},
"baichuan2-7b-chat-v1": {"prompt": 0.00084, "completion": 0.00084},
"baichuan-7b-v1": {"prompt": 0.0, "completion": 0.0},
"chatglm-6b-v2": {"prompt": 0.0011, "completion": 0.0011},
"chatglm3-6b": {"prompt": 0.0, "completion": 0.0},
"ziya-llama-13b-v1": {"prompt": 0.0, "completion": 0.0}, # no price page, judge it as free
"dolly-12b-v2": {"prompt": 0.0, "completion": 0.0},
"belle-llama-13b-2m-v1": {"prompt": 0.0, "completion": 0.0},
"moss-moon-003-sft-v1": {"prompt": 0.0, "completion": 0.0},
"chatyuan-large-v2": {"prompt": 0.0, "completion": 0.0},
"billa-7b-sft-v1": {"prompt": 0.0, "completion": 0.0},
}
TOKEN_MAX = {
"gpt-3.5-turbo": 4096,
"gpt-3.5-turbo-0301": 4096,