mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-27 09:46:24 +02:00
feat: merge main
This commit is contained in:
commit
f19fcfa2df
407 changed files with 20083 additions and 1174 deletions
|
|
@ -10,8 +10,8 @@ from metagpt.utils.read_document import read_docx
|
|||
from metagpt.utils.singleton import Singleton
|
||||
from metagpt.utils.token_counter import (
|
||||
TOKEN_COSTS,
|
||||
count_message_tokens,
|
||||
count_string_tokens,
|
||||
count_input_tokens,
|
||||
count_output_tokens,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -19,6 +19,6 @@ __all__ = [
|
|||
"read_docx",
|
||||
"Singleton",
|
||||
"TOKEN_COSTS",
|
||||
"count_message_tokens",
|
||||
"count_string_tokens",
|
||||
"count_input_tokens",
|
||||
"count_output_tokens",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -20,3 +20,18 @@ def run_coroutine_in_new_loop(coroutine) -> Any:
|
|||
new_loop.call_soon_threadsafe(new_loop.stop)
|
||||
t.join()
|
||||
new_loop.close()
|
||||
|
||||
|
||||
class NestAsyncio:
|
||||
"""Make asyncio event loop reentrant."""
|
||||
|
||||
is_applied = False
|
||||
|
||||
@classmethod
|
||||
def apply_once(cls):
|
||||
"""Ensures `nest_asyncio.apply()` is called only once."""
|
||||
if not cls.is_applied:
|
||||
import nest_asyncio
|
||||
|
||||
nest_asyncio.apply()
|
||||
cls.is_applied = True
|
||||
|
|
|
|||
|
|
@ -722,7 +722,10 @@ def list_files(root: str | Path) -> List[Path]:
|
|||
|
||||
|
||||
def parse_json_code_block(markdown_text: str) -> List[str]:
|
||||
json_blocks = re.findall(r"```json(.*?)```", markdown_text, re.DOTALL)
|
||||
json_blocks = (
|
||||
re.findall(r"```json(.*?)```", markdown_text, re.DOTALL) if "```json" in markdown_text else [markdown_text]
|
||||
)
|
||||
|
||||
return [v.strip() for v in json_blocks]
|
||||
|
||||
|
||||
|
|
@ -838,3 +841,21 @@ def get_markdown_codeblock_type(filename: str) -> str:
|
|||
"application/sql": "sql",
|
||||
}
|
||||
return mappings.get(mime_type, "text")
|
||||
|
||||
|
||||
def download_model(file_url: str, target_folder: Path) -> Path:
|
||||
file_name = file_url.split("/")[-1]
|
||||
file_path = target_folder.joinpath(f"{file_name}")
|
||||
if not file_path.exists():
|
||||
file_path.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
response = requests.get(file_url, stream=True)
|
||||
response.raise_for_status() # 检查请求是否成功
|
||||
# 保存文件
|
||||
with open(file_path, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
logger.info(f"权重文件已下载并保存至 {file_path}")
|
||||
except requests.exceptions.HTTPError as err:
|
||||
logger.info(f"权重文件下载过程中发生错误: {err}")
|
||||
return file_path
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ class DependencyFile:
|
|||
try:
|
||||
key = Path(filename).relative_to(root).as_posix()
|
||||
except ValueError:
|
||||
key = filename
|
||||
key = Path(filename).as_posix()
|
||||
return set(self._dependencies.get(str(key), {}))
|
||||
|
||||
def delete_file(self):
|
||||
|
|
|
|||
|
|
@ -23,8 +23,8 @@ from metagpt.utils.graph_repository import SPO, GraphRepository
|
|||
class DiGraphRepository(GraphRepository):
|
||||
"""Graph repository based on DiGraph."""
|
||||
|
||||
def __init__(self, name: str, **kwargs):
|
||||
super().__init__(name=name, **kwargs)
|
||||
def __init__(self, name: str | Path, **kwargs):
|
||||
super().__init__(name=str(name), **kwargs)
|
||||
self._repo = networkx.DiGraph()
|
||||
|
||||
async def insert(self, subject: str, predicate: str, object_: str):
|
||||
|
|
@ -112,8 +112,28 @@ class DiGraphRepository(GraphRepository):
|
|||
async def load(self, pathname: str | Path):
|
||||
"""Load a directed graph repository from a JSON file."""
|
||||
data = await aread(filename=pathname, encoding="utf-8")
|
||||
m = json.loads(data)
|
||||
self.load_json(data)
|
||||
|
||||
def load_json(self, val: str):
|
||||
"""
|
||||
Loads a JSON-encoded string representing a graph structure and updates
|
||||
the internal repository (_repo) with the parsed graph.
|
||||
|
||||
Args:
|
||||
val (str): A JSON-encoded string representing a graph structure.
|
||||
|
||||
Returns:
|
||||
self: Returns the instance of the class with the updated _repo attribute.
|
||||
|
||||
Raises:
|
||||
TypeError: If val is not a valid JSON string or cannot be parsed into
|
||||
a valid graph structure.
|
||||
"""
|
||||
if not val:
|
||||
return self
|
||||
m = json.loads(val)
|
||||
self._repo = networkx.node_link_graph(m)
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
async def load_from(pathname: str | Path) -> GraphRepository:
|
||||
|
|
@ -126,9 +146,7 @@ class DiGraphRepository(GraphRepository):
|
|||
GraphRepository: A new instance of the graph repository loaded from the specified JSON file.
|
||||
"""
|
||||
pathname = Path(pathname)
|
||||
name = pathname.with_suffix("").name
|
||||
root = pathname.parent
|
||||
graph = DiGraphRepository(name=name, root=root)
|
||||
graph = DiGraphRepository(name=pathname.stem, root=pathname.parent)
|
||||
if pathname.exists():
|
||||
await graph.load(pathname=pathname)
|
||||
return graph
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@ class GitRepository:
|
|||
self._repository = Repo.init(path=Path(local_path))
|
||||
|
||||
gitignore_filename = Path(local_path) / ".gitignore"
|
||||
ignores = ["__pycache__", "*.pyc"]
|
||||
ignores = ["__pycache__", "*.pyc", ".vs"]
|
||||
with open(str(gitignore_filename), mode="w") as writer:
|
||||
writer.write("\n".join(ignores))
|
||||
self._repository.index.add([".gitignore"])
|
||||
|
|
|
|||
|
|
@ -81,6 +81,8 @@ async def mermaid_to_file(engine, mermaid_code, output_file_without_suffix, widt
|
|||
from metagpt.utils.mmdc_ink import mermaid_to_file
|
||||
|
||||
return await mermaid_to_file(mermaid_code, output_file_without_suffix)
|
||||
elif engine == "none":
|
||||
return 0
|
||||
else:
|
||||
logger.warning(f"Unsupported mermaid engine: {engine}")
|
||||
return 0
|
||||
|
|
|
|||
|
|
@ -53,30 +53,30 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
|
|||
await page.wait_for_load_state("networkidle")
|
||||
|
||||
await page.wait_for_selector("div#container", state="attached")
|
||||
# mermaid_config = {}
|
||||
mermaid_config = {}
|
||||
background_color = "#ffffff"
|
||||
# my_css = ""
|
||||
my_css = ""
|
||||
await page.evaluate(f'document.body.style.background = "{background_color}";')
|
||||
|
||||
# metadata = await page.evaluate(
|
||||
# """async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
|
||||
# const { mermaid, zenuml } = globalThis;
|
||||
# await mermaid.registerExternalDiagrams([zenuml]);
|
||||
# mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
|
||||
# const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
|
||||
# document.getElementById('container').innerHTML = svg;
|
||||
# const svgElement = document.querySelector('svg');
|
||||
# svgElement.style.backgroundColor = backgroundColor;
|
||||
#
|
||||
# if (myCSS) {
|
||||
# const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
|
||||
# style.appendChild(document.createTextNode(myCSS));
|
||||
# svgElement.appendChild(style);
|
||||
# }
|
||||
#
|
||||
# }""",
|
||||
# [mermaid_code, mermaid_config, my_css, background_color],
|
||||
# )
|
||||
await page.evaluate(
|
||||
"""async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
|
||||
const { mermaid, zenuml } = globalThis;
|
||||
await mermaid.registerExternalDiagrams([zenuml]);
|
||||
mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
|
||||
const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
|
||||
document.getElementById('container').innerHTML = svg;
|
||||
const svgElement = document.querySelector('svg');
|
||||
svgElement.style.backgroundColor = backgroundColor;
|
||||
|
||||
if (myCSS) {
|
||||
const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
|
||||
style.appendChild(document.createTextNode(myCSS));
|
||||
svgElement.appendChild(style);
|
||||
}
|
||||
|
||||
}""",
|
||||
[mermaid_code, mermaid_config, my_css, background_color],
|
||||
)
|
||||
|
||||
if "svg" in suffixes:
|
||||
svg_xml = await page.evaluate(
|
||||
|
|
|
|||
|
|
@ -55,29 +55,29 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
|
|||
await page.goto(mermaid_html_url)
|
||||
|
||||
await page.querySelector("div#container")
|
||||
# mermaid_config = {}
|
||||
mermaid_config = {}
|
||||
background_color = "#ffffff"
|
||||
# my_css = ""
|
||||
my_css = ""
|
||||
await page.evaluate(f'document.body.style.background = "{background_color}";')
|
||||
|
||||
# metadata = await page.evaluate(
|
||||
# """async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
|
||||
# const { mermaid, zenuml } = globalThis;
|
||||
# await mermaid.registerExternalDiagrams([zenuml]);
|
||||
# mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
|
||||
# const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
|
||||
# document.getElementById('container').innerHTML = svg;
|
||||
# const svgElement = document.querySelector('svg');
|
||||
# svgElement.style.backgroundColor = backgroundColor;
|
||||
#
|
||||
# if (myCSS) {
|
||||
# const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
|
||||
# style.appendChild(document.createTextNode(myCSS));
|
||||
# svgElement.appendChild(style);
|
||||
# }
|
||||
# }""",
|
||||
# [mermaid_code, mermaid_config, my_css, background_color],
|
||||
# )
|
||||
await page.evaluate(
|
||||
"""async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
|
||||
const { mermaid, zenuml } = globalThis;
|
||||
await mermaid.registerExternalDiagrams([zenuml]);
|
||||
mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
|
||||
const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
|
||||
document.getElementById('container').innerHTML = svg;
|
||||
const svgElement = document.querySelector('svg');
|
||||
svgElement.style.backgroundColor = backgroundColor;
|
||||
|
||||
if (myCSS) {
|
||||
const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
|
||||
style.appendChild(document.createTextNode(myCSS));
|
||||
svgElement.appendChild(style);
|
||||
}
|
||||
}""",
|
||||
[mermaid_code, mermaid_config, my_css, background_color],
|
||||
)
|
||||
|
||||
if "svg" in suffixes:
|
||||
svg_xml = await page.evaluate(
|
||||
|
|
|
|||
239
metagpt/utils/omniparse_client.py
Normal file
239
metagpt/utils/omniparse_client.py
Normal file
|
|
@ -0,0 +1,239 @@
|
|||
import mimetypes
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import httpx
|
||||
|
||||
from metagpt.rag.schema import OmniParsedResult
|
||||
from metagpt.utils.common import aread_bin
|
||||
|
||||
|
||||
class OmniParseClient:
|
||||
"""
|
||||
OmniParse Server Client
|
||||
This client interacts with the OmniParse server to parse different types of media, documents.
|
||||
|
||||
OmniParse API Documentation: https://docs.cognitivelab.in/api
|
||||
|
||||
Attributes:
|
||||
ALLOWED_DOCUMENT_EXTENSIONS (set): A set of supported document file extensions.
|
||||
ALLOWED_AUDIO_EXTENSIONS (set): A set of supported audio file extensions.
|
||||
ALLOWED_VIDEO_EXTENSIONS (set): A set of supported video file extensions.
|
||||
"""
|
||||
|
||||
ALLOWED_DOCUMENT_EXTENSIONS = {".pdf", ".ppt", ".pptx", ".doc", ".docx"}
|
||||
ALLOWED_AUDIO_EXTENSIONS = {".mp3", ".wav", ".aac"}
|
||||
ALLOWED_VIDEO_EXTENSIONS = {".mp4", ".mkv", ".avi", ".mov"}
|
||||
|
||||
def __init__(self, api_key: str = None, base_url: str = "http://localhost:8000", max_timeout: int = 120):
|
||||
"""
|
||||
Args:
|
||||
api_key: Default None, can be used for authentication later.
|
||||
base_url: Base URL for the API.
|
||||
max_timeout: Maximum request timeout in seconds.
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.max_timeout = max_timeout
|
||||
|
||||
self.parse_media_endpoint = "/parse_media"
|
||||
self.parse_website_endpoint = "/parse_website"
|
||||
self.parse_document_endpoint = "/parse_document"
|
||||
|
||||
async def _request_parse(
|
||||
self,
|
||||
endpoint: str,
|
||||
method: str = "POST",
|
||||
files: dict = None,
|
||||
params: dict = None,
|
||||
data: dict = None,
|
||||
json: dict = None,
|
||||
headers: dict = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
"""
|
||||
Request OmniParse API to parse a document.
|
||||
|
||||
Args:
|
||||
endpoint (str): API endpoint.
|
||||
method (str, optional): HTTP method to use. Default is "POST".
|
||||
files (dict, optional): Files to include in the request.
|
||||
params (dict, optional): Query string parameters.
|
||||
data (dict, optional): Form data to include in the request body.
|
||||
json (dict, optional): JSON data to include in the request body.
|
||||
headers (dict, optional): HTTP headers to include in the request.
|
||||
**kwargs: Additional keyword arguments for httpx.AsyncClient.request()
|
||||
|
||||
Returns:
|
||||
dict: JSON response data.
|
||||
"""
|
||||
url = f"{self.base_url}{endpoint}"
|
||||
method = method.upper()
|
||||
headers = headers or {}
|
||||
_headers = {"Authorization": f"Bearer {self.api_key}"} if self.api_key else {}
|
||||
headers.update(**_headers)
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.request(
|
||||
url=url,
|
||||
method=method,
|
||||
files=files,
|
||||
params=params,
|
||||
json=json,
|
||||
data=data,
|
||||
headers=headers,
|
||||
timeout=self.max_timeout,
|
||||
**kwargs,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def parse_document(self, file_input: Union[str, bytes, Path], bytes_filename: str = None) -> OmniParsedResult:
|
||||
"""
|
||||
Parse document-type data (supports ".pdf", ".ppt", ".pptx", ".doc", ".docx").
|
||||
|
||||
Args:
|
||||
file_input: File path or file byte data.
|
||||
bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request.
|
||||
|
||||
Raises:
|
||||
ValueError: If the file extension is not allowed.
|
||||
|
||||
Returns:
|
||||
OmniParsedResult: The result of the document parsing.
|
||||
"""
|
||||
self.verify_file_ext(file_input, self.ALLOWED_DOCUMENT_EXTENSIONS, bytes_filename)
|
||||
file_info = await self.get_file_info(file_input, bytes_filename)
|
||||
resp = await self._request_parse(self.parse_document_endpoint, files={"file": file_info})
|
||||
data = OmniParsedResult(**resp)
|
||||
return data
|
||||
|
||||
async def parse_pdf(self, file_input: Union[str, bytes, Path]) -> OmniParsedResult:
|
||||
"""
|
||||
Parse pdf document.
|
||||
|
||||
Args:
|
||||
file_input: File path or file byte data.
|
||||
|
||||
Raises:
|
||||
ValueError: If the file extension is not allowed.
|
||||
|
||||
Returns:
|
||||
OmniParsedResult: The result of the pdf parsing.
|
||||
"""
|
||||
self.verify_file_ext(file_input, {".pdf"})
|
||||
# parse_pdf supports parsing by accepting only the byte data of the file.
|
||||
file_info = await self.get_file_info(file_input, only_bytes=True)
|
||||
endpoint = f"{self.parse_document_endpoint}/pdf"
|
||||
resp = await self._request_parse(endpoint=endpoint, files={"file": file_info})
|
||||
data = OmniParsedResult(**resp)
|
||||
return data
|
||||
|
||||
async def parse_video(self, file_input: Union[str, bytes, Path], bytes_filename: str = None) -> dict:
|
||||
"""
|
||||
Parse video-type data (supports ".mp4", ".mkv", ".avi", ".mov").
|
||||
|
||||
Args:
|
||||
file_input: File path or file byte data.
|
||||
bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request.
|
||||
|
||||
Raises:
|
||||
ValueError: If the file extension is not allowed.
|
||||
|
||||
Returns:
|
||||
dict: JSON response data.
|
||||
"""
|
||||
self.verify_file_ext(file_input, self.ALLOWED_VIDEO_EXTENSIONS, bytes_filename)
|
||||
file_info = await self.get_file_info(file_input, bytes_filename)
|
||||
return await self._request_parse(f"{self.parse_media_endpoint}/video", files={"file": file_info})
|
||||
|
||||
async def parse_audio(self, file_input: Union[str, bytes, Path], bytes_filename: str = None) -> dict:
|
||||
"""
|
||||
Parse audio-type data (supports ".mp3", ".wav", ".aac").
|
||||
|
||||
Args:
|
||||
file_input: File path or file byte data.
|
||||
bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request.
|
||||
|
||||
Raises:
|
||||
ValueError: If the file extension is not allowed.
|
||||
|
||||
Returns:
|
||||
dict: JSON response data.
|
||||
"""
|
||||
self.verify_file_ext(file_input, self.ALLOWED_AUDIO_EXTENSIONS, bytes_filename)
|
||||
file_info = await self.get_file_info(file_input, bytes_filename)
|
||||
return await self._request_parse(f"{self.parse_media_endpoint}/audio", files={"file": file_info})
|
||||
|
||||
@staticmethod
|
||||
def verify_file_ext(file_input: Union[str, bytes, Path], allowed_file_extensions: set, bytes_filename: str = None):
|
||||
"""
|
||||
Verify the file extension.
|
||||
|
||||
Args:
|
||||
file_input: File path or file byte data.
|
||||
allowed_file_extensions: Set of allowed file extensions.
|
||||
bytes_filename: Filename to use for verification when `file_input` is byte data.
|
||||
|
||||
Raises:
|
||||
ValueError: If the file extension is not allowed.
|
||||
|
||||
Returns:
|
||||
"""
|
||||
verify_file_path = None
|
||||
if isinstance(file_input, (str, Path)):
|
||||
verify_file_path = str(file_input)
|
||||
elif isinstance(file_input, bytes) and bytes_filename:
|
||||
verify_file_path = bytes_filename
|
||||
|
||||
if not verify_file_path:
|
||||
# Do not verify if only byte data is provided
|
||||
return
|
||||
|
||||
file_ext = os.path.splitext(verify_file_path)[1].lower()
|
||||
if file_ext not in allowed_file_extensions:
|
||||
raise ValueError(f"Not allowed {file_ext} File extension must be one of {allowed_file_extensions}")
|
||||
|
||||
@staticmethod
|
||||
async def get_file_info(
|
||||
file_input: Union[str, bytes, Path],
|
||||
bytes_filename: str = None,
|
||||
only_bytes: bool = False,
|
||||
) -> Union[bytes, tuple]:
|
||||
"""
|
||||
Get file information.
|
||||
|
||||
Args:
|
||||
file_input: File path or file byte data.
|
||||
bytes_filename: Filename to use when uploading byte data, useful for determining MIME type.
|
||||
only_bytes: Whether to return only byte data. Default is False, which returns a tuple.
|
||||
|
||||
Raises:
|
||||
ValueError: If bytes_filename is not provided when file_input is bytes or if file_input is not a valid type.
|
||||
|
||||
Notes:
|
||||
Since `parse_document`,`parse_video`, `parse_audio` supports parsing various file types,
|
||||
the MIME type of the file must be specified when uploading.
|
||||
|
||||
Returns: [bytes, tuple]
|
||||
Returns bytes if only_bytes is True, otherwise returns a tuple (filename, file_bytes, mime_type).
|
||||
"""
|
||||
if isinstance(file_input, (str, Path)):
|
||||
filename = os.path.basename(str(file_input))
|
||||
file_bytes = await aread_bin(file_input)
|
||||
|
||||
if only_bytes:
|
||||
return file_bytes
|
||||
|
||||
mime_type = mimetypes.guess_type(file_input)[0]
|
||||
return filename, file_bytes, mime_type
|
||||
elif isinstance(file_input, bytes):
|
||||
if only_bytes:
|
||||
return file_input
|
||||
if not bytes_filename:
|
||||
raise ValueError("bytes_filename must be set when passing bytes")
|
||||
|
||||
mime_type = mimetypes.guess_type(bytes_filename)[0]
|
||||
return bytes_filename, file_input, mime_type
|
||||
else:
|
||||
raise ValueError("file_input must be a string (file path) or bytes.")
|
||||
|
|
@ -3,7 +3,7 @@ from typing import Tuple
|
|||
|
||||
|
||||
def remove_spaces(text):
|
||||
return re.sub(r"\s+", " ", text).strip()
|
||||
return re.sub(r"\s+", " ", text).strip() if text else ""
|
||||
|
||||
|
||||
class DocstringParser:
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from __future__ import annotations
|
|||
import traceback
|
||||
from datetime import timedelta
|
||||
|
||||
import aioredis # https://aioredis.readthedocs.io/en/latest/getting-started/
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
from metagpt.configs.redis_config import RedisConfig
|
||||
from metagpt.logs import logger
|
||||
|
|
|
|||
42
metagpt/utils/stream_pipe.py
Normal file
42
metagpt/utils/stream_pipe.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Time : 2024/3/27 10:00
|
||||
# @Author : leiwu30
|
||||
# @File : stream_pipe.py
|
||||
# @Version : None
|
||||
# @Description : None
|
||||
|
||||
import json
|
||||
import time
|
||||
from multiprocessing import Pipe
|
||||
|
||||
|
||||
class StreamPipe:
|
||||
def __init__(self, name=None):
|
||||
self.name = name
|
||||
self.parent_conn, self.child_conn = Pipe()
|
||||
self.finish: bool = False
|
||||
|
||||
format_data = {
|
||||
"id": "chatcmpl-96bVnBOOyPFZZxEoTIGbdpFcVEnur",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 1711361191,
|
||||
"model": "gpt-3.5-turbo-0125",
|
||||
"system_fingerprint": "fp_3bc1b5746c",
|
||||
"choices": [
|
||||
{"index": 0, "delta": {"role": "assistant", "content": "content"}, "logprobs": None, "finish_reason": None}
|
||||
],
|
||||
}
|
||||
|
||||
def set_message(self, msg):
|
||||
self.parent_conn.send(msg)
|
||||
|
||||
def get_message(self, timeout: int = 3):
|
||||
if self.child_conn.poll(timeout):
|
||||
return self.child_conn.recv()
|
||||
else:
|
||||
return None
|
||||
|
||||
def msg2stream(self, msg):
|
||||
self.format_data["created"] = int(time.time())
|
||||
self.format_data["choices"][0]["delta"]["content"] = msg
|
||||
return f"data: {json.dumps(self.format_data, ensure_ascii=False)}\n".encode("utf-8")
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Generator, Sequence
|
||||
|
||||
from metagpt.utils.token_counter import TOKEN_MAX, count_string_tokens
|
||||
from metagpt.utils.token_counter import TOKEN_MAX, count_output_tokens
|
||||
|
||||
|
||||
def reduce_message_length(
|
||||
|
|
@ -23,9 +23,9 @@ def reduce_message_length(
|
|||
Raises:
|
||||
RuntimeError: If it fails to reduce the concatenated message length.
|
||||
"""
|
||||
max_token = TOKEN_MAX.get(model_name, 2048) - count_string_tokens(system_text, model_name) - reserved
|
||||
max_token = TOKEN_MAX.get(model_name, 2048) - count_output_tokens(system_text, model_name) - reserved
|
||||
for msg in msgs:
|
||||
if count_string_tokens(msg, model_name) < max_token or model_name not in TOKEN_MAX:
|
||||
if count_output_tokens(msg, model_name) < max_token or model_name not in TOKEN_MAX:
|
||||
return msg
|
||||
|
||||
raise RuntimeError("fail to reduce message length")
|
||||
|
|
@ -54,13 +54,13 @@ def generate_prompt_chunk(
|
|||
current_token = 0
|
||||
current_lines = []
|
||||
|
||||
reserved = reserved + count_string_tokens(prompt_template + system_text, model_name)
|
||||
reserved = reserved + count_output_tokens(prompt_template + system_text, model_name)
|
||||
# 100 is a magic number to ensure the maximum context length is not exceeded
|
||||
max_token = TOKEN_MAX.get(model_name, 2048) - reserved - 100
|
||||
|
||||
while paragraphs:
|
||||
paragraph = paragraphs.pop(0)
|
||||
token = count_string_tokens(paragraph, model_name)
|
||||
token = count_output_tokens(paragraph, model_name)
|
||||
if current_token + token <= max_token:
|
||||
current_lines.append(paragraph)
|
||||
current_token += token
|
||||
|
|
|
|||
|
|
@ -11,6 +11,11 @@ ref4: https://github.com/hwchase17/langchain/blob/master/langchain/chat_models/o
|
|||
ref5: https://ai.google.dev/models/gemini
|
||||
"""
|
||||
import tiktoken
|
||||
from openai.types import CompletionUsage
|
||||
from openai.types.chat import ChatCompletionChunk
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.ahttp_client import apost
|
||||
|
||||
TOKEN_COSTS = {
|
||||
"gpt-3.5-turbo": {"prompt": 0.0015, "completion": 0.002},
|
||||
|
|
@ -28,12 +33,15 @@ TOKEN_COSTS = {
|
|||
"gpt-4-32k-0314": {"prompt": 0.06, "completion": 0.12},
|
||||
"gpt-4-0613": {"prompt": 0.06, "completion": 0.12},
|
||||
"gpt-4-turbo-preview": {"prompt": 0.01, "completion": 0.03},
|
||||
"gpt-4-0125-preview": {"prompt": 0.01, "completion": 0.03},
|
||||
"gpt-4-1106-preview": {"prompt": 0.01, "completion": 0.03},
|
||||
"gpt-4-0125-preview": {"prompt": 0.01, "completion": 0.03},
|
||||
"gpt-4-turbo": {"prompt": 0.01, "completion": 0.03},
|
||||
"gpt-4-turbo-2024-04-09": {"prompt": 0.01, "completion": 0.03},
|
||||
"gpt-4-vision-preview": {"prompt": 0.01, "completion": 0.03}, # TODO add extra image price calculator
|
||||
"gpt-4-1106-vision-preview": {"prompt": 0.01, "completion": 0.03},
|
||||
"gpt-4o": {"prompt": 0.005, "completion": 0.015},
|
||||
"gpt-4o-mini": {"prompt": 0.00015, "completion": 0.0006},
|
||||
"gpt-4o-2024-05-13": {"prompt": 0.005, "completion": 0.015},
|
||||
"text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0},
|
||||
"glm-3-turbo": {"prompt": 0.0007, "completion": 0.0007}, # 128k version, prompt + completion tokens=0.005¥/k-tokens
|
||||
"glm-4": {"prompt": 0.014, "completion": 0.014}, # 128k version, prompt + completion tokens=0.1¥/k-tokens
|
||||
|
|
@ -50,9 +58,28 @@ TOKEN_COSTS = {
|
|||
"claude-2.0": {"prompt": 0.008, "completion": 0.024},
|
||||
"claude-2.1": {"prompt": 0.008, "completion": 0.024},
|
||||
"claude-3-sonnet-20240229": {"prompt": 0.003, "completion": 0.015},
|
||||
"claude-3-5-sonnet-20240620": {"prompt": 0.003, "completion": 0.015},
|
||||
"claude-3-opus-20240229": {"prompt": 0.015, "completion": 0.075},
|
||||
"claude-3-haiku-20240307": {"prompt": 0.00025, "completion": 0.00125},
|
||||
"yi-34b-chat-0205": {"prompt": 0.0003, "completion": 0.0003},
|
||||
"yi-34b-chat-200k": {"prompt": 0.0017, "completion": 0.0017},
|
||||
"yi-large": {"prompt": 0.0028, "completion": 0.0028},
|
||||
"microsoft/wizardlm-2-8x22b": {"prompt": 0.00108, "completion": 0.00108}, # for openrouter, start
|
||||
"meta-llama/llama-3-70b-instruct": {"prompt": 0.008, "completion": 0.008},
|
||||
"llama3-70b-8192": {"prompt": 0.0059, "completion": 0.0079},
|
||||
"openai/gpt-3.5-turbo-0125": {"prompt": 0.0005, "completion": 0.0015},
|
||||
"openai/gpt-4-turbo-preview": {"prompt": 0.01, "completion": 0.03},
|
||||
"deepseek-chat": {"prompt": 0.00014, "completion": 0.00028},
|
||||
"deepseek-coder": {"prompt": 0.00014, "completion": 0.00028},
|
||||
# For ark model https://www.volcengine.com/docs/82379/1099320
|
||||
"doubao-lite-4k-240515": {"prompt": 0.000042, "completion": 0.000084},
|
||||
"doubao-lite-32k-240515": {"prompt": 0.000042, "completion": 0.000084},
|
||||
"doubao-lite-128k-240515": {"prompt": 0.00011, "completion": 0.00013},
|
||||
"doubao-pro-4k-240515": {"prompt": 0.00011, "completion": 0.00028},
|
||||
"doubao-pro-32k-240515": {"prompt": 0.00011, "completion": 0.00028},
|
||||
"doubao-pro-128k-240515": {"prompt": 0.0007, "completion": 0.0012},
|
||||
"llama3-70b-llama3-70b-instruct": {"prompt": 0.0, "completion": 0.0},
|
||||
"llama3-8b-llama3-8b-instruct": {"prompt": 0.0, "completion": 0.0},
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -113,14 +140,30 @@ DashScope Token price https://help.aliyun.com/zh/dashscope/developer-reference/t
|
|||
Different model has different detail page. Attention, some model are free for a limited time.
|
||||
"""
|
||||
DASHSCOPE_TOKEN_COSTS = {
|
||||
"qwen-turbo": {"prompt": 0.0011, "completion": 0.0011},
|
||||
"qwen-plus": {"prompt": 0.0028, "completion": 0.0028},
|
||||
"qwen-max": {"prompt": 0.0, "completion": 0.0},
|
||||
"qwen-max-1201": {"prompt": 0.0, "completion": 0.0},
|
||||
"qwen-max-longcontext": {"prompt": 0.0, "completion": 0.0},
|
||||
"qwen2-72b-instruct": {"prompt": 0.000714, "completion": 0.001428},
|
||||
"qwen2-57b-a14b-instruct": {"prompt": 0.0005, "completion": 0.001},
|
||||
"qwen2-7b-instruct": {"prompt": 0.000143, "completion": 0.000286},
|
||||
"qwen2-1.5b-instruct": {"prompt": 0, "completion": 0},
|
||||
"qwen2-0.5b-instruct": {"prompt": 0, "completion": 0},
|
||||
"qwen1.5-110b-chat": {"prompt": 0.001, "completion": 0.002},
|
||||
"qwen1.5-72b-chat": {"prompt": 0.000714, "completion": 0.001428},
|
||||
"qwen1.5-32b-chat": {"prompt": 0.0005, "completion": 0.001},
|
||||
"qwen1.5-14b-chat": {"prompt": 0.000286, "completion": 0.000571},
|
||||
"qwen1.5-7b-chat": {"prompt": 0.000143, "completion": 0.000286},
|
||||
"qwen1.5-1.8b-chat": {"prompt": 0, "completion": 0},
|
||||
"qwen1.5-0.5b-chat": {"prompt": 0, "completion": 0},
|
||||
"qwen-turbo": {"prompt": 0.00028, "completion": 0.00083},
|
||||
"qwen-long": {"prompt": 0.00007, "completion": 0.00028},
|
||||
"qwen-plus": {"prompt": 0.00055, "completion": 0.00166},
|
||||
"qwen-max": {"prompt": 0.0055, "completion": 0.0166},
|
||||
"qwen-max-0428": {"prompt": 0.0055, "completion": 0.0166},
|
||||
"qwen-max-0403": {"prompt": 0.0055, "completion": 0.0166},
|
||||
"qwen-max-0107": {"prompt": 0.0055, "completion": 0.0166},
|
||||
"qwen-max-1201": {"prompt": 0.0166, "completion": 0.0166},
|
||||
"qwen-max-longcontext": {"prompt": 0.0055, "completion": 0.0166},
|
||||
"llama2-7b-chat-v2": {"prompt": 0.0, "completion": 0.0},
|
||||
"llama2-13b-chat-v2": {"prompt": 0.0, "completion": 0.0},
|
||||
"qwen-72b-chat": {"prompt": 0.0, "completion": 0.0},
|
||||
"qwen-72b-chat": {"prompt": 0.0028, "completion": 0.0028},
|
||||
"qwen-14b-chat": {"prompt": 0.0011, "completion": 0.0011},
|
||||
"qwen-7b-chat": {"prompt": 0.00084, "completion": 0.00084},
|
||||
"qwen-1.8b-chat": {"prompt": 0.0, "completion": 0.0},
|
||||
|
|
@ -147,9 +190,13 @@ FIREWORKS_GRADE_TOKEN_COSTS = {
|
|||
|
||||
# https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo
|
||||
TOKEN_MAX = {
|
||||
"gpt-4o-2024-05-13": 128000,
|
||||
"gpt-4o": 128000,
|
||||
"gpt-4-turbo-2024-04-09": 128000,
|
||||
"gpt-4-0125-preview": 128000,
|
||||
"gpt-4-turbo-preview": 128000,
|
||||
"gpt-4-1106-preview": 128000,
|
||||
"gpt-4-turbo": 128000,
|
||||
"gpt-4-vision-preview": 128000,
|
||||
"gpt-4-1106-vision-preview": 128000,
|
||||
"gpt-4": 8192,
|
||||
|
|
@ -157,7 +204,6 @@ TOKEN_MAX = {
|
|||
"gpt-4-32k": 32768,
|
||||
"gpt-4-32k-0613": 32768,
|
||||
"gpt-4o-mini": 128000,
|
||||
"gpt-4o": 128000,
|
||||
"gpt-3.5-turbo-0125": 16385,
|
||||
"gpt-3.5-turbo": 16385,
|
||||
"gpt-3.5-turbo-1106": 16385,
|
||||
|
|
@ -182,17 +228,108 @@ TOKEN_MAX = {
|
|||
"claude-2.1": 200000,
|
||||
"claude-3-sonnet-20240229": 200000,
|
||||
"claude-3-opus-20240229": 200000,
|
||||
"claude-3-5-sonnet-20240620": 200000,
|
||||
"claude-3-haiku-20240307": 200000,
|
||||
"yi-34b-chat-0205": 4000,
|
||||
"yi-34b-chat-200k": 200000,
|
||||
"yi-large": 16385,
|
||||
"microsoft/wizardlm-2-8x22b": 65536,
|
||||
"meta-llama/llama-3-70b-instruct": 8192,
|
||||
"llama3-70b-8192": 8192,
|
||||
"openai/gpt-3.5-turbo-0125": 16385,
|
||||
"openai/gpt-4-turbo-preview": 128000,
|
||||
"deepseek-chat": 32768,
|
||||
"deepseek-coder": 16385,
|
||||
"doubao-lite-4k-240515": 4000,
|
||||
"doubao-lite-32k-240515": 32000,
|
||||
"doubao-lite-128k-240515": 128000,
|
||||
"doubao-pro-4k-240515": 4000,
|
||||
"doubao-pro-32k-240515": 32000,
|
||||
"doubao-pro-128k-240515": 128000,
|
||||
# Qwen https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-qianwen-7b-14b-72b-api-detailes?spm=a2c4g.11186623.0.i20
|
||||
"qwen2-57b-a14b-instruct": 32768,
|
||||
"qwen2-72b-instruct": 131072,
|
||||
"qwen2-7b-instruct": 32768,
|
||||
"qwen2-1.5b-instruct": 32768,
|
||||
"qwen2-0.5b-instruct": 32768,
|
||||
"qwen1.5-110b-chat": 32000,
|
||||
"qwen1.5-72b-chat": 32000,
|
||||
"qwen1.5-32b-chat": 32000,
|
||||
"qwen1.5-14b-chat": 8000,
|
||||
"qwen1.5-7b-chat": 32000,
|
||||
"qwen1.5-1.8b-chat": 32000,
|
||||
"qwen1.5-0.5b-chat": 32000,
|
||||
"codeqwen1.5-7b-chat": 64000,
|
||||
"qwen-72b-chat": 32000,
|
||||
"qwen-14b-chat": 8000,
|
||||
"qwen-7b-chat": 32000,
|
||||
"qwen-1.8b-longcontext-chat": 32000,
|
||||
"qwen-1.8b-chat": 8000,
|
||||
}
|
||||
|
||||
# For Amazon Bedrock US region
|
||||
# See https://aws.amazon.com/cn/bedrock/pricing/
|
||||
|
||||
BEDROCK_TOKEN_COSTS = {
|
||||
"amazon.titan-tg1-large": {"prompt": 0.0008, "completion": 0.0008},
|
||||
"amazon.titan-text-express-v1": {"prompt": 0.0008, "completion": 0.0008},
|
||||
"amazon.titan-text-express-v1:0:8k": {"prompt": 0.0008, "completion": 0.0008},
|
||||
"amazon.titan-text-lite-v1:0:4k": {"prompt": 0.0003, "completion": 0.0004},
|
||||
"amazon.titan-text-lite-v1": {"prompt": 0.0003, "completion": 0.0004},
|
||||
"anthropic.claude-instant-v1": {"prompt": 0.0008, "completion": 0.00024},
|
||||
"anthropic.claude-instant-v1:2:100k": {"prompt": 0.0008, "completion": 0.00024},
|
||||
"anthropic.claude-v1": {"prompt": 0.008, "completion": 0.0024},
|
||||
"anthropic.claude-v2": {"prompt": 0.008, "completion": 0.0024},
|
||||
"anthropic.claude-v2:1": {"prompt": 0.008, "completion": 0.0024},
|
||||
"anthropic.claude-v2:0:18k": {"prompt": 0.008, "completion": 0.0024},
|
||||
"anthropic.claude-v2:1:200k": {"prompt": 0.008, "completion": 0.0024},
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0": {"prompt": 0.003, "completion": 0.015},
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:28k": {"prompt": 0.003, "completion": 0.015},
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:200k": {"prompt": 0.003, "completion": 0.015},
|
||||
"anthropic.claude-3-5-sonnet-20240620-v1:0": {"prompt": 0.003, "completion": 0.015},
|
||||
"anthropic.claude-3-haiku-20240307-v1:0": {"prompt": 0.00025, "completion": 0.00125},
|
||||
"anthropic.claude-3-haiku-20240307-v1:0:48k": {"prompt": 0.00025, "completion": 0.00125},
|
||||
"anthropic.claude-3-haiku-20240307-v1:0:200k": {"prompt": 0.00025, "completion": 0.00125},
|
||||
# currently (2024-4-29) only available at US West (Oregon) AWS Region.
|
||||
"anthropic.claude-3-opus-20240229-v1:0": {"prompt": 0.015, "completion": 0.075},
|
||||
"cohere.command-text-v14": {"prompt": 0.0015, "completion": 0.0015},
|
||||
"cohere.command-text-v14:7:4k": {"prompt": 0.0015, "completion": 0.0015},
|
||||
"cohere.command-light-text-v14": {"prompt": 0.0003, "completion": 0.0003},
|
||||
"cohere.command-light-text-v14:7:4k": {"prompt": 0.0003, "completion": 0.0003},
|
||||
"meta.llama2-13b-chat-v1:0:4k": {"prompt": 0.00075, "completion": 0.001},
|
||||
"meta.llama2-13b-chat-v1": {"prompt": 0.00075, "completion": 0.001},
|
||||
"meta.llama2-70b-v1": {"prompt": 0.00195, "completion": 0.00256},
|
||||
"meta.llama2-70b-v1:0:4k": {"prompt": 0.00195, "completion": 0.00256},
|
||||
"meta.llama2-70b-chat-v1": {"prompt": 0.00195, "completion": 0.00256},
|
||||
"meta.llama2-70b-chat-v1:0:4k": {"prompt": 0.00195, "completion": 0.00256},
|
||||
"meta.llama3-8b-instruct-v1:0": {"prompt": 0.0004, "completion": 0.0006},
|
||||
"meta.llama3-70b-instruct-v1:0": {"prompt": 0.00265, "completion": 0.0035},
|
||||
"mistral.mistral-7b-instruct-v0:2": {"prompt": 0.00015, "completion": 0.0002},
|
||||
"mistral.mixtral-8x7b-instruct-v0:1": {"prompt": 0.00045, "completion": 0.0007},
|
||||
"mistral.mistral-large-2402-v1:0": {"prompt": 0.008, "completion": 0.024},
|
||||
"ai21.j2-grande-instruct": {"prompt": 0.0125, "completion": 0.0125},
|
||||
"ai21.j2-jumbo-instruct": {"prompt": 0.0188, "completion": 0.0188},
|
||||
"ai21.j2-mid": {"prompt": 0.0125, "completion": 0.0125},
|
||||
"ai21.j2-mid-v1": {"prompt": 0.0125, "completion": 0.0125},
|
||||
"ai21.j2-ultra": {"prompt": 0.0188, "completion": 0.0188},
|
||||
"ai21.j2-ultra-v1": {"prompt": 0.0188, "completion": 0.0188},
|
||||
}
|
||||
|
||||
# https://xinghuo.xfyun.cn/sparkapi?scr=price
|
||||
SPARK_TOKENS = {
|
||||
"general": {"prompt": 0.0, "completion": 0.0}, # Spark-Lite
|
||||
"generalv2": {"prompt": 0.0188, "completion": 0.0188}, # Spark V2.0
|
||||
"generalv3": {"prompt": 0.0035, "completion": 0.0035}, # Spark Pro
|
||||
"generalv3.5": {"prompt": 0.0035, "completion": 0.0035}, # Spark3.5 Max
|
||||
}
|
||||
|
||||
|
||||
def count_message_tokens(messages, model="gpt-3.5-turbo-0125"):
|
||||
def count_input_tokens(messages, model="gpt-3.5-turbo-0125"):
|
||||
"""Return the number of tokens used by a list of messages."""
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
print("Warning: model not found. Using cl100k_base encoding.")
|
||||
logger.info(f"Warning: model {model} not found in tiktoken. Using cl100k_base encoding.")
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
if model in {
|
||||
"gpt-3.5-turbo-0613",
|
||||
|
|
@ -206,12 +343,14 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0125"):
|
|||
"gpt-4-32k-0314",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k-0613",
|
||||
"gpt-4-turbo",
|
||||
"gpt-4-turbo-preview",
|
||||
"gpt-4-0125-preview",
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4-turbo",
|
||||
"gpt-4-vision-preview",
|
||||
"gpt-4-1106-vision-preview",
|
||||
"gpt-4o",
|
||||
"gpt-4o-2024-05-13",
|
||||
"gpt-4o-mini",
|
||||
}:
|
||||
tokens_per_message = 3 # # every reply is primed with <|start|>assistant<|message|>
|
||||
|
|
@ -220,11 +359,11 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0125"):
|
|||
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||
tokens_per_name = -1 # if there's a name, the role is omitted
|
||||
elif "gpt-3.5-turbo" == model:
|
||||
print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0125.")
|
||||
return count_message_tokens(messages, model="gpt-3.5-turbo-0125")
|
||||
logger.info("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0125.")
|
||||
return count_input_tokens(messages, model="gpt-3.5-turbo-0125")
|
||||
elif "gpt-4" == model:
|
||||
print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
|
||||
return count_message_tokens(messages, model="gpt-4-0613")
|
||||
logger.info("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
|
||||
return count_input_tokens(messages, model="gpt-4-0613")
|
||||
elif "open-llm-model" == model:
|
||||
"""
|
||||
For self-hosted open_llm api, they include lots of different models. The message tokens calculation is
|
||||
|
|
@ -255,21 +394,21 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0125"):
|
|||
return num_tokens
|
||||
|
||||
|
||||
def count_string_tokens(string: str, model_name: str) -> int:
|
||||
def count_output_tokens(string: str, model: str) -> int:
|
||||
"""
|
||||
Returns the number of tokens in a text string.
|
||||
|
||||
Args:
|
||||
string (str): The text string.
|
||||
model_name (str): The name of the encoding to use. (e.g., "gpt-3.5-turbo")
|
||||
model (str): The name of the encoding to use. (e.g., "gpt-3.5-turbo")
|
||||
|
||||
Returns:
|
||||
int: The number of tokens in the text string.
|
||||
"""
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model_name)
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
print("Warning: model not found. Using cl100k_base encoding.")
|
||||
logger.info(f"Warning: model {model} not found in tiktoken. Using cl100k_base encoding.")
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
return len(encoding.encode(string))
|
||||
|
||||
|
|
@ -286,4 +425,16 @@ def get_max_completion_tokens(messages: list[dict], model: str, default: int) ->
|
|||
"""
|
||||
if model not in TOKEN_MAX:
|
||||
return default
|
||||
return TOKEN_MAX[model] - count_message_tokens(messages) - 1
|
||||
return TOKEN_MAX[model] - count_input_tokens(messages) - 1
|
||||
|
||||
|
||||
async def get_openrouter_tokens(chunk: ChatCompletionChunk) -> CompletionUsage:
|
||||
"""refs to https://openrouter.ai/docs#querying-cost-and-stats"""
|
||||
url = f"https://openrouter.ai/api/v1/generation?id={chunk.id}"
|
||||
resp = await apost(url=url, as_json=True)
|
||||
tokens_prompt = resp.get("tokens_prompt", 0)
|
||||
completion_tokens = resp.get("tokens_completion", 0)
|
||||
usage = CompletionUsage(
|
||||
prompt_tokens=tokens_prompt, completion_tokens=completion_tokens, total_tokens=tokens_prompt + completion_tokens
|
||||
)
|
||||
return usage
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue