feat: merge main

This commit is contained in:
莘权 马 2024-08-06 16:22:42 +08:00
commit f19fcfa2df
407 changed files with 20083 additions and 1174 deletions

View file

@ -10,8 +10,8 @@ from metagpt.utils.read_document import read_docx
from metagpt.utils.singleton import Singleton
from metagpt.utils.token_counter import (
TOKEN_COSTS,
count_message_tokens,
count_string_tokens,
count_input_tokens,
count_output_tokens,
)
@ -19,6 +19,6 @@ __all__ = [
"read_docx",
"Singleton",
"TOKEN_COSTS",
"count_message_tokens",
"count_string_tokens",
"count_input_tokens",
"count_output_tokens",
]

View file

@ -20,3 +20,18 @@ def run_coroutine_in_new_loop(coroutine) -> Any:
new_loop.call_soon_threadsafe(new_loop.stop)
t.join()
new_loop.close()
class NestAsyncio:
"""Make asyncio event loop reentrant."""
is_applied = False
@classmethod
def apply_once(cls):
"""Ensures `nest_asyncio.apply()` is called only once."""
if not cls.is_applied:
import nest_asyncio
nest_asyncio.apply()
cls.is_applied = True

View file

@ -722,7 +722,10 @@ def list_files(root: str | Path) -> List[Path]:
def parse_json_code_block(markdown_text: str) -> List[str]:
json_blocks = re.findall(r"```json(.*?)```", markdown_text, re.DOTALL)
json_blocks = (
re.findall(r"```json(.*?)```", markdown_text, re.DOTALL) if "```json" in markdown_text else [markdown_text]
)
return [v.strip() for v in json_blocks]
@ -838,3 +841,21 @@ def get_markdown_codeblock_type(filename: str) -> str:
"application/sql": "sql",
}
return mappings.get(mime_type, "text")
def download_model(file_url: str, target_folder: Path) -> Path:
file_name = file_url.split("/")[-1]
file_path = target_folder.joinpath(f"{file_name}")
if not file_path.exists():
file_path.mkdir(parents=True, exist_ok=True)
try:
response = requests.get(file_url, stream=True)
response.raise_for_status() # 检查请求是否成功
# 保存文件
with open(file_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
logger.info(f"权重文件已下载并保存至 {file_path}")
except requests.exceptions.HTTPError as err:
logger.info(f"权重文件下载过程中发生错误: {err}")
return file_path

View file

@ -91,7 +91,7 @@ class DependencyFile:
try:
key = Path(filename).relative_to(root).as_posix()
except ValueError:
key = filename
key = Path(filename).as_posix()
return set(self._dependencies.get(str(key), {}))
def delete_file(self):

View file

@ -23,8 +23,8 @@ from metagpt.utils.graph_repository import SPO, GraphRepository
class DiGraphRepository(GraphRepository):
"""Graph repository based on DiGraph."""
def __init__(self, name: str, **kwargs):
super().__init__(name=name, **kwargs)
def __init__(self, name: str | Path, **kwargs):
super().__init__(name=str(name), **kwargs)
self._repo = networkx.DiGraph()
async def insert(self, subject: str, predicate: str, object_: str):
@ -112,8 +112,28 @@ class DiGraphRepository(GraphRepository):
async def load(self, pathname: str | Path):
"""Load a directed graph repository from a JSON file."""
data = await aread(filename=pathname, encoding="utf-8")
m = json.loads(data)
self.load_json(data)
def load_json(self, val: str):
"""
Loads a JSON-encoded string representing a graph structure and updates
the internal repository (_repo) with the parsed graph.
Args:
val (str): A JSON-encoded string representing a graph structure.
Returns:
self: Returns the instance of the class with the updated _repo attribute.
Raises:
TypeError: If val is not a valid JSON string or cannot be parsed into
a valid graph structure.
"""
if not val:
return self
m = json.loads(val)
self._repo = networkx.node_link_graph(m)
return self
@staticmethod
async def load_from(pathname: str | Path) -> GraphRepository:
@ -126,9 +146,7 @@ class DiGraphRepository(GraphRepository):
GraphRepository: A new instance of the graph repository loaded from the specified JSON file.
"""
pathname = Path(pathname)
name = pathname.with_suffix("").name
root = pathname.parent
graph = DiGraphRepository(name=name, root=root)
graph = DiGraphRepository(name=pathname.stem, root=pathname.parent)
if pathname.exists():
await graph.load(pathname=pathname)
return graph

View file

@ -78,7 +78,7 @@ class GitRepository:
self._repository = Repo.init(path=Path(local_path))
gitignore_filename = Path(local_path) / ".gitignore"
ignores = ["__pycache__", "*.pyc"]
ignores = ["__pycache__", "*.pyc", ".vs"]
with open(str(gitignore_filename), mode="w") as writer:
writer.write("\n".join(ignores))
self._repository.index.add([".gitignore"])

View file

@ -81,6 +81,8 @@ async def mermaid_to_file(engine, mermaid_code, output_file_without_suffix, widt
from metagpt.utils.mmdc_ink import mermaid_to_file
return await mermaid_to_file(mermaid_code, output_file_without_suffix)
elif engine == "none":
return 0
else:
logger.warning(f"Unsupported mermaid engine: {engine}")
return 0

View file

@ -53,30 +53,30 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
await page.wait_for_load_state("networkidle")
await page.wait_for_selector("div#container", state="attached")
# mermaid_config = {}
mermaid_config = {}
background_color = "#ffffff"
# my_css = ""
my_css = ""
await page.evaluate(f'document.body.style.background = "{background_color}";')
# metadata = await page.evaluate(
# """async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
# const { mermaid, zenuml } = globalThis;
# await mermaid.registerExternalDiagrams([zenuml]);
# mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
# const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
# document.getElementById('container').innerHTML = svg;
# const svgElement = document.querySelector('svg');
# svgElement.style.backgroundColor = backgroundColor;
#
# if (myCSS) {
# const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
# style.appendChild(document.createTextNode(myCSS));
# svgElement.appendChild(style);
# }
#
# }""",
# [mermaid_code, mermaid_config, my_css, background_color],
# )
await page.evaluate(
"""async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
const { mermaid, zenuml } = globalThis;
await mermaid.registerExternalDiagrams([zenuml]);
mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
document.getElementById('container').innerHTML = svg;
const svgElement = document.querySelector('svg');
svgElement.style.backgroundColor = backgroundColor;
if (myCSS) {
const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
style.appendChild(document.createTextNode(myCSS));
svgElement.appendChild(style);
}
}""",
[mermaid_code, mermaid_config, my_css, background_color],
)
if "svg" in suffixes:
svg_xml = await page.evaluate(

View file

@ -55,29 +55,29 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
await page.goto(mermaid_html_url)
await page.querySelector("div#container")
# mermaid_config = {}
mermaid_config = {}
background_color = "#ffffff"
# my_css = ""
my_css = ""
await page.evaluate(f'document.body.style.background = "{background_color}";')
# metadata = await page.evaluate(
# """async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
# const { mermaid, zenuml } = globalThis;
# await mermaid.registerExternalDiagrams([zenuml]);
# mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
# const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
# document.getElementById('container').innerHTML = svg;
# const svgElement = document.querySelector('svg');
# svgElement.style.backgroundColor = backgroundColor;
#
# if (myCSS) {
# const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
# style.appendChild(document.createTextNode(myCSS));
# svgElement.appendChild(style);
# }
# }""",
# [mermaid_code, mermaid_config, my_css, background_color],
# )
await page.evaluate(
"""async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
const { mermaid, zenuml } = globalThis;
await mermaid.registerExternalDiagrams([zenuml]);
mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
document.getElementById('container').innerHTML = svg;
const svgElement = document.querySelector('svg');
svgElement.style.backgroundColor = backgroundColor;
if (myCSS) {
const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
style.appendChild(document.createTextNode(myCSS));
svgElement.appendChild(style);
}
}""",
[mermaid_code, mermaid_config, my_css, background_color],
)
if "svg" in suffixes:
svg_xml = await page.evaluate(

View file

@ -0,0 +1,239 @@
import mimetypes
import os
from pathlib import Path
from typing import Union
import httpx
from metagpt.rag.schema import OmniParsedResult
from metagpt.utils.common import aread_bin
class OmniParseClient:
"""
OmniParse Server Client
This client interacts with the OmniParse server to parse different types of media, documents.
OmniParse API Documentation: https://docs.cognitivelab.in/api
Attributes:
ALLOWED_DOCUMENT_EXTENSIONS (set): A set of supported document file extensions.
ALLOWED_AUDIO_EXTENSIONS (set): A set of supported audio file extensions.
ALLOWED_VIDEO_EXTENSIONS (set): A set of supported video file extensions.
"""
ALLOWED_DOCUMENT_EXTENSIONS = {".pdf", ".ppt", ".pptx", ".doc", ".docx"}
ALLOWED_AUDIO_EXTENSIONS = {".mp3", ".wav", ".aac"}
ALLOWED_VIDEO_EXTENSIONS = {".mp4", ".mkv", ".avi", ".mov"}
def __init__(self, api_key: str = None, base_url: str = "http://localhost:8000", max_timeout: int = 120):
"""
Args:
api_key: Default None, can be used for authentication later.
base_url: Base URL for the API.
max_timeout: Maximum request timeout in seconds.
"""
self.api_key = api_key
self.base_url = base_url
self.max_timeout = max_timeout
self.parse_media_endpoint = "/parse_media"
self.parse_website_endpoint = "/parse_website"
self.parse_document_endpoint = "/parse_document"
async def _request_parse(
self,
endpoint: str,
method: str = "POST",
files: dict = None,
params: dict = None,
data: dict = None,
json: dict = None,
headers: dict = None,
**kwargs,
) -> dict:
"""
Request OmniParse API to parse a document.
Args:
endpoint (str): API endpoint.
method (str, optional): HTTP method to use. Default is "POST".
files (dict, optional): Files to include in the request.
params (dict, optional): Query string parameters.
data (dict, optional): Form data to include in the request body.
json (dict, optional): JSON data to include in the request body.
headers (dict, optional): HTTP headers to include in the request.
**kwargs: Additional keyword arguments for httpx.AsyncClient.request()
Returns:
dict: JSON response data.
"""
url = f"{self.base_url}{endpoint}"
method = method.upper()
headers = headers or {}
_headers = {"Authorization": f"Bearer {self.api_key}"} if self.api_key else {}
headers.update(**_headers)
async with httpx.AsyncClient() as client:
response = await client.request(
url=url,
method=method,
files=files,
params=params,
json=json,
data=data,
headers=headers,
timeout=self.max_timeout,
**kwargs,
)
response.raise_for_status()
return response.json()
async def parse_document(self, file_input: Union[str, bytes, Path], bytes_filename: str = None) -> OmniParsedResult:
"""
Parse document-type data (supports ".pdf", ".ppt", ".pptx", ".doc", ".docx").
Args:
file_input: File path or file byte data.
bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request.
Raises:
ValueError: If the file extension is not allowed.
Returns:
OmniParsedResult: The result of the document parsing.
"""
self.verify_file_ext(file_input, self.ALLOWED_DOCUMENT_EXTENSIONS, bytes_filename)
file_info = await self.get_file_info(file_input, bytes_filename)
resp = await self._request_parse(self.parse_document_endpoint, files={"file": file_info})
data = OmniParsedResult(**resp)
return data
async def parse_pdf(self, file_input: Union[str, bytes, Path]) -> OmniParsedResult:
"""
Parse pdf document.
Args:
file_input: File path or file byte data.
Raises:
ValueError: If the file extension is not allowed.
Returns:
OmniParsedResult: The result of the pdf parsing.
"""
self.verify_file_ext(file_input, {".pdf"})
# parse_pdf supports parsing by accepting only the byte data of the file.
file_info = await self.get_file_info(file_input, only_bytes=True)
endpoint = f"{self.parse_document_endpoint}/pdf"
resp = await self._request_parse(endpoint=endpoint, files={"file": file_info})
data = OmniParsedResult(**resp)
return data
async def parse_video(self, file_input: Union[str, bytes, Path], bytes_filename: str = None) -> dict:
"""
Parse video-type data (supports ".mp4", ".mkv", ".avi", ".mov").
Args:
file_input: File path or file byte data.
bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request.
Raises:
ValueError: If the file extension is not allowed.
Returns:
dict: JSON response data.
"""
self.verify_file_ext(file_input, self.ALLOWED_VIDEO_EXTENSIONS, bytes_filename)
file_info = await self.get_file_info(file_input, bytes_filename)
return await self._request_parse(f"{self.parse_media_endpoint}/video", files={"file": file_info})
async def parse_audio(self, file_input: Union[str, bytes, Path], bytes_filename: str = None) -> dict:
"""
Parse audio-type data (supports ".mp3", ".wav", ".aac").
Args:
file_input: File path or file byte data.
bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request.
Raises:
ValueError: If the file extension is not allowed.
Returns:
dict: JSON response data.
"""
self.verify_file_ext(file_input, self.ALLOWED_AUDIO_EXTENSIONS, bytes_filename)
file_info = await self.get_file_info(file_input, bytes_filename)
return await self._request_parse(f"{self.parse_media_endpoint}/audio", files={"file": file_info})
@staticmethod
def verify_file_ext(file_input: Union[str, bytes, Path], allowed_file_extensions: set, bytes_filename: str = None):
"""
Verify the file extension.
Args:
file_input: File path or file byte data.
allowed_file_extensions: Set of allowed file extensions.
bytes_filename: Filename to use for verification when `file_input` is byte data.
Raises:
ValueError: If the file extension is not allowed.
Returns:
"""
verify_file_path = None
if isinstance(file_input, (str, Path)):
verify_file_path = str(file_input)
elif isinstance(file_input, bytes) and bytes_filename:
verify_file_path = bytes_filename
if not verify_file_path:
# Do not verify if only byte data is provided
return
file_ext = os.path.splitext(verify_file_path)[1].lower()
if file_ext not in allowed_file_extensions:
raise ValueError(f"Not allowed {file_ext} File extension must be one of {allowed_file_extensions}")
@staticmethod
async def get_file_info(
file_input: Union[str, bytes, Path],
bytes_filename: str = None,
only_bytes: bool = False,
) -> Union[bytes, tuple]:
"""
Get file information.
Args:
file_input: File path or file byte data.
bytes_filename: Filename to use when uploading byte data, useful for determining MIME type.
only_bytes: Whether to return only byte data. Default is False, which returns a tuple.
Raises:
ValueError: If bytes_filename is not provided when file_input is bytes or if file_input is not a valid type.
Notes:
Since `parse_document`,`parse_video`, `parse_audio` supports parsing various file types,
the MIME type of the file must be specified when uploading.
Returns: [bytes, tuple]
Returns bytes if only_bytes is True, otherwise returns a tuple (filename, file_bytes, mime_type).
"""
if isinstance(file_input, (str, Path)):
filename = os.path.basename(str(file_input))
file_bytes = await aread_bin(file_input)
if only_bytes:
return file_bytes
mime_type = mimetypes.guess_type(file_input)[0]
return filename, file_bytes, mime_type
elif isinstance(file_input, bytes):
if only_bytes:
return file_input
if not bytes_filename:
raise ValueError("bytes_filename must be set when passing bytes")
mime_type = mimetypes.guess_type(bytes_filename)[0]
return bytes_filename, file_input, mime_type
else:
raise ValueError("file_input must be a string (file path) or bytes.")

View file

@ -3,7 +3,7 @@ from typing import Tuple
def remove_spaces(text):
return re.sub(r"\s+", " ", text).strip()
return re.sub(r"\s+", " ", text).strip() if text else ""
class DocstringParser:

View file

@ -10,7 +10,7 @@ from __future__ import annotations
import traceback
from datetime import timedelta
import aioredis # https://aioredis.readthedocs.io/en/latest/getting-started/
import redis.asyncio as aioredis
from metagpt.configs.redis_config import RedisConfig
from metagpt.logs import logger

View file

@ -0,0 +1,42 @@
# -*- coding: utf-8 -*-
# @Time : 2024/3/27 10:00
# @Author : leiwu30
# @File : stream_pipe.py
# @Version : None
# @Description : None
import json
import time
from multiprocessing import Pipe
class StreamPipe:
def __init__(self, name=None):
self.name = name
self.parent_conn, self.child_conn = Pipe()
self.finish: bool = False
format_data = {
"id": "chatcmpl-96bVnBOOyPFZZxEoTIGbdpFcVEnur",
"object": "chat.completion.chunk",
"created": 1711361191,
"model": "gpt-3.5-turbo-0125",
"system_fingerprint": "fp_3bc1b5746c",
"choices": [
{"index": 0, "delta": {"role": "assistant", "content": "content"}, "logprobs": None, "finish_reason": None}
],
}
def set_message(self, msg):
self.parent_conn.send(msg)
def get_message(self, timeout: int = 3):
if self.child_conn.poll(timeout):
return self.child_conn.recv()
else:
return None
def msg2stream(self, msg):
self.format_data["created"] = int(time.time())
self.format_data["choices"][0]["delta"]["content"] = msg
return f"data: {json.dumps(self.format_data, ensure_ascii=False)}\n".encode("utf-8")

View file

@ -1,6 +1,6 @@
from typing import Generator, Sequence
from metagpt.utils.token_counter import TOKEN_MAX, count_string_tokens
from metagpt.utils.token_counter import TOKEN_MAX, count_output_tokens
def reduce_message_length(
@ -23,9 +23,9 @@ def reduce_message_length(
Raises:
RuntimeError: If it fails to reduce the concatenated message length.
"""
max_token = TOKEN_MAX.get(model_name, 2048) - count_string_tokens(system_text, model_name) - reserved
max_token = TOKEN_MAX.get(model_name, 2048) - count_output_tokens(system_text, model_name) - reserved
for msg in msgs:
if count_string_tokens(msg, model_name) < max_token or model_name not in TOKEN_MAX:
if count_output_tokens(msg, model_name) < max_token or model_name not in TOKEN_MAX:
return msg
raise RuntimeError("fail to reduce message length")
@ -54,13 +54,13 @@ def generate_prompt_chunk(
current_token = 0
current_lines = []
reserved = reserved + count_string_tokens(prompt_template + system_text, model_name)
reserved = reserved + count_output_tokens(prompt_template + system_text, model_name)
# 100 is a magic number to ensure the maximum context length is not exceeded
max_token = TOKEN_MAX.get(model_name, 2048) - reserved - 100
while paragraphs:
paragraph = paragraphs.pop(0)
token = count_string_tokens(paragraph, model_name)
token = count_output_tokens(paragraph, model_name)
if current_token + token <= max_token:
current_lines.append(paragraph)
current_token += token

View file

@ -11,6 +11,11 @@ ref4: https://github.com/hwchase17/langchain/blob/master/langchain/chat_models/o
ref5: https://ai.google.dev/models/gemini
"""
import tiktoken
from openai.types import CompletionUsage
from openai.types.chat import ChatCompletionChunk
from metagpt.logs import logger
from metagpt.utils.ahttp_client import apost
TOKEN_COSTS = {
"gpt-3.5-turbo": {"prompt": 0.0015, "completion": 0.002},
@ -28,12 +33,15 @@ TOKEN_COSTS = {
"gpt-4-32k-0314": {"prompt": 0.06, "completion": 0.12},
"gpt-4-0613": {"prompt": 0.06, "completion": 0.12},
"gpt-4-turbo-preview": {"prompt": 0.01, "completion": 0.03},
"gpt-4-0125-preview": {"prompt": 0.01, "completion": 0.03},
"gpt-4-1106-preview": {"prompt": 0.01, "completion": 0.03},
"gpt-4-0125-preview": {"prompt": 0.01, "completion": 0.03},
"gpt-4-turbo": {"prompt": 0.01, "completion": 0.03},
"gpt-4-turbo-2024-04-09": {"prompt": 0.01, "completion": 0.03},
"gpt-4-vision-preview": {"prompt": 0.01, "completion": 0.03}, # TODO add extra image price calculator
"gpt-4-1106-vision-preview": {"prompt": 0.01, "completion": 0.03},
"gpt-4o": {"prompt": 0.005, "completion": 0.015},
"gpt-4o-mini": {"prompt": 0.00015, "completion": 0.0006},
"gpt-4o-2024-05-13": {"prompt": 0.005, "completion": 0.015},
"text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0},
"glm-3-turbo": {"prompt": 0.0007, "completion": 0.0007}, # 128k version, prompt + completion tokens=0.005¥/k-tokens
"glm-4": {"prompt": 0.014, "completion": 0.014}, # 128k version, prompt + completion tokens=0.1¥/k-tokens
@ -50,9 +58,28 @@ TOKEN_COSTS = {
"claude-2.0": {"prompt": 0.008, "completion": 0.024},
"claude-2.1": {"prompt": 0.008, "completion": 0.024},
"claude-3-sonnet-20240229": {"prompt": 0.003, "completion": 0.015},
"claude-3-5-sonnet-20240620": {"prompt": 0.003, "completion": 0.015},
"claude-3-opus-20240229": {"prompt": 0.015, "completion": 0.075},
"claude-3-haiku-20240307": {"prompt": 0.00025, "completion": 0.00125},
"yi-34b-chat-0205": {"prompt": 0.0003, "completion": 0.0003},
"yi-34b-chat-200k": {"prompt": 0.0017, "completion": 0.0017},
"yi-large": {"prompt": 0.0028, "completion": 0.0028},
"microsoft/wizardlm-2-8x22b": {"prompt": 0.00108, "completion": 0.00108}, # for openrouter, start
"meta-llama/llama-3-70b-instruct": {"prompt": 0.008, "completion": 0.008},
"llama3-70b-8192": {"prompt": 0.0059, "completion": 0.0079},
"openai/gpt-3.5-turbo-0125": {"prompt": 0.0005, "completion": 0.0015},
"openai/gpt-4-turbo-preview": {"prompt": 0.01, "completion": 0.03},
"deepseek-chat": {"prompt": 0.00014, "completion": 0.00028},
"deepseek-coder": {"prompt": 0.00014, "completion": 0.00028},
# For ark model https://www.volcengine.com/docs/82379/1099320
"doubao-lite-4k-240515": {"prompt": 0.000042, "completion": 0.000084},
"doubao-lite-32k-240515": {"prompt": 0.000042, "completion": 0.000084},
"doubao-lite-128k-240515": {"prompt": 0.00011, "completion": 0.00013},
"doubao-pro-4k-240515": {"prompt": 0.00011, "completion": 0.00028},
"doubao-pro-32k-240515": {"prompt": 0.00011, "completion": 0.00028},
"doubao-pro-128k-240515": {"prompt": 0.0007, "completion": 0.0012},
"llama3-70b-llama3-70b-instruct": {"prompt": 0.0, "completion": 0.0},
"llama3-8b-llama3-8b-instruct": {"prompt": 0.0, "completion": 0.0},
}
@ -113,14 +140,30 @@ DashScope Token price https://help.aliyun.com/zh/dashscope/developer-reference/t
Different model has different detail page. Attention, some model are free for a limited time.
"""
DASHSCOPE_TOKEN_COSTS = {
"qwen-turbo": {"prompt": 0.0011, "completion": 0.0011},
"qwen-plus": {"prompt": 0.0028, "completion": 0.0028},
"qwen-max": {"prompt": 0.0, "completion": 0.0},
"qwen-max-1201": {"prompt": 0.0, "completion": 0.0},
"qwen-max-longcontext": {"prompt": 0.0, "completion": 0.0},
"qwen2-72b-instruct": {"prompt": 0.000714, "completion": 0.001428},
"qwen2-57b-a14b-instruct": {"prompt": 0.0005, "completion": 0.001},
"qwen2-7b-instruct": {"prompt": 0.000143, "completion": 0.000286},
"qwen2-1.5b-instruct": {"prompt": 0, "completion": 0},
"qwen2-0.5b-instruct": {"prompt": 0, "completion": 0},
"qwen1.5-110b-chat": {"prompt": 0.001, "completion": 0.002},
"qwen1.5-72b-chat": {"prompt": 0.000714, "completion": 0.001428},
"qwen1.5-32b-chat": {"prompt": 0.0005, "completion": 0.001},
"qwen1.5-14b-chat": {"prompt": 0.000286, "completion": 0.000571},
"qwen1.5-7b-chat": {"prompt": 0.000143, "completion": 0.000286},
"qwen1.5-1.8b-chat": {"prompt": 0, "completion": 0},
"qwen1.5-0.5b-chat": {"prompt": 0, "completion": 0},
"qwen-turbo": {"prompt": 0.00028, "completion": 0.00083},
"qwen-long": {"prompt": 0.00007, "completion": 0.00028},
"qwen-plus": {"prompt": 0.00055, "completion": 0.00166},
"qwen-max": {"prompt": 0.0055, "completion": 0.0166},
"qwen-max-0428": {"prompt": 0.0055, "completion": 0.0166},
"qwen-max-0403": {"prompt": 0.0055, "completion": 0.0166},
"qwen-max-0107": {"prompt": 0.0055, "completion": 0.0166},
"qwen-max-1201": {"prompt": 0.0166, "completion": 0.0166},
"qwen-max-longcontext": {"prompt": 0.0055, "completion": 0.0166},
"llama2-7b-chat-v2": {"prompt": 0.0, "completion": 0.0},
"llama2-13b-chat-v2": {"prompt": 0.0, "completion": 0.0},
"qwen-72b-chat": {"prompt": 0.0, "completion": 0.0},
"qwen-72b-chat": {"prompt": 0.0028, "completion": 0.0028},
"qwen-14b-chat": {"prompt": 0.0011, "completion": 0.0011},
"qwen-7b-chat": {"prompt": 0.00084, "completion": 0.00084},
"qwen-1.8b-chat": {"prompt": 0.0, "completion": 0.0},
@ -147,9 +190,13 @@ FIREWORKS_GRADE_TOKEN_COSTS = {
# https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo
TOKEN_MAX = {
"gpt-4o-2024-05-13": 128000,
"gpt-4o": 128000,
"gpt-4-turbo-2024-04-09": 128000,
"gpt-4-0125-preview": 128000,
"gpt-4-turbo-preview": 128000,
"gpt-4-1106-preview": 128000,
"gpt-4-turbo": 128000,
"gpt-4-vision-preview": 128000,
"gpt-4-1106-vision-preview": 128000,
"gpt-4": 8192,
@ -157,7 +204,6 @@ TOKEN_MAX = {
"gpt-4-32k": 32768,
"gpt-4-32k-0613": 32768,
"gpt-4o-mini": 128000,
"gpt-4o": 128000,
"gpt-3.5-turbo-0125": 16385,
"gpt-3.5-turbo": 16385,
"gpt-3.5-turbo-1106": 16385,
@ -182,17 +228,108 @@ TOKEN_MAX = {
"claude-2.1": 200000,
"claude-3-sonnet-20240229": 200000,
"claude-3-opus-20240229": 200000,
"claude-3-5-sonnet-20240620": 200000,
"claude-3-haiku-20240307": 200000,
"yi-34b-chat-0205": 4000,
"yi-34b-chat-200k": 200000,
"yi-large": 16385,
"microsoft/wizardlm-2-8x22b": 65536,
"meta-llama/llama-3-70b-instruct": 8192,
"llama3-70b-8192": 8192,
"openai/gpt-3.5-turbo-0125": 16385,
"openai/gpt-4-turbo-preview": 128000,
"deepseek-chat": 32768,
"deepseek-coder": 16385,
"doubao-lite-4k-240515": 4000,
"doubao-lite-32k-240515": 32000,
"doubao-lite-128k-240515": 128000,
"doubao-pro-4k-240515": 4000,
"doubao-pro-32k-240515": 32000,
"doubao-pro-128k-240515": 128000,
# Qwen https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-qianwen-7b-14b-72b-api-detailes?spm=a2c4g.11186623.0.i20
"qwen2-57b-a14b-instruct": 32768,
"qwen2-72b-instruct": 131072,
"qwen2-7b-instruct": 32768,
"qwen2-1.5b-instruct": 32768,
"qwen2-0.5b-instruct": 32768,
"qwen1.5-110b-chat": 32000,
"qwen1.5-72b-chat": 32000,
"qwen1.5-32b-chat": 32000,
"qwen1.5-14b-chat": 8000,
"qwen1.5-7b-chat": 32000,
"qwen1.5-1.8b-chat": 32000,
"qwen1.5-0.5b-chat": 32000,
"codeqwen1.5-7b-chat": 64000,
"qwen-72b-chat": 32000,
"qwen-14b-chat": 8000,
"qwen-7b-chat": 32000,
"qwen-1.8b-longcontext-chat": 32000,
"qwen-1.8b-chat": 8000,
}
# For Amazon Bedrock US region
# See https://aws.amazon.com/cn/bedrock/pricing/
BEDROCK_TOKEN_COSTS = {
"amazon.titan-tg1-large": {"prompt": 0.0008, "completion": 0.0008},
"amazon.titan-text-express-v1": {"prompt": 0.0008, "completion": 0.0008},
"amazon.titan-text-express-v1:0:8k": {"prompt": 0.0008, "completion": 0.0008},
"amazon.titan-text-lite-v1:0:4k": {"prompt": 0.0003, "completion": 0.0004},
"amazon.titan-text-lite-v1": {"prompt": 0.0003, "completion": 0.0004},
"anthropic.claude-instant-v1": {"prompt": 0.0008, "completion": 0.00024},
"anthropic.claude-instant-v1:2:100k": {"prompt": 0.0008, "completion": 0.00024},
"anthropic.claude-v1": {"prompt": 0.008, "completion": 0.0024},
"anthropic.claude-v2": {"prompt": 0.008, "completion": 0.0024},
"anthropic.claude-v2:1": {"prompt": 0.008, "completion": 0.0024},
"anthropic.claude-v2:0:18k": {"prompt": 0.008, "completion": 0.0024},
"anthropic.claude-v2:1:200k": {"prompt": 0.008, "completion": 0.0024},
"anthropic.claude-3-sonnet-20240229-v1:0": {"prompt": 0.003, "completion": 0.015},
"anthropic.claude-3-sonnet-20240229-v1:0:28k": {"prompt": 0.003, "completion": 0.015},
"anthropic.claude-3-sonnet-20240229-v1:0:200k": {"prompt": 0.003, "completion": 0.015},
"anthropic.claude-3-5-sonnet-20240620-v1:0": {"prompt": 0.003, "completion": 0.015},
"anthropic.claude-3-haiku-20240307-v1:0": {"prompt": 0.00025, "completion": 0.00125},
"anthropic.claude-3-haiku-20240307-v1:0:48k": {"prompt": 0.00025, "completion": 0.00125},
"anthropic.claude-3-haiku-20240307-v1:0:200k": {"prompt": 0.00025, "completion": 0.00125},
# currently (2024-4-29) only available at US West (Oregon) AWS Region.
"anthropic.claude-3-opus-20240229-v1:0": {"prompt": 0.015, "completion": 0.075},
"cohere.command-text-v14": {"prompt": 0.0015, "completion": 0.0015},
"cohere.command-text-v14:7:4k": {"prompt": 0.0015, "completion": 0.0015},
"cohere.command-light-text-v14": {"prompt": 0.0003, "completion": 0.0003},
"cohere.command-light-text-v14:7:4k": {"prompt": 0.0003, "completion": 0.0003},
"meta.llama2-13b-chat-v1:0:4k": {"prompt": 0.00075, "completion": 0.001},
"meta.llama2-13b-chat-v1": {"prompt": 0.00075, "completion": 0.001},
"meta.llama2-70b-v1": {"prompt": 0.00195, "completion": 0.00256},
"meta.llama2-70b-v1:0:4k": {"prompt": 0.00195, "completion": 0.00256},
"meta.llama2-70b-chat-v1": {"prompt": 0.00195, "completion": 0.00256},
"meta.llama2-70b-chat-v1:0:4k": {"prompt": 0.00195, "completion": 0.00256},
"meta.llama3-8b-instruct-v1:0": {"prompt": 0.0004, "completion": 0.0006},
"meta.llama3-70b-instruct-v1:0": {"prompt": 0.00265, "completion": 0.0035},
"mistral.mistral-7b-instruct-v0:2": {"prompt": 0.00015, "completion": 0.0002},
"mistral.mixtral-8x7b-instruct-v0:1": {"prompt": 0.00045, "completion": 0.0007},
"mistral.mistral-large-2402-v1:0": {"prompt": 0.008, "completion": 0.024},
"ai21.j2-grande-instruct": {"prompt": 0.0125, "completion": 0.0125},
"ai21.j2-jumbo-instruct": {"prompt": 0.0188, "completion": 0.0188},
"ai21.j2-mid": {"prompt": 0.0125, "completion": 0.0125},
"ai21.j2-mid-v1": {"prompt": 0.0125, "completion": 0.0125},
"ai21.j2-ultra": {"prompt": 0.0188, "completion": 0.0188},
"ai21.j2-ultra-v1": {"prompt": 0.0188, "completion": 0.0188},
}
# https://xinghuo.xfyun.cn/sparkapi?scr=price
SPARK_TOKENS = {
"general": {"prompt": 0.0, "completion": 0.0}, # Spark-Lite
"generalv2": {"prompt": 0.0188, "completion": 0.0188}, # Spark V2.0
"generalv3": {"prompt": 0.0035, "completion": 0.0035}, # Spark Pro
"generalv3.5": {"prompt": 0.0035, "completion": 0.0035}, # Spark3.5 Max
}
def count_message_tokens(messages, model="gpt-3.5-turbo-0125"):
def count_input_tokens(messages, model="gpt-3.5-turbo-0125"):
"""Return the number of tokens used by a list of messages."""
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
print("Warning: model not found. Using cl100k_base encoding.")
logger.info(f"Warning: model {model} not found in tiktoken. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
if model in {
"gpt-3.5-turbo-0613",
@ -206,12 +343,14 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0125"):
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613",
"gpt-4-turbo",
"gpt-4-turbo-preview",
"gpt-4-0125-preview",
"gpt-4-1106-preview",
"gpt-4-turbo",
"gpt-4-vision-preview",
"gpt-4-1106-vision-preview",
"gpt-4o",
"gpt-4o-2024-05-13",
"gpt-4o-mini",
}:
tokens_per_message = 3 # # every reply is primed with <|start|>assistant<|message|>
@ -220,11 +359,11 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0125"):
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_name = -1 # if there's a name, the role is omitted
elif "gpt-3.5-turbo" == model:
print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0125.")
return count_message_tokens(messages, model="gpt-3.5-turbo-0125")
logger.info("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0125.")
return count_input_tokens(messages, model="gpt-3.5-turbo-0125")
elif "gpt-4" == model:
print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
return count_message_tokens(messages, model="gpt-4-0613")
logger.info("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
return count_input_tokens(messages, model="gpt-4-0613")
elif "open-llm-model" == model:
"""
For self-hosted open_llm api, they include lots of different models. The message tokens calculation is
@ -255,21 +394,21 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0125"):
return num_tokens
def count_string_tokens(string: str, model_name: str) -> int:
def count_output_tokens(string: str, model: str) -> int:
"""
Returns the number of tokens in a text string.
Args:
string (str): The text string.
model_name (str): The name of the encoding to use. (e.g., "gpt-3.5-turbo")
model (str): The name of the encoding to use. (e.g., "gpt-3.5-turbo")
Returns:
int: The number of tokens in the text string.
"""
try:
encoding = tiktoken.encoding_for_model(model_name)
encoding = tiktoken.encoding_for_model(model)
except KeyError:
print("Warning: model not found. Using cl100k_base encoding.")
logger.info(f"Warning: model {model} not found in tiktoken. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
return len(encoding.encode(string))
@ -286,4 +425,16 @@ def get_max_completion_tokens(messages: list[dict], model: str, default: int) ->
"""
if model not in TOKEN_MAX:
return default
return TOKEN_MAX[model] - count_message_tokens(messages) - 1
return TOKEN_MAX[model] - count_input_tokens(messages) - 1
async def get_openrouter_tokens(chunk: ChatCompletionChunk) -> CompletionUsage:
"""refs to https://openrouter.ai/docs#querying-cost-and-stats"""
url = f"https://openrouter.ai/api/v1/generation?id={chunk.id}"
resp = await apost(url=url, as_json=True)
tokens_prompt = resp.get("tokens_prompt", 0)
completion_tokens = resp.get("tokens_completion", 0)
usage = CompletionUsage(
prompt_tokens=tokens_prompt, completion_tokens=completion_tokens, total_tokens=tokens_prompt + completion_tokens
)
return usage