mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-01 20:03:28 +02:00
Merge branch 'geekan:main' into update_qwen
This commit is contained in:
commit
4e3c46a5fd
26 changed files with 389 additions and 294 deletions
|
|
@ -84,7 +84,7 @@ class WriteTasks(Action):
|
|||
|
||||
async def _update_requirements(self, doc):
|
||||
m = json.loads(doc.content)
|
||||
packages = set(m.get("Required Python packages", set()))
|
||||
packages = set(m.get("Required packages", set()))
|
||||
requirement_doc = await self.repo.get(filename=PACKAGE_REQUIREMENTS_FILENAME)
|
||||
if not requirement_doc:
|
||||
requirement_doc = Document(filename=PACKAGE_REQUIREMENTS_FILENAME, root_path=".", content="")
|
||||
|
|
|
|||
|
|
@ -9,10 +9,10 @@ from typing import List
|
|||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
|
||||
REQUIRED_PYTHON_PACKAGES = ActionNode(
|
||||
key="Required Python packages",
|
||||
REQUIRED_PACKAGES = ActionNode(
|
||||
key="Required packages",
|
||||
expected_type=List[str],
|
||||
instruction="Provide required Python packages in requirements.txt format.",
|
||||
instruction="Provide required packages in requirements.txt format.",
|
||||
example=["flask==1.1.2", "bcrypt==3.2.0"],
|
||||
)
|
||||
|
||||
|
|
@ -97,7 +97,7 @@ ANYTHING_UNCLEAR_PM = ActionNode(
|
|||
)
|
||||
|
||||
NODES = [
|
||||
REQUIRED_PYTHON_PACKAGES,
|
||||
REQUIRED_PACKAGES,
|
||||
REQUIRED_OTHER_LANGUAGE_PACKAGES,
|
||||
LOGIC_ANALYSIS,
|
||||
TASK_LIST,
|
||||
|
|
@ -107,7 +107,7 @@ NODES = [
|
|||
]
|
||||
|
||||
REFINED_NODES = [
|
||||
REQUIRED_PYTHON_PACKAGES,
|
||||
REQUIRED_PACKAGES,
|
||||
REQUIRED_OTHER_LANGUAGE_PACKAGES,
|
||||
REFINED_LOGIC_ANALYSIS,
|
||||
REFINED_TASK_LIST,
|
||||
|
|
|
|||
|
|
@ -139,7 +139,7 @@ Language: Please use the same language as the user requirement, but the title an
|
|||
end", "Anything UNCLEAR": "目前项目要求明确,没有不清楚的地方。"}
|
||||
|
||||
## Tasks
|
||||
{"Required Python packages": ["无需Python包"], "Required Other language third-party packages": ["vue.js"], "Logic Analysis": [["index.html", "作为游戏的入口文件和主要的HTML结构"], ["styles.css", "包含所有的CSS样式,确保游戏界面美观"], ["main.js", "包含Main类,负责初始化游戏和绑定事件"], ["game.js", "包含Game类,负责游戏逻辑,如开始游戏、移动方块等"], ["storage.js", "包含Storage类,用于获取和设置玩家的最高分"]], "Task list": ["index.html", "styles.css", "storage.js", "game.js", "main.js"], "Full API spec": "", "Shared Knowledge": "\'game.js\' 包含游戏逻辑相关的函数,被 \'main.js\' 调用。", "Anything UNCLEAR": "目前项目要求明确,没有不清楚的地方。"}
|
||||
{"Required packages": ["无需Python包"], "Required Other language third-party packages": ["vue.js"], "Logic Analysis": [["index.html", "作为游戏的入口文件和主要的HTML结构"], ["styles.css", "包含所有的CSS样式,确保游戏界面美观"], ["main.js", "包含Main类,负责初始化游戏和绑定事件"], ["game.js", "包含Game类,负责游戏逻辑,如开始游戏、移动方块等"], ["storage.js", "包含Storage类,用于获取和设置玩家的最高分"]], "Task list": ["index.html", "styles.css", "storage.js", "game.js", "main.js"], "Full API spec": "", "Shared Knowledge": "\'game.js\' 包含游戏逻辑相关的函数,被 \'main.js\' 调用。", "Anything UNCLEAR": "目前项目要求明确,没有不清楚的地方。"}
|
||||
|
||||
## Code Files
|
||||
----- index.html
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ class LLMType(Enum):
|
|||
YI = "yi" # lingyiwanwu
|
||||
OPENROUTER = "openrouter"
|
||||
BEDROCK = "bedrock"
|
||||
ARK = "ark"
|
||||
|
||||
def __missing__(self, key):
|
||||
return self.OPENAI
|
||||
|
|
|
|||
|
|
@ -56,10 +56,16 @@ def get_role_environment(sim_code: str, role_name: str, step: int = 0) -> dict:
|
|||
|
||||
|
||||
def write_curr_sim_code(curr_sim_code: dict, temp_storage_path: Optional[Path] = None):
|
||||
temp_storage_path = Path(temp_storage_path) or TEMP_STORAGE_PATH
|
||||
if temp_storage_path is None:
|
||||
temp_storage_path = TEMP_STORAGE_PATH
|
||||
else:
|
||||
temp_storage_path = Path(temp_storage_path)
|
||||
write_json_file(temp_storage_path.joinpath("curr_sim_code.json"), curr_sim_code)
|
||||
|
||||
|
||||
def write_curr_step(curr_step: dict, temp_storage_path: Optional[Path] = None):
|
||||
temp_storage_path = Path(temp_storage_path) or TEMP_STORAGE_PATH
|
||||
if temp_storage_path is None:
|
||||
temp_storage_path = TEMP_STORAGE_PATH
|
||||
else:
|
||||
temp_storage_path = Path(temp_storage_path)
|
||||
write_json_file(temp_storage_path.joinpath("curr_step.json"), curr_step)
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from metagpt.provider.qianfan_api import QianFanLLM
|
|||
from metagpt.provider.dashscope_api import DashScopeLLM
|
||||
from metagpt.provider.anthropic_api import AnthropicLLM
|
||||
from metagpt.provider.bedrock_api import BedrockLLM
|
||||
from metagpt.provider.ark_api import ArkLLM
|
||||
|
||||
__all__ = [
|
||||
"GeminiLLM",
|
||||
|
|
@ -32,4 +33,5 @@ __all__ = [
|
|||
"DashScopeLLM",
|
||||
"AnthropicLLM",
|
||||
"BedrockLLM",
|
||||
"ArkLLM",
|
||||
]
|
||||
|
|
|
|||
44
metagpt/provider/ark_api.py
Normal file
44
metagpt/provider/ark_api.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
from openai import AsyncStream
|
||||
from openai.types import CompletionUsage
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.const import USE_CONFIG_TIMEOUT
|
||||
from metagpt.logs import log_llm_stream
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import OpenAILLM
|
||||
|
||||
|
||||
@register_provider(LLMType.ARK)
|
||||
class ArkLLM(OpenAILLM):
|
||||
"""
|
||||
用于火山方舟的API
|
||||
见:https://www.volcengine.com/docs/82379/1263482
|
||||
"""
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str:
|
||||
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
|
||||
**self._cons_kwargs(messages, timeout=self.get_timeout(timeout)),
|
||||
stream=True,
|
||||
extra_body={"stream_options": {"include_usage": True}} # 只有增加这个参数才会在流式时最后返回usage
|
||||
)
|
||||
usage = None
|
||||
collected_messages = []
|
||||
async for chunk in response:
|
||||
chunk_message = chunk.choices[0].delta.content or "" if chunk.choices else "" # extract the message
|
||||
log_llm_stream(chunk_message)
|
||||
collected_messages.append(chunk_message)
|
||||
if chunk.usage:
|
||||
# 火山方舟的流式调用会在最后一个chunk中返回usage,最后一个chunk的choices为[]
|
||||
usage = CompletionUsage(**chunk.usage)
|
||||
|
||||
log_llm_stream("\n")
|
||||
full_reply_content = "".join(collected_messages)
|
||||
self._update_costs(usage, chunk.model)
|
||||
return full_reply_content
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> ChatCompletion:
|
||||
kwargs = self._cons_kwargs(messages, timeout=self.get_timeout(timeout))
|
||||
rsp: ChatCompletion = await self.aclient.chat.completions.create(**kwargs)
|
||||
self._update_costs(rsp.usage, rsp.model)
|
||||
return rsp
|
||||
|
|
@ -102,7 +102,10 @@ class OpenAILLM(BaseLLM):
|
|||
if finish_reason:
|
||||
if hasattr(chunk, "usage") and chunk.usage is not None:
|
||||
# Some services have usage as an attribute of the chunk, such as Fireworks
|
||||
usage = CompletionUsage(**chunk.usage)
|
||||
if isinstance(chunk.usage, CompletionUsage):
|
||||
usage = chunk.usage
|
||||
else:
|
||||
usage = CompletionUsage(**chunk.usage)
|
||||
elif hasattr(chunk.choices[0], "usage"):
|
||||
# The usage of some services is an attribute of chunk.choices[0], such as Moonshot
|
||||
usage = CompletionUsage(**chunk.choices[0].usage)
|
||||
|
|
|
|||
|
|
@ -1,175 +1,95 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@File : spark_api.py
|
||||
"""
|
||||
import _thread as thread
|
||||
import base64
|
||||
import datetime
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import ssl
|
||||
from time import mktime
|
||||
from urllib.parse import urlencode, urlparse
|
||||
from wsgiref.handlers import format_date_time
|
||||
|
||||
import websocket # 使用websocket_client
|
||||
from sparkai.core.messages import _convert_to_message, convert_to_messages
|
||||
from sparkai.core.messages.ai import AIMessage
|
||||
from sparkai.core.messages.base import BaseMessage
|
||||
from sparkai.core.messages.human import HumanMessage
|
||||
from sparkai.core.messages.system import SystemMessage
|
||||
from sparkai.core.outputs.llm_result import LLMResult
|
||||
from sparkai.llm.llm import ChatSparkLLM
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.const import USE_CONFIG_TIMEOUT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.logs import log_llm_stream
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.utils.common import any_to_str
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.token_counter import SPARK_TOKENS
|
||||
|
||||
|
||||
@register_provider(LLMType.SPARK)
|
||||
class SparkLLM(BaseLLM):
|
||||
"""
|
||||
用于讯飞星火大模型系列
|
||||
参考:https://github.com/iflytek/spark-ai-python"""
|
||||
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.config = config
|
||||
logger.warning("SparkLLM:当前方法无法支持异步运行。当你使用acompletion时,并不能并行访问。")
|
||||
self.cost_manager = CostManager(token_costs=SPARK_TOKENS)
|
||||
self.model = self.config.domain
|
||||
self._init_client()
|
||||
|
||||
def get_choice_text(self, rsp: dict) -> str:
|
||||
return rsp["payload"]["choices"]["text"][-1]["content"]
|
||||
def _init_client(self):
|
||||
self.client = ChatSparkLLM(
|
||||
spark_api_url=self.config.base_url,
|
||||
spark_app_id=self.config.app_id,
|
||||
spark_api_key=self.config.api_key,
|
||||
spark_api_secret=self.config.api_secret,
|
||||
spark_llm_domain=self.config.domain,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
|
||||
pass
|
||||
def _system_msg(self, msg: str) -> SystemMessage:
|
||||
return _convert_to_message(msg)
|
||||
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = USE_CONFIG_TIMEOUT) -> str:
|
||||
# 不支持
|
||||
# logger.warning("当前方法无法支持异步运行。当你使用acompletion时,并不能并行访问。")
|
||||
w = GetMessageFromWeb(messages, self.config)
|
||||
return w.run()
|
||||
def _user_msg(self, msg: str, **kwargs) -> HumanMessage:
|
||||
return _convert_to_message(msg)
|
||||
|
||||
def _assistant_msg(self, msg: str) -> AIMessage:
|
||||
return _convert_to_message(msg)
|
||||
|
||||
def get_choice_text(self, rsp: LLMResult) -> str:
|
||||
return rsp.generations[0][0].text
|
||||
|
||||
def get_usage(self, response: LLMResult):
|
||||
message = response.generations[0][0].message
|
||||
if hasattr(message, "additional_kwargs"):
|
||||
return message.additional_kwargs.get("token_usage", {})
|
||||
else:
|
||||
return {}
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
|
||||
pass
|
||||
response = await self.acreate(messages, stream=False)
|
||||
usage = self.get_usage(response)
|
||||
self._update_costs(usage)
|
||||
return response
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
|
||||
# 不支持异步
|
||||
w = GetMessageFromWeb(messages, self.config)
|
||||
return w.run()
|
||||
return await self._achat_completion(messages, timeout)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
|
||||
response = await self.acreate(messages, stream=True)
|
||||
collected_content = []
|
||||
usage = {}
|
||||
async for chunk in response:
|
||||
collected_content.append(chunk.content)
|
||||
log_llm_stream(chunk.content)
|
||||
if hasattr(chunk, "additional_kwargs"):
|
||||
usage = chunk.additional_kwargs.get("token_usage", {})
|
||||
|
||||
class GetMessageFromWeb:
|
||||
class WsParam:
|
||||
"""
|
||||
该类适合讯飞星火大部分接口的调用。
|
||||
输入 app_id, api_key, api_secret, spark_url以初始化,
|
||||
create_url方法返回接口url
|
||||
"""
|
||||
log_llm_stream("\n")
|
||||
self._update_costs(usage)
|
||||
full_content = "".join(collected_content)
|
||||
return full_content
|
||||
|
||||
# 初始化
|
||||
def __init__(self, app_id, api_key, api_secret, spark_url, message=None):
|
||||
self.app_id = app_id
|
||||
self.api_key = api_key
|
||||
self.api_secret = api_secret
|
||||
self.host = urlparse(spark_url).netloc
|
||||
self.path = urlparse(spark_url).path
|
||||
self.spark_url = spark_url
|
||||
self.message = message
|
||||
def _extract_assistant_rsp(self, context: list[BaseMessage]) -> str:
|
||||
return "\n".join([i.content for i in context if "AIMessage" in any_to_str(i)])
|
||||
|
||||
# 生成url
|
||||
def create_url(self):
|
||||
# 生成RFC1123格式的时间戳
|
||||
now = datetime.datetime.now()
|
||||
date = format_date_time(mktime(now.timetuple()))
|
||||
|
||||
# 拼接字符串
|
||||
signature_origin = "host: " + self.host + "\n"
|
||||
signature_origin += "date: " + date + "\n"
|
||||
signature_origin += "GET " + self.path + " HTTP/1.1"
|
||||
|
||||
# 进行hmac-sha256进行加密
|
||||
signature_sha = hmac.new(
|
||||
self.api_secret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256
|
||||
).digest()
|
||||
|
||||
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding="utf-8")
|
||||
|
||||
authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
|
||||
|
||||
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8")
|
||||
|
||||
# 将请求的鉴权参数组合为字典
|
||||
v = {"authorization": authorization, "date": date, "host": self.host}
|
||||
# 拼接鉴权参数,生成url
|
||||
url = self.spark_url + "?" + urlencode(v)
|
||||
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
|
||||
return url
|
||||
|
||||
def __init__(self, text, config: LLMConfig):
|
||||
self.text = text
|
||||
self.ret = ""
|
||||
self.spark_appid = config.app_id
|
||||
self.spark_api_secret = config.api_secret
|
||||
self.spark_api_key = config.api_key
|
||||
self.domain = config.domain
|
||||
self.spark_url = config.base_url
|
||||
|
||||
def on_message(self, ws, message):
|
||||
data = json.loads(message)
|
||||
code = data["header"]["code"]
|
||||
|
||||
if code != 0:
|
||||
ws.close() # 请求错误,则关闭socket
|
||||
logger.critical(f"回答获取失败,响应信息反序列化之后为: {data}")
|
||||
return
|
||||
async def acreate(self, messages: list[dict], stream: bool = True):
|
||||
messages = convert_to_messages(messages)
|
||||
if stream:
|
||||
return self.client.astream(messages)
|
||||
else:
|
||||
choices = data["payload"]["choices"]
|
||||
# seq = choices["seq"] # 服务端是流式返回,seq为返回的数据序号
|
||||
status = choices["status"] # 服务端是流式返回,status用于判断信息是否传送完毕
|
||||
content = choices["text"][0]["content"] # 本次接收到的回答文本
|
||||
self.ret += content
|
||||
if status == 2:
|
||||
ws.close()
|
||||
|
||||
# 收到websocket错误的处理
|
||||
def on_error(self, ws, error):
|
||||
# on_message方法处理接收到的信息,出现任何错误,都会调用这个方法
|
||||
logger.critical(f"通讯连接出错,【错误提示: {error}】")
|
||||
|
||||
# 收到websocket关闭的处理
|
||||
def on_close(self, ws, one, two):
|
||||
pass
|
||||
|
||||
# 处理请求数据
|
||||
def gen_params(self):
|
||||
data = {
|
||||
"header": {"app_id": self.spark_appid, "uid": "1234"},
|
||||
"parameter": {
|
||||
"chat": {
|
||||
# domain为必传参数
|
||||
"domain": self.domain,
|
||||
# 以下为可微调,非必传参数
|
||||
# 注意:官方建议,temperature和top_k修改一个即可
|
||||
"max_tokens": 2048, # 默认2048,模型回答的tokens的最大长度,即允许它输出文本的最长字数
|
||||
"temperature": 0.5, # 取值为[0,1],默认为0.5。取值越高随机性越强、发散性越高,即相同的问题得到的不同答案的可能性越高
|
||||
"top_k": 4, # 取值为[1,6],默认为4。从k个候选中随机选择一个(非等概率)
|
||||
}
|
||||
},
|
||||
"payload": {"message": {"text": self.text}},
|
||||
}
|
||||
return data
|
||||
|
||||
def send(self, ws, *args):
|
||||
data = json.dumps(self.gen_params())
|
||||
ws.send(data)
|
||||
|
||||
# 收到websocket连接建立的处理
|
||||
def on_open(self, ws):
|
||||
thread.start_new_thread(self.send, (ws,))
|
||||
|
||||
# 处理收到的 websocket消息,出现任何错误,调用on_error方法
|
||||
def run(self):
|
||||
return self._run(self.text)
|
||||
|
||||
def _run(self, text_list):
|
||||
ws_param = self.WsParam(self.spark_appid, self.spark_api_key, self.spark_api_secret, self.spark_url, text_list)
|
||||
ws_url = ws_param.create_url()
|
||||
|
||||
websocket.enableTrace(False) # 默认禁用 WebSocket 的跟踪功能
|
||||
ws = websocket.WebSocketApp(
|
||||
ws_url, on_message=self.on_message, on_error=self.on_error, on_close=self.on_close, on_open=self.on_open
|
||||
)
|
||||
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
||||
return self.ret
|
||||
return await self.client.agenerate([messages])
|
||||
|
|
|
|||
|
|
@ -89,9 +89,9 @@ class RAGEmbeddingFactory(GenericFactory):
|
|||
return OllamaEmbedding(**params)
|
||||
|
||||
def _try_set_model_and_batch_size(self, params: dict):
|
||||
"""Set the model and embed_batch_size only when they are specified."""
|
||||
"""Set the model_name and embed_batch_size only when they are specified."""
|
||||
if config.embedding.model:
|
||||
params["model"] = config.embedding.model
|
||||
params["model_name"] = config.embedding.model
|
||||
|
||||
if config.embedding.embed_batch_size:
|
||||
params["embed_batch_size"] = config.embedding.embed_batch_size
|
||||
|
|
|
|||
|
|
@ -136,4 +136,4 @@ class Assistant(Role):
|
|||
try:
|
||||
self.memory = BrainMemory(**m)
|
||||
except Exception as e:
|
||||
logger.exception(f"load error:{e}, data:{jsn}")
|
||||
logger.exception(f"load error:{e}, data:{m}")
|
||||
|
|
|
|||
|
|
@ -53,30 +53,30 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
|
|||
await page.wait_for_load_state("networkidle")
|
||||
|
||||
await page.wait_for_selector("div#container", state="attached")
|
||||
# mermaid_config = {}
|
||||
mermaid_config = {}
|
||||
background_color = "#ffffff"
|
||||
# my_css = ""
|
||||
my_css = ""
|
||||
await page.evaluate(f'document.body.style.background = "{background_color}";')
|
||||
|
||||
# metadata = await page.evaluate(
|
||||
# """async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
|
||||
# const { mermaid, zenuml } = globalThis;
|
||||
# await mermaid.registerExternalDiagrams([zenuml]);
|
||||
# mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
|
||||
# const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
|
||||
# document.getElementById('container').innerHTML = svg;
|
||||
# const svgElement = document.querySelector('svg');
|
||||
# svgElement.style.backgroundColor = backgroundColor;
|
||||
#
|
||||
# if (myCSS) {
|
||||
# const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
|
||||
# style.appendChild(document.createTextNode(myCSS));
|
||||
# svgElement.appendChild(style);
|
||||
# }
|
||||
#
|
||||
# }""",
|
||||
# [mermaid_code, mermaid_config, my_css, background_color],
|
||||
# )
|
||||
await page.evaluate(
|
||||
"""async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
|
||||
const { mermaid, zenuml } = globalThis;
|
||||
await mermaid.registerExternalDiagrams([zenuml]);
|
||||
mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
|
||||
const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
|
||||
document.getElementById('container').innerHTML = svg;
|
||||
const svgElement = document.querySelector('svg');
|
||||
svgElement.style.backgroundColor = backgroundColor;
|
||||
|
||||
if (myCSS) {
|
||||
const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
|
||||
style.appendChild(document.createTextNode(myCSS));
|
||||
svgElement.appendChild(style);
|
||||
}
|
||||
|
||||
}""",
|
||||
[mermaid_code, mermaid_config, my_css, background_color],
|
||||
)
|
||||
|
||||
if "svg" in suffixes:
|
||||
svg_xml = await page.evaluate(
|
||||
|
|
|
|||
|
|
@ -55,29 +55,29 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
|
|||
await page.goto(mermaid_html_url)
|
||||
|
||||
await page.querySelector("div#container")
|
||||
# mermaid_config = {}
|
||||
mermaid_config = {}
|
||||
background_color = "#ffffff"
|
||||
# my_css = ""
|
||||
my_css = ""
|
||||
await page.evaluate(f'document.body.style.background = "{background_color}";')
|
||||
|
||||
# metadata = await page.evaluate(
|
||||
# """async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
|
||||
# const { mermaid, zenuml } = globalThis;
|
||||
# await mermaid.registerExternalDiagrams([zenuml]);
|
||||
# mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
|
||||
# const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
|
||||
# document.getElementById('container').innerHTML = svg;
|
||||
# const svgElement = document.querySelector('svg');
|
||||
# svgElement.style.backgroundColor = backgroundColor;
|
||||
#
|
||||
# if (myCSS) {
|
||||
# const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
|
||||
# style.appendChild(document.createTextNode(myCSS));
|
||||
# svgElement.appendChild(style);
|
||||
# }
|
||||
# }""",
|
||||
# [mermaid_code, mermaid_config, my_css, background_color],
|
||||
# )
|
||||
await page.evaluate(
|
||||
"""async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
|
||||
const { mermaid, zenuml } = globalThis;
|
||||
await mermaid.registerExternalDiagrams([zenuml]);
|
||||
mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
|
||||
const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
|
||||
document.getElementById('container').innerHTML = svg;
|
||||
const svgElement = document.querySelector('svg');
|
||||
svgElement.style.backgroundColor = backgroundColor;
|
||||
|
||||
if (myCSS) {
|
||||
const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
|
||||
style.appendChild(document.createTextNode(myCSS));
|
||||
svgElement.appendChild(style);
|
||||
}
|
||||
}""",
|
||||
[mermaid_code, mermaid_config, my_css, background_color],
|
||||
)
|
||||
|
||||
if "svg" in suffixes:
|
||||
svg_xml = await page.evaluate(
|
||||
|
|
|
|||
|
|
@ -68,6 +68,15 @@ TOKEN_COSTS = {
|
|||
"openai/gpt-4-turbo-preview": {"prompt": 0.01, "completion": 0.03},
|
||||
"deepseek-chat": {"prompt": 0.00014, "completion": 0.00028},
|
||||
"deepseek-coder": {"prompt": 0.00014, "completion": 0.00028},
|
||||
# For ark model https://www.volcengine.com/docs/82379/1099320
|
||||
"doubao-lite-4k-240515": {"prompt": 0.000042, "completion": 0.000084},
|
||||
"doubao-lite-32k-240515": {"prompt": 0.000042, "completion": 0.000084},
|
||||
"doubao-lite-128k-240515": {"prompt": 0.00011, "completion": 0.00013},
|
||||
"doubao-pro-4k-240515": {"prompt": 0.00011, "completion": 0.00028},
|
||||
"doubao-pro-32k-240515": {"prompt": 0.00011, "completion": 0.00028},
|
||||
"doubao-pro-128k-240515": {"prompt": 0.0007, "completion": 0.0012},
|
||||
"llama3-70b-llama3-70b-instruct": {"prompt": 0.0, "completion": 0.0},
|
||||
"llama3-8b-llama3-8b-instruct": {"prompt": 0.0, "completion": 0.0},
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -213,6 +222,12 @@ TOKEN_MAX = {
|
|||
"openai/gpt-4-turbo-preview": 128000,
|
||||
"deepseek-chat": 32768,
|
||||
"deepseek-coder": 16385,
|
||||
"doubao-lite-4k-240515": 4000,
|
||||
"doubao-lite-32k-240515": 32000,
|
||||
"doubao-lite-128k-240515": 128000,
|
||||
"doubao-pro-4k-240515": 4000,
|
||||
"doubao-pro-32k-240515": 32000,
|
||||
"doubao-pro-128k-240515": 128000,
|
||||
}
|
||||
|
||||
# For Amazon Bedrock US region
|
||||
|
|
@ -262,6 +277,14 @@ BEDROCK_TOKEN_COSTS = {
|
|||
"ai21.j2-ultra-v1": {"prompt": 0.0188, "completion": 0.0188},
|
||||
}
|
||||
|
||||
# https://xinghuo.xfyun.cn/sparkapi?scr=price
|
||||
SPARK_TOKENS = {
|
||||
"general": {"prompt": 0.0, "completion": 0.0}, # Spark-Lite
|
||||
"generalv2": {"prompt": 0.0188, "completion": 0.0188}, # Spark V2.0
|
||||
"generalv3": {"prompt": 0.0035, "completion": 0.0035}, # Spark Pro
|
||||
"generalv3.5": {"prompt": 0.0035, "completion": 0.0035}, # Spark3.5 Max
|
||||
}
|
||||
|
||||
|
||||
def count_input_tokens(messages, model="gpt-3.5-turbo-0125"):
|
||||
"""Return the number of tokens used by a list of messages."""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue