Merge branch 'geekan:main' into update_qwen

This commit is contained in:
usamimeri_renko 2024-05-29 16:27:49 +08:00 committed by GitHub
commit 4e3c46a5fd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
26 changed files with 389 additions and 294 deletions

View file

@ -35,7 +35,7 @@ browser:
mermaid:
engine: "pyppeteer"
path: "/Applications/Google Chrome.app"
pyppeteer_path: "/Applications/Google Chrome.app"
redis:
host: "YOUR_HOST"

View file

@ -0,0 +1,5 @@
llm:
api_type: "ark"
model: "" # your model endpoint like ep-xxx
base_url: "https://ark.cn-beijing.volces.com/api/v3"
api_key: "" # your api-key like ey……

View file

@ -0,0 +1,10 @@
# 适用于讯飞星火的spark-lite 参考 https://www.xfyun.cn/doc/spark/Web.html#_2-function-call%E8%AF%B4%E6%98%8E
llm:
api_type: "spark"
# 对应模型的url 参考 https://www.xfyun.cn/doc/spark/Web.html#_1-%E6%8E%A5%E5%8F%A3%E8%AF%B4%E6%98%8E
base_url: "ws(s)://spark-api.xf-yun.com/v1.1/chat"
app_id: ""
api_key: ""
api_secret: ""
domain: "general" # 取值为 [general,generalv2,generalv3,generalv3.5] 和url一一对应

View file

@ -96,7 +96,7 @@ ### Open-Ended Tasks Dataset and Requirements
| 14 | 14_image_background_removal | Image Background Removal | Remove the background of a given image | This is an image, you need to use python toolkit rembg remove the background of the image. image path:'{data_dir}/open_ended_tasks/14_image_background_removal.jpg'; save path:'{data_dir}/open_ended_tasks/14_image_background_removal.jpg' |
| 15 | 15_text2img | Text2Img | Use SD tools to generate images | I want to generate an image of a beautiful girl using the stable diffusion text2image tool, sd_url = "http://your.sd.service.ip:port" |
| 16 | 16_image_2_code_generation | Image2Code Generation | Web code generation | This is a image. First, convert the image to webpage code including HTML, CSS and JS in one go, and finally save webpage code in a file.The image path: '{data_dir}/open_ended_tasks/16_image_2_code_generation.png'. NOTE: All required dependencies and environments have been fully installed and configured. |
| 17 | 17_image_2_code_generation | Image2Code Generation | Web code generation | This is a image. First, convert the image to webpage code including HTML, CSS and JS in one go, and finally save webpage code in a file.The image path: '{data_dir}/open_ended_tasks/16_image_2_code_generation.png'. NOTE: All required dependencies and environments have been fully installed and configured. |
| 17 | 17_image_2_code_generation | Image2Code Generation | Web code generation | This is a image. First, convert the image to webpage code including HTML, CSS and JS in one go, and finally save webpage code in a file.The image path: '{data_dir}/open_ended_tasks/17_image_2_code_generation.png'. NOTE: All required dependencies and environments have been fully installed and configured. |
| 18 | 18_generate_games | Generate games using existing repo | Game tool usage (pyxel) | Create a Snake game. Players need to control the movement of the snake to eat food and grow its body, while avoiding the snake's head touching their own body or game boundaries. Games need to have basic game logic, user interface. During the production process, please consider factors such as playability, beautiful interface, and convenient operation of the game. Note: pyxel environment already satisfied |
| 19 | 19_generate_games | Generate games using existing repo | Game tool usage (pyxel) | You are a professional game developer, please use pyxel software to create a simple jumping game. The game needs to include a character that can move left and right on the screen. When the player presses the spacebar, the character should jump. Please ensure that the game is easy to operate, with clear graphics, and complies with the functional limitations of pyxel software. Note: pyxel environment already satisfied |
| 20 | 20_generate_games | Generate games using existing repo | Game tool usage (pyxel) | Make a mouse click game that click button as many times as possible in 30 seconds using pyxel. Note: pyxel environment already satisfied |

View file

@ -84,7 +84,7 @@ class WriteTasks(Action):
async def _update_requirements(self, doc):
m = json.loads(doc.content)
packages = set(m.get("Required Python packages", set()))
packages = set(m.get("Required packages", set()))
requirement_doc = await self.repo.get(filename=PACKAGE_REQUIREMENTS_FILENAME)
if not requirement_doc:
requirement_doc = Document(filename=PACKAGE_REQUIREMENTS_FILENAME, root_path=".", content="")

View file

@ -9,10 +9,10 @@ from typing import List
from metagpt.actions.action_node import ActionNode
REQUIRED_PYTHON_PACKAGES = ActionNode(
key="Required Python packages",
REQUIRED_PACKAGES = ActionNode(
key="Required packages",
expected_type=List[str],
instruction="Provide required Python packages in requirements.txt format.",
instruction="Provide required packages in requirements.txt format.",
example=["flask==1.1.2", "bcrypt==3.2.0"],
)
@ -97,7 +97,7 @@ ANYTHING_UNCLEAR_PM = ActionNode(
)
NODES = [
REQUIRED_PYTHON_PACKAGES,
REQUIRED_PACKAGES,
REQUIRED_OTHER_LANGUAGE_PACKAGES,
LOGIC_ANALYSIS,
TASK_LIST,
@ -107,7 +107,7 @@ NODES = [
]
REFINED_NODES = [
REQUIRED_PYTHON_PACKAGES,
REQUIRED_PACKAGES,
REQUIRED_OTHER_LANGUAGE_PACKAGES,
REFINED_LOGIC_ANALYSIS,
REFINED_TASK_LIST,

View file

@ -139,7 +139,7 @@ Language: Please use the same language as the user requirement, but the title an
end", "Anything UNCLEAR": "目前项目要求明确没有不清楚的地方"}
## Tasks
{"Required Python packages": ["无需Python包"], "Required Other language third-party packages": ["vue.js"], "Logic Analysis": [["index.html", "作为游戏的入口文件和主要的HTML结构"], ["styles.css", "包含所有的CSS样式确保游戏界面美观"], ["main.js", "包含Main类负责初始化游戏和绑定事件"], ["game.js", "包含Game类负责游戏逻辑如开始游戏、移动方块等"], ["storage.js", "包含Storage类用于获取和设置玩家的最高分"]], "Task list": ["index.html", "styles.css", "storage.js", "game.js", "main.js"], "Full API spec": "", "Shared Knowledge": "\'game.js\' 包含游戏逻辑相关的函数,被 \'main.js\' 调用。", "Anything UNCLEAR": "目前项目要求明确,没有不清楚的地方。"}
{"Required packages": ["无需Python包"], "Required Other language third-party packages": ["vue.js"], "Logic Analysis": [["index.html", "作为游戏的入口文件和主要的HTML结构"], ["styles.css", "包含所有的CSS样式确保游戏界面美观"], ["main.js", "包含Main类负责初始化游戏和绑定事件"], ["game.js", "包含Game类负责游戏逻辑如开始游戏、移动方块等"], ["storage.js", "包含Storage类用于获取和设置玩家的最高分"]], "Task list": ["index.html", "styles.css", "storage.js", "game.js", "main.js"], "Full API spec": "", "Shared Knowledge": "\'game.js\' 包含游戏逻辑相关的函数,被 \'main.js\' 调用。", "Anything UNCLEAR": "目前项目要求明确,没有不清楚的地方。"}
## Code Files
----- index.html

View file

@ -33,6 +33,7 @@ class LLMType(Enum):
YI = "yi" # lingyiwanwu
OPENROUTER = "openrouter"
BEDROCK = "bedrock"
ARK = "ark"
def __missing__(self, key):
return self.OPENAI

View file

@ -56,10 +56,16 @@ def get_role_environment(sim_code: str, role_name: str, step: int = 0) -> dict:
def write_curr_sim_code(curr_sim_code: dict, temp_storage_path: Optional[Path] = None):
temp_storage_path = Path(temp_storage_path) or TEMP_STORAGE_PATH
if temp_storage_path is None:
temp_storage_path = TEMP_STORAGE_PATH
else:
temp_storage_path = Path(temp_storage_path)
write_json_file(temp_storage_path.joinpath("curr_sim_code.json"), curr_sim_code)
def write_curr_step(curr_step: dict, temp_storage_path: Optional[Path] = None):
temp_storage_path = Path(temp_storage_path) or TEMP_STORAGE_PATH
if temp_storage_path is None:
temp_storage_path = TEMP_STORAGE_PATH
else:
temp_storage_path = Path(temp_storage_path)
write_json_file(temp_storage_path.joinpath("curr_step.json"), curr_step)

View file

@ -18,6 +18,7 @@ from metagpt.provider.qianfan_api import QianFanLLM
from metagpt.provider.dashscope_api import DashScopeLLM
from metagpt.provider.anthropic_api import AnthropicLLM
from metagpt.provider.bedrock_api import BedrockLLM
from metagpt.provider.ark_api import ArkLLM
__all__ = [
"GeminiLLM",
@ -32,4 +33,5 @@ __all__ = [
"DashScopeLLM",
"AnthropicLLM",
"BedrockLLM",
"ArkLLM",
]

View file

@ -0,0 +1,44 @@
from openai import AsyncStream
from openai.types import CompletionUsage
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from metagpt.configs.llm_config import LLMType
from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.logs import log_llm_stream
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import OpenAILLM
@register_provider(LLMType.ARK)
class ArkLLM(OpenAILLM):
"""
用于火山方舟的API
https://www.volcengine.com/docs/82379/1263482
"""
async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str:
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
**self._cons_kwargs(messages, timeout=self.get_timeout(timeout)),
stream=True,
extra_body={"stream_options": {"include_usage": True}} # 只有增加这个参数才会在流式时最后返回usage
)
usage = None
collected_messages = []
async for chunk in response:
chunk_message = chunk.choices[0].delta.content or "" if chunk.choices else "" # extract the message
log_llm_stream(chunk_message)
collected_messages.append(chunk_message)
if chunk.usage:
# 火山方舟的流式调用会在最后一个chunk中返回usage,最后一个chunk的choices为[]
usage = CompletionUsage(**chunk.usage)
log_llm_stream("\n")
full_reply_content = "".join(collected_messages)
self._update_costs(usage, chunk.model)
return full_reply_content
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> ChatCompletion:
kwargs = self._cons_kwargs(messages, timeout=self.get_timeout(timeout))
rsp: ChatCompletion = await self.aclient.chat.completions.create(**kwargs)
self._update_costs(rsp.usage, rsp.model)
return rsp

View file

@ -102,7 +102,10 @@ class OpenAILLM(BaseLLM):
if finish_reason:
if hasattr(chunk, "usage") and chunk.usage is not None:
# Some services have usage as an attribute of the chunk, such as Fireworks
usage = CompletionUsage(**chunk.usage)
if isinstance(chunk.usage, CompletionUsage):
usage = chunk.usage
else:
usage = CompletionUsage(**chunk.usage)
elif hasattr(chunk.choices[0], "usage"):
# The usage of some services is an attribute of chunk.choices[0], such as Moonshot
usage = CompletionUsage(**chunk.choices[0].usage)

View file

@ -1,175 +1,95 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@File : spark_api.py
"""
import _thread as thread
import base64
import datetime
import hashlib
import hmac
import json
import ssl
from time import mktime
from urllib.parse import urlencode, urlparse
from wsgiref.handlers import format_date_time
import websocket # 使用websocket_client
from sparkai.core.messages import _convert_to_message, convert_to_messages
from sparkai.core.messages.ai import AIMessage
from sparkai.core.messages.base import BaseMessage
from sparkai.core.messages.human import HumanMessage
from sparkai.core.messages.system import SystemMessage
from sparkai.core.outputs.llm_result import LLMResult
from sparkai.llm.llm import ChatSparkLLM
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.logs import logger
from metagpt.logs import log_llm_stream
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.utils.common import any_to_str
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.token_counter import SPARK_TOKENS
@register_provider(LLMType.SPARK)
class SparkLLM(BaseLLM):
"""
用于讯飞星火大模型系列
参考https://github.com/iflytek/spark-ai-python"""
def __init__(self, config: LLMConfig):
self.config = config
logger.warning("SparkLLM当前方法无法支持异步运行。当你使用acompletion时并不能并行访问。")
self.cost_manager = CostManager(token_costs=SPARK_TOKENS)
self.model = self.config.domain
self._init_client()
def get_choice_text(self, rsp: dict) -> str:
return rsp["payload"]["choices"]["text"][-1]["content"]
def _init_client(self):
self.client = ChatSparkLLM(
spark_api_url=self.config.base_url,
spark_app_id=self.config.app_id,
spark_api_key=self.config.api_key,
spark_api_secret=self.config.api_secret,
spark_llm_domain=self.config.domain,
streaming=True,
)
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
pass
def _system_msg(self, msg: str) -> SystemMessage:
return _convert_to_message(msg)
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = USE_CONFIG_TIMEOUT) -> str:
# 不支持
# logger.warning("当前方法无法支持异步运行。当你使用acompletion时并不能并行访问。")
w = GetMessageFromWeb(messages, self.config)
return w.run()
def _user_msg(self, msg: str, **kwargs) -> HumanMessage:
return _convert_to_message(msg)
def _assistant_msg(self, msg: str) -> AIMessage:
return _convert_to_message(msg)
def get_choice_text(self, rsp: LLMResult) -> str:
return rsp.generations[0][0].text
def get_usage(self, response: LLMResult):
message = response.generations[0][0].message
if hasattr(message, "additional_kwargs"):
return message.additional_kwargs.get("token_usage", {})
else:
return {}
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
pass
response = await self.acreate(messages, stream=False)
usage = self.get_usage(response)
self._update_costs(usage)
return response
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
# 不支持异步
w = GetMessageFromWeb(messages, self.config)
return w.run()
return await self._achat_completion(messages, timeout)
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
response = await self.acreate(messages, stream=True)
collected_content = []
usage = {}
async for chunk in response:
collected_content.append(chunk.content)
log_llm_stream(chunk.content)
if hasattr(chunk, "additional_kwargs"):
usage = chunk.additional_kwargs.get("token_usage", {})
class GetMessageFromWeb:
class WsParam:
"""
该类适合讯飞星火大部分接口的调用
输入 app_id, api_key, api_secret, spark_url以初始化
create_url方法返回接口url
"""
log_llm_stream("\n")
self._update_costs(usage)
full_content = "".join(collected_content)
return full_content
# 初始化
def __init__(self, app_id, api_key, api_secret, spark_url, message=None):
self.app_id = app_id
self.api_key = api_key
self.api_secret = api_secret
self.host = urlparse(spark_url).netloc
self.path = urlparse(spark_url).path
self.spark_url = spark_url
self.message = message
def _extract_assistant_rsp(self, context: list[BaseMessage]) -> str:
return "\n".join([i.content for i in context if "AIMessage" in any_to_str(i)])
# 生成url
def create_url(self):
# 生成RFC1123格式的时间戳
now = datetime.datetime.now()
date = format_date_time(mktime(now.timetuple()))
# 拼接字符串
signature_origin = "host: " + self.host + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + self.path + " HTTP/1.1"
# 进行hmac-sha256进行加密
signature_sha = hmac.new(
self.api_secret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256
).digest()
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding="utf-8")
authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8")
# 将请求的鉴权参数组合为字典
v = {"authorization": authorization, "date": date, "host": self.host}
# 拼接鉴权参数生成url
url = self.spark_url + "?" + urlencode(v)
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释比对相同参数时生成的url与自己代码生成的url是否一致
return url
def __init__(self, text, config: LLMConfig):
self.text = text
self.ret = ""
self.spark_appid = config.app_id
self.spark_api_secret = config.api_secret
self.spark_api_key = config.api_key
self.domain = config.domain
self.spark_url = config.base_url
def on_message(self, ws, message):
data = json.loads(message)
code = data["header"]["code"]
if code != 0:
ws.close() # 请求错误则关闭socket
logger.critical(f"回答获取失败,响应信息反序列化之后为: {data}")
return
async def acreate(self, messages: list[dict], stream: bool = True):
messages = convert_to_messages(messages)
if stream:
return self.client.astream(messages)
else:
choices = data["payload"]["choices"]
# seq = choices["seq"] # 服务端是流式返回seq为返回的数据序号
status = choices["status"] # 服务端是流式返回status用于判断信息是否传送完毕
content = choices["text"][0]["content"] # 本次接收到的回答文本
self.ret += content
if status == 2:
ws.close()
# 收到websocket错误的处理
def on_error(self, ws, error):
# on_message方法处理接收到的信息出现任何错误都会调用这个方法
logger.critical(f"通讯连接出错,【错误提示: {error}")
# 收到websocket关闭的处理
def on_close(self, ws, one, two):
pass
# 处理请求数据
def gen_params(self):
data = {
"header": {"app_id": self.spark_appid, "uid": "1234"},
"parameter": {
"chat": {
# domain为必传参数
"domain": self.domain,
# 以下为可微调,非必传参数
# 注意官方建议temperature和top_k修改一个即可
"max_tokens": 2048, # 默认2048模型回答的tokens的最大长度即允许它输出文本的最长字数
"temperature": 0.5, # 取值为[0,1],默认为0.5。取值越高随机性越强、发散性越高,即相同的问题得到的不同答案的可能性越高
"top_k": 4, # 取值为[16],默认为4。从k个候选中随机选择一个非等概率
}
},
"payload": {"message": {"text": self.text}},
}
return data
def send(self, ws, *args):
data = json.dumps(self.gen_params())
ws.send(data)
# 收到websocket连接建立的处理
def on_open(self, ws):
thread.start_new_thread(self.send, (ws,))
# 处理收到的 websocket消息出现任何错误调用on_error方法
def run(self):
return self._run(self.text)
def _run(self, text_list):
ws_param = self.WsParam(self.spark_appid, self.spark_api_key, self.spark_api_secret, self.spark_url, text_list)
ws_url = ws_param.create_url()
websocket.enableTrace(False) # 默认禁用 WebSocket 的跟踪功能
ws = websocket.WebSocketApp(
ws_url, on_message=self.on_message, on_error=self.on_error, on_close=self.on_close, on_open=self.on_open
)
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
return self.ret
return await self.client.agenerate([messages])

View file

@ -89,9 +89,9 @@ class RAGEmbeddingFactory(GenericFactory):
return OllamaEmbedding(**params)
def _try_set_model_and_batch_size(self, params: dict):
"""Set the model and embed_batch_size only when they are specified."""
"""Set the model_name and embed_batch_size only when they are specified."""
if config.embedding.model:
params["model"] = config.embedding.model
params["model_name"] = config.embedding.model
if config.embedding.embed_batch_size:
params["embed_batch_size"] = config.embedding.embed_batch_size

View file

@ -136,4 +136,4 @@ class Assistant(Role):
try:
self.memory = BrainMemory(**m)
except Exception as e:
logger.exception(f"load error:{e}, data:{jsn}")
logger.exception(f"load error:{e}, data:{m}")

View file

@ -53,30 +53,30 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
await page.wait_for_load_state("networkidle")
await page.wait_for_selector("div#container", state="attached")
# mermaid_config = {}
mermaid_config = {}
background_color = "#ffffff"
# my_css = ""
my_css = ""
await page.evaluate(f'document.body.style.background = "{background_color}";')
# metadata = await page.evaluate(
# """async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
# const { mermaid, zenuml } = globalThis;
# await mermaid.registerExternalDiagrams([zenuml]);
# mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
# const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
# document.getElementById('container').innerHTML = svg;
# const svgElement = document.querySelector('svg');
# svgElement.style.backgroundColor = backgroundColor;
#
# if (myCSS) {
# const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
# style.appendChild(document.createTextNode(myCSS));
# svgElement.appendChild(style);
# }
#
# }""",
# [mermaid_code, mermaid_config, my_css, background_color],
# )
await page.evaluate(
"""async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
const { mermaid, zenuml } = globalThis;
await mermaid.registerExternalDiagrams([zenuml]);
mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
document.getElementById('container').innerHTML = svg;
const svgElement = document.querySelector('svg');
svgElement.style.backgroundColor = backgroundColor;
if (myCSS) {
const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
style.appendChild(document.createTextNode(myCSS));
svgElement.appendChild(style);
}
}""",
[mermaid_code, mermaid_config, my_css, background_color],
)
if "svg" in suffixes:
svg_xml = await page.evaluate(

View file

@ -55,29 +55,29 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
await page.goto(mermaid_html_url)
await page.querySelector("div#container")
# mermaid_config = {}
mermaid_config = {}
background_color = "#ffffff"
# my_css = ""
my_css = ""
await page.evaluate(f'document.body.style.background = "{background_color}";')
# metadata = await page.evaluate(
# """async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
# const { mermaid, zenuml } = globalThis;
# await mermaid.registerExternalDiagrams([zenuml]);
# mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
# const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
# document.getElementById('container').innerHTML = svg;
# const svgElement = document.querySelector('svg');
# svgElement.style.backgroundColor = backgroundColor;
#
# if (myCSS) {
# const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
# style.appendChild(document.createTextNode(myCSS));
# svgElement.appendChild(style);
# }
# }""",
# [mermaid_code, mermaid_config, my_css, background_color],
# )
await page.evaluate(
"""async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
const { mermaid, zenuml } = globalThis;
await mermaid.registerExternalDiagrams([zenuml]);
mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
document.getElementById('container').innerHTML = svg;
const svgElement = document.querySelector('svg');
svgElement.style.backgroundColor = backgroundColor;
if (myCSS) {
const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
style.appendChild(document.createTextNode(myCSS));
svgElement.appendChild(style);
}
}""",
[mermaid_code, mermaid_config, my_css, background_color],
)
if "svg" in suffixes:
svg_xml = await page.evaluate(

View file

@ -68,6 +68,15 @@ TOKEN_COSTS = {
"openai/gpt-4-turbo-preview": {"prompt": 0.01, "completion": 0.03},
"deepseek-chat": {"prompt": 0.00014, "completion": 0.00028},
"deepseek-coder": {"prompt": 0.00014, "completion": 0.00028},
# For ark model https://www.volcengine.com/docs/82379/1099320
"doubao-lite-4k-240515": {"prompt": 0.000042, "completion": 0.000084},
"doubao-lite-32k-240515": {"prompt": 0.000042, "completion": 0.000084},
"doubao-lite-128k-240515": {"prompt": 0.00011, "completion": 0.00013},
"doubao-pro-4k-240515": {"prompt": 0.00011, "completion": 0.00028},
"doubao-pro-32k-240515": {"prompt": 0.00011, "completion": 0.00028},
"doubao-pro-128k-240515": {"prompt": 0.0007, "completion": 0.0012},
"llama3-70b-llama3-70b-instruct": {"prompt": 0.0, "completion": 0.0},
"llama3-8b-llama3-8b-instruct": {"prompt": 0.0, "completion": 0.0},
}
@ -213,6 +222,12 @@ TOKEN_MAX = {
"openai/gpt-4-turbo-preview": 128000,
"deepseek-chat": 32768,
"deepseek-coder": 16385,
"doubao-lite-4k-240515": 4000,
"doubao-lite-32k-240515": 32000,
"doubao-lite-128k-240515": 128000,
"doubao-pro-4k-240515": 4000,
"doubao-pro-32k-240515": 32000,
"doubao-pro-128k-240515": 128000,
}
# For Amazon Bedrock US region
@ -262,6 +277,14 @@ BEDROCK_TOKEN_COSTS = {
"ai21.j2-ultra-v1": {"prompt": 0.0188, "completion": 0.0188},
}
# https://xinghuo.xfyun.cn/sparkapi?scr=price
SPARK_TOKENS = {
"general": {"prompt": 0.0, "completion": 0.0}, # Spark-Lite
"generalv2": {"prompt": 0.0188, "completion": 0.0188}, # Spark V2.0
"generalv3": {"prompt": 0.0035, "completion": 0.0035}, # Spark Pro
"generalv3.5": {"prompt": 0.0035, "completion": 0.0035}, # Spark3.5 Max
}
def count_input_tokens(messages, model="gpt-3.5-turbo-0125"):
"""Return the number of tokens used by a list of messages."""

View file

@ -44,7 +44,7 @@ wrapt==1.15.0
# azure-cognitiveservices-speech~=1.31.0 # Used by metagpt/tools/azure_tts.py
#aioboto3~=12.4.0 # Used by metagpt/utils/s3.py
aioredis~=2.0.1 # Used by metagpt/utils/redis.py
websocket-client==1.6.2
websocket-client~=1.8.0
aiofiles==23.2.1
gitpython==3.1.40
zhipuai==2.0.1
@ -71,3 +71,4 @@ dashscope==1.14.1
rank-bm25==0.2.2 # for tool recommendation
gymnasium==0.29.1
boto3~=1.34.69
spark_ai_python~=0.3.30

File diff suppressed because one or more lines are too long

View file

@ -38,7 +38,7 @@ DESIGN = {
TASK = {
"Required Python packages": ["pygame==2.0.1"],
"Required packages": ["pygame==2.0.1"],
"Required Other language third-party packages": ["No third-party dependencies required"],
"Logic Analysis": [
["game.py", "Contains Game class and related functions for game logic"],

View file

@ -69,3 +69,5 @@ mock_llm_config_bedrock = LLMConfig(
secret_key="123abc",
max_token=10000,
)
mock_llm_config_ark = LLMConfig(api_type="ark", api_key="eyxxx", base_url="xxx", model="ep-xxx")

View file

@ -0,0 +1,85 @@
"""
用于火山方舟Python SDK V3的测试用例
API文档https://www.volcengine.com/docs/82379/1263482
"""
from typing import AsyncIterator, List, Union
import pytest
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta
from metagpt.provider.ark_api import ArkLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config_ark
from tests.metagpt.provider.req_resp_const import (
get_openai_chat_completion,
llm_general_chat_funcs_test,
messages,
prompt,
resp_cont_tmpl,
)
name = "AI assistant"
resp_cont = resp_cont_tmpl.format(name=name)
USAGE = {"completion_tokens": 1000, "prompt_tokens": 1000, "total_tokens": 2000}
default_resp = get_openai_chat_completion(name)
default_resp.model = "doubao-pro-32k-240515"
default_resp.usage = USAGE
def create_chat_completion_chunk(
content: str, finish_reason: str = None, choices: List[Choice] = None
) -> ChatCompletionChunk:
if choices is None:
choices = [
Choice(
delta=ChoiceDelta(content=content, function_call=None, role="assistant", tool_calls=None),
finish_reason=finish_reason,
index=0,
logprobs=None,
)
]
return ChatCompletionChunk(
id="012",
choices=choices,
created=1716278586,
model="doubao-pro-32k-240515",
object="chat.completion.chunk",
system_fingerprint=None,
usage=None if choices else USAGE,
)
ark_resp_chunk = create_chat_completion_chunk(content="")
ark_resp_chunk_finish = create_chat_completion_chunk(content=resp_cont, finish_reason="stop")
ark_resp_chunk_last = create_chat_completion_chunk(content="", choices=[])
async def chunk_iterator(chunks: List[ChatCompletionChunk]) -> AsyncIterator[ChatCompletionChunk]:
for chunk in chunks:
yield chunk
async def mock_ark_acompletions_create(
self, stream: bool = False, **kwargs
) -> Union[ChatCompletionChunk, ChatCompletion]:
if stream:
chunks = [ark_resp_chunk, ark_resp_chunk_finish, ark_resp_chunk_last]
return chunk_iterator(chunks)
else:
return default_resp
@pytest.mark.asyncio
async def test_ark_acompletion(mocker):
mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_ark_acompletions_create)
llm = ArkLLM(mock_llm_config_ark)
resp = await llm.acompletion(messages)
assert resp.choices[0].finish_reason == "stop"
assert resp.choices[0].message.content == resp_cont
assert resp.usage == USAGE
await llm_general_chat_funcs_test(llm, prompt, messages, resp_cont)

View file

@ -1,62 +1,55 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the unittest of spark api
"""
用于讯飞星火SDK的测试用例
文档https://www.xfyun.cn/doc/spark/Web.html
"""
from typing import AsyncIterator, List
import pytest
from sparkai.core.messages.ai import AIMessage, AIMessageChunk
from sparkai.core.outputs.chat_generation import ChatGeneration
from sparkai.core.outputs.llm_result import LLMResult
from metagpt.provider.spark_api import GetMessageFromWeb, SparkLLM
from tests.metagpt.provider.mock_llm_config import (
mock_llm_config,
mock_llm_config_spark,
)
from metagpt.provider.spark_api import SparkLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config_spark
from tests.metagpt.provider.req_resp_const import (
llm_general_chat_funcs_test,
messages,
prompt,
resp_cont_tmpl,
)
resp_cont = resp_cont_tmpl.format(name="Spark")
USAGE = {
"token_usage": {"question_tokens": 1000, "prompt_tokens": 1000, "completion_tokens": 1000, "total_tokens": 2000}
}
spark_agenerate_result = LLMResult(
generations=[[ChatGeneration(text=resp_cont, message=AIMessage(content=resp_cont, additional_kwargs=USAGE))]]
)
chunks = [AIMessageChunk(content=resp_cont), AIMessageChunk(content="", additional_kwargs=USAGE)]
class MockWebSocketApp(object):
def __init__(self, ws_url, on_message=None, on_error=None, on_close=None, on_open=None):
pass
def run_forever(self, sslopt=None):
pass
async def chunk_iterator(chunks: List[AIMessageChunk]) -> AsyncIterator[AIMessageChunk]:
for chunk in chunks:
yield chunk
def test_get_msg_from_web(mocker):
mocker.patch("websocket.WebSocketApp", MockWebSocketApp)
get_msg_from_web = GetMessageFromWeb(prompt, mock_llm_config)
assert get_msg_from_web.gen_params()["parameter"]["chat"]["domain"] == "mock_domain"
ret = get_msg_from_web.run()
assert ret == ""
def mock_spark_get_msg_from_web_run(self) -> str:
return resp_cont
@pytest.mark.asyncio
async def test_spark_aask(mocker):
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
llm = SparkLLM(mock_llm_config_spark)
resp = await llm.aask("Hello!")
assert resp == resp_cont
async def mock_spark_acreate(self, messages, stream):
if stream:
return chunk_iterator(chunks)
else:
return spark_agenerate_result
@pytest.mark.asyncio
async def test_spark_acompletion(mocker):
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
mocker.patch("metagpt.provider.spark_api.SparkLLM.acreate", mock_spark_acreate)
spark_llm = SparkLLM(mock_llm_config)
spark_llm = SparkLLM(mock_llm_config_spark)
resp = await spark_llm.acompletion([])
assert resp == resp_cont
resp = await spark_llm.acompletion([messages])
assert spark_llm.get_choice_text(resp) == resp_cont
await llm_general_chat_funcs_test(spark_llm, prompt, prompt, resp_cont)
await llm_general_chat_funcs_test(spark_llm, prompt, messages, resp_cont)

View file

@ -66,7 +66,7 @@ class TestRAGEmbeddingFactory:
@pytest.mark.parametrize(
"model, embed_batch_size, expected_params",
[("test_model", 100, {"model": "test_model", "embed_batch_size": 100}), (None, None, {})],
[("test_model", 100, {"model_name": "test_model", "embed_batch_size": 100}), (None, None, {})],
)
def test_try_set_model_and_batch_size(self, mock_config, model, embed_batch_size, expected_params):
# Mock

View file

@ -13,7 +13,7 @@ from metagpt.utils.mermaid import MMC1, mermaid_to_file
@pytest.mark.asyncio
@pytest.mark.parametrize("engine", ["nodejs", "ink"]) # TODO: playwright and pyppeteer
@pytest.mark.parametrize("engine", ["nodejs", "ink", "playwright", "pyppeteer"])
async def test_mermaid(engine, context, mermaid_mocker):
# nodejs prerequisites: npm install -g @mermaid-js/mermaid-cli
# ink prerequisites: connected to internet