mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-29 10:56:22 +02:00
Merge branch 'geekan:main' into update_qwen
This commit is contained in:
commit
4e3c46a5fd
26 changed files with 389 additions and 294 deletions
|
|
@ -35,7 +35,7 @@ browser:
|
|||
|
||||
mermaid:
|
||||
engine: "pyppeteer"
|
||||
path: "/Applications/Google Chrome.app"
|
||||
pyppeteer_path: "/Applications/Google Chrome.app"
|
||||
|
||||
redis:
|
||||
host: "YOUR_HOST"
|
||||
|
|
|
|||
5
config/examples/huoshan_ark.yaml
Normal file
5
config/examples/huoshan_ark.yaml
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
llm:
|
||||
api_type: "ark"
|
||||
model: "" # your model endpoint like ep-xxx
|
||||
base_url: "https://ark.cn-beijing.volces.com/api/v3"
|
||||
api_key: "" # your api-key like ey……
|
||||
10
config/examples/spark_lite.yaml
Normal file
10
config/examples/spark_lite.yaml
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
# 适用于讯飞星火的spark-lite 参考 https://www.xfyun.cn/doc/spark/Web.html#_2-function-call%E8%AF%B4%E6%98%8E
|
||||
|
||||
llm:
|
||||
api_type: "spark"
|
||||
# 对应模型的url 参考 https://www.xfyun.cn/doc/spark/Web.html#_1-%E6%8E%A5%E5%8F%A3%E8%AF%B4%E6%98%8E
|
||||
base_url: "ws(s)://spark-api.xf-yun.com/v1.1/chat"
|
||||
app_id: ""
|
||||
api_key: ""
|
||||
api_secret: ""
|
||||
domain: "general" # 取值为 [general,generalv2,generalv3,generalv3.5] 和url一一对应
|
||||
|
|
@ -96,7 +96,7 @@ ### Open-Ended Tasks Dataset and Requirements
|
|||
| 14 | 14_image_background_removal | Image Background Removal | Remove the background of a given image | This is an image, you need to use python toolkit rembg remove the background of the image. image path:'{data_dir}/open_ended_tasks/14_image_background_removal.jpg'; save path:'{data_dir}/open_ended_tasks/14_image_background_removal.jpg' |
|
||||
| 15 | 15_text2img | Text2Img | Use SD tools to generate images | I want to generate an image of a beautiful girl using the stable diffusion text2image tool, sd_url = "http://your.sd.service.ip:port" |
|
||||
| 16 | 16_image_2_code_generation | Image2Code Generation | Web code generation | This is a image. First, convert the image to webpage code including HTML, CSS and JS in one go, and finally save webpage code in a file.The image path: '{data_dir}/open_ended_tasks/16_image_2_code_generation.png'. NOTE: All required dependencies and environments have been fully installed and configured. |
|
||||
| 17 | 17_image_2_code_generation | Image2Code Generation | Web code generation | This is a image. First, convert the image to webpage code including HTML, CSS and JS in one go, and finally save webpage code in a file.The image path: '{data_dir}/open_ended_tasks/16_image_2_code_generation.png'. NOTE: All required dependencies and environments have been fully installed and configured. |
|
||||
| 17 | 17_image_2_code_generation | Image2Code Generation | Web code generation | This is a image. First, convert the image to webpage code including HTML, CSS and JS in one go, and finally save webpage code in a file.The image path: '{data_dir}/open_ended_tasks/17_image_2_code_generation.png'. NOTE: All required dependencies and environments have been fully installed and configured. |
|
||||
| 18 | 18_generate_games | Generate games using existing repo | Game tool usage (pyxel) | Create a Snake game. Players need to control the movement of the snake to eat food and grow its body, while avoiding the snake's head touching their own body or game boundaries. Games need to have basic game logic, user interface. During the production process, please consider factors such as playability, beautiful interface, and convenient operation of the game. Note: pyxel environment already satisfied |
|
||||
| 19 | 19_generate_games | Generate games using existing repo | Game tool usage (pyxel) | You are a professional game developer, please use pyxel software to create a simple jumping game. The game needs to include a character that can move left and right on the screen. When the player presses the spacebar, the character should jump. Please ensure that the game is easy to operate, with clear graphics, and complies with the functional limitations of pyxel software. Note: pyxel environment already satisfied |
|
||||
| 20 | 20_generate_games | Generate games using existing repo | Game tool usage (pyxel) | Make a mouse click game that click button as many times as possible in 30 seconds using pyxel. Note: pyxel environment already satisfied |
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ class WriteTasks(Action):
|
|||
|
||||
async def _update_requirements(self, doc):
|
||||
m = json.loads(doc.content)
|
||||
packages = set(m.get("Required Python packages", set()))
|
||||
packages = set(m.get("Required packages", set()))
|
||||
requirement_doc = await self.repo.get(filename=PACKAGE_REQUIREMENTS_FILENAME)
|
||||
if not requirement_doc:
|
||||
requirement_doc = Document(filename=PACKAGE_REQUIREMENTS_FILENAME, root_path=".", content="")
|
||||
|
|
|
|||
|
|
@ -9,10 +9,10 @@ from typing import List
|
|||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
|
||||
REQUIRED_PYTHON_PACKAGES = ActionNode(
|
||||
key="Required Python packages",
|
||||
REQUIRED_PACKAGES = ActionNode(
|
||||
key="Required packages",
|
||||
expected_type=List[str],
|
||||
instruction="Provide required Python packages in requirements.txt format.",
|
||||
instruction="Provide required packages in requirements.txt format.",
|
||||
example=["flask==1.1.2", "bcrypt==3.2.0"],
|
||||
)
|
||||
|
||||
|
|
@ -97,7 +97,7 @@ ANYTHING_UNCLEAR_PM = ActionNode(
|
|||
)
|
||||
|
||||
NODES = [
|
||||
REQUIRED_PYTHON_PACKAGES,
|
||||
REQUIRED_PACKAGES,
|
||||
REQUIRED_OTHER_LANGUAGE_PACKAGES,
|
||||
LOGIC_ANALYSIS,
|
||||
TASK_LIST,
|
||||
|
|
@ -107,7 +107,7 @@ NODES = [
|
|||
]
|
||||
|
||||
REFINED_NODES = [
|
||||
REQUIRED_PYTHON_PACKAGES,
|
||||
REQUIRED_PACKAGES,
|
||||
REQUIRED_OTHER_LANGUAGE_PACKAGES,
|
||||
REFINED_LOGIC_ANALYSIS,
|
||||
REFINED_TASK_LIST,
|
||||
|
|
|
|||
|
|
@ -139,7 +139,7 @@ Language: Please use the same language as the user requirement, but the title an
|
|||
end", "Anything UNCLEAR": "目前项目要求明确,没有不清楚的地方。"}
|
||||
|
||||
## Tasks
|
||||
{"Required Python packages": ["无需Python包"], "Required Other language third-party packages": ["vue.js"], "Logic Analysis": [["index.html", "作为游戏的入口文件和主要的HTML结构"], ["styles.css", "包含所有的CSS样式,确保游戏界面美观"], ["main.js", "包含Main类,负责初始化游戏和绑定事件"], ["game.js", "包含Game类,负责游戏逻辑,如开始游戏、移动方块等"], ["storage.js", "包含Storage类,用于获取和设置玩家的最高分"]], "Task list": ["index.html", "styles.css", "storage.js", "game.js", "main.js"], "Full API spec": "", "Shared Knowledge": "\'game.js\' 包含游戏逻辑相关的函数,被 \'main.js\' 调用。", "Anything UNCLEAR": "目前项目要求明确,没有不清楚的地方。"}
|
||||
{"Required packages": ["无需Python包"], "Required Other language third-party packages": ["vue.js"], "Logic Analysis": [["index.html", "作为游戏的入口文件和主要的HTML结构"], ["styles.css", "包含所有的CSS样式,确保游戏界面美观"], ["main.js", "包含Main类,负责初始化游戏和绑定事件"], ["game.js", "包含Game类,负责游戏逻辑,如开始游戏、移动方块等"], ["storage.js", "包含Storage类,用于获取和设置玩家的最高分"]], "Task list": ["index.html", "styles.css", "storage.js", "game.js", "main.js"], "Full API spec": "", "Shared Knowledge": "\'game.js\' 包含游戏逻辑相关的函数,被 \'main.js\' 调用。", "Anything UNCLEAR": "目前项目要求明确,没有不清楚的地方。"}
|
||||
|
||||
## Code Files
|
||||
----- index.html
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ class LLMType(Enum):
|
|||
YI = "yi" # lingyiwanwu
|
||||
OPENROUTER = "openrouter"
|
||||
BEDROCK = "bedrock"
|
||||
ARK = "ark"
|
||||
|
||||
def __missing__(self, key):
|
||||
return self.OPENAI
|
||||
|
|
|
|||
|
|
@ -56,10 +56,16 @@ def get_role_environment(sim_code: str, role_name: str, step: int = 0) -> dict:
|
|||
|
||||
|
||||
def write_curr_sim_code(curr_sim_code: dict, temp_storage_path: Optional[Path] = None):
|
||||
temp_storage_path = Path(temp_storage_path) or TEMP_STORAGE_PATH
|
||||
if temp_storage_path is None:
|
||||
temp_storage_path = TEMP_STORAGE_PATH
|
||||
else:
|
||||
temp_storage_path = Path(temp_storage_path)
|
||||
write_json_file(temp_storage_path.joinpath("curr_sim_code.json"), curr_sim_code)
|
||||
|
||||
|
||||
def write_curr_step(curr_step: dict, temp_storage_path: Optional[Path] = None):
|
||||
temp_storage_path = Path(temp_storage_path) or TEMP_STORAGE_PATH
|
||||
if temp_storage_path is None:
|
||||
temp_storage_path = TEMP_STORAGE_PATH
|
||||
else:
|
||||
temp_storage_path = Path(temp_storage_path)
|
||||
write_json_file(temp_storage_path.joinpath("curr_step.json"), curr_step)
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from metagpt.provider.qianfan_api import QianFanLLM
|
|||
from metagpt.provider.dashscope_api import DashScopeLLM
|
||||
from metagpt.provider.anthropic_api import AnthropicLLM
|
||||
from metagpt.provider.bedrock_api import BedrockLLM
|
||||
from metagpt.provider.ark_api import ArkLLM
|
||||
|
||||
__all__ = [
|
||||
"GeminiLLM",
|
||||
|
|
@ -32,4 +33,5 @@ __all__ = [
|
|||
"DashScopeLLM",
|
||||
"AnthropicLLM",
|
||||
"BedrockLLM",
|
||||
"ArkLLM",
|
||||
]
|
||||
|
|
|
|||
44
metagpt/provider/ark_api.py
Normal file
44
metagpt/provider/ark_api.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
from openai import AsyncStream
|
||||
from openai.types import CompletionUsage
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.const import USE_CONFIG_TIMEOUT
|
||||
from metagpt.logs import log_llm_stream
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import OpenAILLM
|
||||
|
||||
|
||||
@register_provider(LLMType.ARK)
|
||||
class ArkLLM(OpenAILLM):
|
||||
"""
|
||||
用于火山方舟的API
|
||||
见:https://www.volcengine.com/docs/82379/1263482
|
||||
"""
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str:
|
||||
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
|
||||
**self._cons_kwargs(messages, timeout=self.get_timeout(timeout)),
|
||||
stream=True,
|
||||
extra_body={"stream_options": {"include_usage": True}} # 只有增加这个参数才会在流式时最后返回usage
|
||||
)
|
||||
usage = None
|
||||
collected_messages = []
|
||||
async for chunk in response:
|
||||
chunk_message = chunk.choices[0].delta.content or "" if chunk.choices else "" # extract the message
|
||||
log_llm_stream(chunk_message)
|
||||
collected_messages.append(chunk_message)
|
||||
if chunk.usage:
|
||||
# 火山方舟的流式调用会在最后一个chunk中返回usage,最后一个chunk的choices为[]
|
||||
usage = CompletionUsage(**chunk.usage)
|
||||
|
||||
log_llm_stream("\n")
|
||||
full_reply_content = "".join(collected_messages)
|
||||
self._update_costs(usage, chunk.model)
|
||||
return full_reply_content
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> ChatCompletion:
|
||||
kwargs = self._cons_kwargs(messages, timeout=self.get_timeout(timeout))
|
||||
rsp: ChatCompletion = await self.aclient.chat.completions.create(**kwargs)
|
||||
self._update_costs(rsp.usage, rsp.model)
|
||||
return rsp
|
||||
|
|
@ -102,7 +102,10 @@ class OpenAILLM(BaseLLM):
|
|||
if finish_reason:
|
||||
if hasattr(chunk, "usage") and chunk.usage is not None:
|
||||
# Some services have usage as an attribute of the chunk, such as Fireworks
|
||||
usage = CompletionUsage(**chunk.usage)
|
||||
if isinstance(chunk.usage, CompletionUsage):
|
||||
usage = chunk.usage
|
||||
else:
|
||||
usage = CompletionUsage(**chunk.usage)
|
||||
elif hasattr(chunk.choices[0], "usage"):
|
||||
# The usage of some services is an attribute of chunk.choices[0], such as Moonshot
|
||||
usage = CompletionUsage(**chunk.choices[0].usage)
|
||||
|
|
|
|||
|
|
@ -1,175 +1,95 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@File : spark_api.py
|
||||
"""
|
||||
import _thread as thread
|
||||
import base64
|
||||
import datetime
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import ssl
|
||||
from time import mktime
|
||||
from urllib.parse import urlencode, urlparse
|
||||
from wsgiref.handlers import format_date_time
|
||||
|
||||
import websocket # 使用websocket_client
|
||||
from sparkai.core.messages import _convert_to_message, convert_to_messages
|
||||
from sparkai.core.messages.ai import AIMessage
|
||||
from sparkai.core.messages.base import BaseMessage
|
||||
from sparkai.core.messages.human import HumanMessage
|
||||
from sparkai.core.messages.system import SystemMessage
|
||||
from sparkai.core.outputs.llm_result import LLMResult
|
||||
from sparkai.llm.llm import ChatSparkLLM
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.const import USE_CONFIG_TIMEOUT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.logs import log_llm_stream
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.utils.common import any_to_str
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.token_counter import SPARK_TOKENS
|
||||
|
||||
|
||||
@register_provider(LLMType.SPARK)
|
||||
class SparkLLM(BaseLLM):
|
||||
"""
|
||||
用于讯飞星火大模型系列
|
||||
参考:https://github.com/iflytek/spark-ai-python"""
|
||||
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.config = config
|
||||
logger.warning("SparkLLM:当前方法无法支持异步运行。当你使用acompletion时,并不能并行访问。")
|
||||
self.cost_manager = CostManager(token_costs=SPARK_TOKENS)
|
||||
self.model = self.config.domain
|
||||
self._init_client()
|
||||
|
||||
def get_choice_text(self, rsp: dict) -> str:
|
||||
return rsp["payload"]["choices"]["text"][-1]["content"]
|
||||
def _init_client(self):
|
||||
self.client = ChatSparkLLM(
|
||||
spark_api_url=self.config.base_url,
|
||||
spark_app_id=self.config.app_id,
|
||||
spark_api_key=self.config.api_key,
|
||||
spark_api_secret=self.config.api_secret,
|
||||
spark_llm_domain=self.config.domain,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
|
||||
pass
|
||||
def _system_msg(self, msg: str) -> SystemMessage:
|
||||
return _convert_to_message(msg)
|
||||
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = USE_CONFIG_TIMEOUT) -> str:
|
||||
# 不支持
|
||||
# logger.warning("当前方法无法支持异步运行。当你使用acompletion时,并不能并行访问。")
|
||||
w = GetMessageFromWeb(messages, self.config)
|
||||
return w.run()
|
||||
def _user_msg(self, msg: str, **kwargs) -> HumanMessage:
|
||||
return _convert_to_message(msg)
|
||||
|
||||
def _assistant_msg(self, msg: str) -> AIMessage:
|
||||
return _convert_to_message(msg)
|
||||
|
||||
def get_choice_text(self, rsp: LLMResult) -> str:
|
||||
return rsp.generations[0][0].text
|
||||
|
||||
def get_usage(self, response: LLMResult):
|
||||
message = response.generations[0][0].message
|
||||
if hasattr(message, "additional_kwargs"):
|
||||
return message.additional_kwargs.get("token_usage", {})
|
||||
else:
|
||||
return {}
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
|
||||
pass
|
||||
response = await self.acreate(messages, stream=False)
|
||||
usage = self.get_usage(response)
|
||||
self._update_costs(usage)
|
||||
return response
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
|
||||
# 不支持异步
|
||||
w = GetMessageFromWeb(messages, self.config)
|
||||
return w.run()
|
||||
return await self._achat_completion(messages, timeout)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
|
||||
response = await self.acreate(messages, stream=True)
|
||||
collected_content = []
|
||||
usage = {}
|
||||
async for chunk in response:
|
||||
collected_content.append(chunk.content)
|
||||
log_llm_stream(chunk.content)
|
||||
if hasattr(chunk, "additional_kwargs"):
|
||||
usage = chunk.additional_kwargs.get("token_usage", {})
|
||||
|
||||
class GetMessageFromWeb:
|
||||
class WsParam:
|
||||
"""
|
||||
该类适合讯飞星火大部分接口的调用。
|
||||
输入 app_id, api_key, api_secret, spark_url以初始化,
|
||||
create_url方法返回接口url
|
||||
"""
|
||||
log_llm_stream("\n")
|
||||
self._update_costs(usage)
|
||||
full_content = "".join(collected_content)
|
||||
return full_content
|
||||
|
||||
# 初始化
|
||||
def __init__(self, app_id, api_key, api_secret, spark_url, message=None):
|
||||
self.app_id = app_id
|
||||
self.api_key = api_key
|
||||
self.api_secret = api_secret
|
||||
self.host = urlparse(spark_url).netloc
|
||||
self.path = urlparse(spark_url).path
|
||||
self.spark_url = spark_url
|
||||
self.message = message
|
||||
def _extract_assistant_rsp(self, context: list[BaseMessage]) -> str:
|
||||
return "\n".join([i.content for i in context if "AIMessage" in any_to_str(i)])
|
||||
|
||||
# 生成url
|
||||
def create_url(self):
|
||||
# 生成RFC1123格式的时间戳
|
||||
now = datetime.datetime.now()
|
||||
date = format_date_time(mktime(now.timetuple()))
|
||||
|
||||
# 拼接字符串
|
||||
signature_origin = "host: " + self.host + "\n"
|
||||
signature_origin += "date: " + date + "\n"
|
||||
signature_origin += "GET " + self.path + " HTTP/1.1"
|
||||
|
||||
# 进行hmac-sha256进行加密
|
||||
signature_sha = hmac.new(
|
||||
self.api_secret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256
|
||||
).digest()
|
||||
|
||||
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding="utf-8")
|
||||
|
||||
authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
|
||||
|
||||
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8")
|
||||
|
||||
# 将请求的鉴权参数组合为字典
|
||||
v = {"authorization": authorization, "date": date, "host": self.host}
|
||||
# 拼接鉴权参数,生成url
|
||||
url = self.spark_url + "?" + urlencode(v)
|
||||
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
|
||||
return url
|
||||
|
||||
def __init__(self, text, config: LLMConfig):
|
||||
self.text = text
|
||||
self.ret = ""
|
||||
self.spark_appid = config.app_id
|
||||
self.spark_api_secret = config.api_secret
|
||||
self.spark_api_key = config.api_key
|
||||
self.domain = config.domain
|
||||
self.spark_url = config.base_url
|
||||
|
||||
def on_message(self, ws, message):
|
||||
data = json.loads(message)
|
||||
code = data["header"]["code"]
|
||||
|
||||
if code != 0:
|
||||
ws.close() # 请求错误,则关闭socket
|
||||
logger.critical(f"回答获取失败,响应信息反序列化之后为: {data}")
|
||||
return
|
||||
async def acreate(self, messages: list[dict], stream: bool = True):
|
||||
messages = convert_to_messages(messages)
|
||||
if stream:
|
||||
return self.client.astream(messages)
|
||||
else:
|
||||
choices = data["payload"]["choices"]
|
||||
# seq = choices["seq"] # 服务端是流式返回,seq为返回的数据序号
|
||||
status = choices["status"] # 服务端是流式返回,status用于判断信息是否传送完毕
|
||||
content = choices["text"][0]["content"] # 本次接收到的回答文本
|
||||
self.ret += content
|
||||
if status == 2:
|
||||
ws.close()
|
||||
|
||||
# 收到websocket错误的处理
|
||||
def on_error(self, ws, error):
|
||||
# on_message方法处理接收到的信息,出现任何错误,都会调用这个方法
|
||||
logger.critical(f"通讯连接出错,【错误提示: {error}】")
|
||||
|
||||
# 收到websocket关闭的处理
|
||||
def on_close(self, ws, one, two):
|
||||
pass
|
||||
|
||||
# 处理请求数据
|
||||
def gen_params(self):
|
||||
data = {
|
||||
"header": {"app_id": self.spark_appid, "uid": "1234"},
|
||||
"parameter": {
|
||||
"chat": {
|
||||
# domain为必传参数
|
||||
"domain": self.domain,
|
||||
# 以下为可微调,非必传参数
|
||||
# 注意:官方建议,temperature和top_k修改一个即可
|
||||
"max_tokens": 2048, # 默认2048,模型回答的tokens的最大长度,即允许它输出文本的最长字数
|
||||
"temperature": 0.5, # 取值为[0,1],默认为0.5。取值越高随机性越强、发散性越高,即相同的问题得到的不同答案的可能性越高
|
||||
"top_k": 4, # 取值为[1,6],默认为4。从k个候选中随机选择一个(非等概率)
|
||||
}
|
||||
},
|
||||
"payload": {"message": {"text": self.text}},
|
||||
}
|
||||
return data
|
||||
|
||||
def send(self, ws, *args):
|
||||
data = json.dumps(self.gen_params())
|
||||
ws.send(data)
|
||||
|
||||
# 收到websocket连接建立的处理
|
||||
def on_open(self, ws):
|
||||
thread.start_new_thread(self.send, (ws,))
|
||||
|
||||
# 处理收到的 websocket消息,出现任何错误,调用on_error方法
|
||||
def run(self):
|
||||
return self._run(self.text)
|
||||
|
||||
def _run(self, text_list):
|
||||
ws_param = self.WsParam(self.spark_appid, self.spark_api_key, self.spark_api_secret, self.spark_url, text_list)
|
||||
ws_url = ws_param.create_url()
|
||||
|
||||
websocket.enableTrace(False) # 默认禁用 WebSocket 的跟踪功能
|
||||
ws = websocket.WebSocketApp(
|
||||
ws_url, on_message=self.on_message, on_error=self.on_error, on_close=self.on_close, on_open=self.on_open
|
||||
)
|
||||
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
||||
return self.ret
|
||||
return await self.client.agenerate([messages])
|
||||
|
|
|
|||
|
|
@ -89,9 +89,9 @@ class RAGEmbeddingFactory(GenericFactory):
|
|||
return OllamaEmbedding(**params)
|
||||
|
||||
def _try_set_model_and_batch_size(self, params: dict):
|
||||
"""Set the model and embed_batch_size only when they are specified."""
|
||||
"""Set the model_name and embed_batch_size only when they are specified."""
|
||||
if config.embedding.model:
|
||||
params["model"] = config.embedding.model
|
||||
params["model_name"] = config.embedding.model
|
||||
|
||||
if config.embedding.embed_batch_size:
|
||||
params["embed_batch_size"] = config.embedding.embed_batch_size
|
||||
|
|
|
|||
|
|
@ -136,4 +136,4 @@ class Assistant(Role):
|
|||
try:
|
||||
self.memory = BrainMemory(**m)
|
||||
except Exception as e:
|
||||
logger.exception(f"load error:{e}, data:{jsn}")
|
||||
logger.exception(f"load error:{e}, data:{m}")
|
||||
|
|
|
|||
|
|
@ -53,30 +53,30 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
|
|||
await page.wait_for_load_state("networkidle")
|
||||
|
||||
await page.wait_for_selector("div#container", state="attached")
|
||||
# mermaid_config = {}
|
||||
mermaid_config = {}
|
||||
background_color = "#ffffff"
|
||||
# my_css = ""
|
||||
my_css = ""
|
||||
await page.evaluate(f'document.body.style.background = "{background_color}";')
|
||||
|
||||
# metadata = await page.evaluate(
|
||||
# """async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
|
||||
# const { mermaid, zenuml } = globalThis;
|
||||
# await mermaid.registerExternalDiagrams([zenuml]);
|
||||
# mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
|
||||
# const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
|
||||
# document.getElementById('container').innerHTML = svg;
|
||||
# const svgElement = document.querySelector('svg');
|
||||
# svgElement.style.backgroundColor = backgroundColor;
|
||||
#
|
||||
# if (myCSS) {
|
||||
# const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
|
||||
# style.appendChild(document.createTextNode(myCSS));
|
||||
# svgElement.appendChild(style);
|
||||
# }
|
||||
#
|
||||
# }""",
|
||||
# [mermaid_code, mermaid_config, my_css, background_color],
|
||||
# )
|
||||
await page.evaluate(
|
||||
"""async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
|
||||
const { mermaid, zenuml } = globalThis;
|
||||
await mermaid.registerExternalDiagrams([zenuml]);
|
||||
mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
|
||||
const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
|
||||
document.getElementById('container').innerHTML = svg;
|
||||
const svgElement = document.querySelector('svg');
|
||||
svgElement.style.backgroundColor = backgroundColor;
|
||||
|
||||
if (myCSS) {
|
||||
const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
|
||||
style.appendChild(document.createTextNode(myCSS));
|
||||
svgElement.appendChild(style);
|
||||
}
|
||||
|
||||
}""",
|
||||
[mermaid_code, mermaid_config, my_css, background_color],
|
||||
)
|
||||
|
||||
if "svg" in suffixes:
|
||||
svg_xml = await page.evaluate(
|
||||
|
|
|
|||
|
|
@ -55,29 +55,29 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
|
|||
await page.goto(mermaid_html_url)
|
||||
|
||||
await page.querySelector("div#container")
|
||||
# mermaid_config = {}
|
||||
mermaid_config = {}
|
||||
background_color = "#ffffff"
|
||||
# my_css = ""
|
||||
my_css = ""
|
||||
await page.evaluate(f'document.body.style.background = "{background_color}";')
|
||||
|
||||
# metadata = await page.evaluate(
|
||||
# """async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
|
||||
# const { mermaid, zenuml } = globalThis;
|
||||
# await mermaid.registerExternalDiagrams([zenuml]);
|
||||
# mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
|
||||
# const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
|
||||
# document.getElementById('container').innerHTML = svg;
|
||||
# const svgElement = document.querySelector('svg');
|
||||
# svgElement.style.backgroundColor = backgroundColor;
|
||||
#
|
||||
# if (myCSS) {
|
||||
# const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
|
||||
# style.appendChild(document.createTextNode(myCSS));
|
||||
# svgElement.appendChild(style);
|
||||
# }
|
||||
# }""",
|
||||
# [mermaid_code, mermaid_config, my_css, background_color],
|
||||
# )
|
||||
await page.evaluate(
|
||||
"""async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
|
||||
const { mermaid, zenuml } = globalThis;
|
||||
await mermaid.registerExternalDiagrams([zenuml]);
|
||||
mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
|
||||
const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
|
||||
document.getElementById('container').innerHTML = svg;
|
||||
const svgElement = document.querySelector('svg');
|
||||
svgElement.style.backgroundColor = backgroundColor;
|
||||
|
||||
if (myCSS) {
|
||||
const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
|
||||
style.appendChild(document.createTextNode(myCSS));
|
||||
svgElement.appendChild(style);
|
||||
}
|
||||
}""",
|
||||
[mermaid_code, mermaid_config, my_css, background_color],
|
||||
)
|
||||
|
||||
if "svg" in suffixes:
|
||||
svg_xml = await page.evaluate(
|
||||
|
|
|
|||
|
|
@ -68,6 +68,15 @@ TOKEN_COSTS = {
|
|||
"openai/gpt-4-turbo-preview": {"prompt": 0.01, "completion": 0.03},
|
||||
"deepseek-chat": {"prompt": 0.00014, "completion": 0.00028},
|
||||
"deepseek-coder": {"prompt": 0.00014, "completion": 0.00028},
|
||||
# For ark model https://www.volcengine.com/docs/82379/1099320
|
||||
"doubao-lite-4k-240515": {"prompt": 0.000042, "completion": 0.000084},
|
||||
"doubao-lite-32k-240515": {"prompt": 0.000042, "completion": 0.000084},
|
||||
"doubao-lite-128k-240515": {"prompt": 0.00011, "completion": 0.00013},
|
||||
"doubao-pro-4k-240515": {"prompt": 0.00011, "completion": 0.00028},
|
||||
"doubao-pro-32k-240515": {"prompt": 0.00011, "completion": 0.00028},
|
||||
"doubao-pro-128k-240515": {"prompt": 0.0007, "completion": 0.0012},
|
||||
"llama3-70b-llama3-70b-instruct": {"prompt": 0.0, "completion": 0.0},
|
||||
"llama3-8b-llama3-8b-instruct": {"prompt": 0.0, "completion": 0.0},
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -213,6 +222,12 @@ TOKEN_MAX = {
|
|||
"openai/gpt-4-turbo-preview": 128000,
|
||||
"deepseek-chat": 32768,
|
||||
"deepseek-coder": 16385,
|
||||
"doubao-lite-4k-240515": 4000,
|
||||
"doubao-lite-32k-240515": 32000,
|
||||
"doubao-lite-128k-240515": 128000,
|
||||
"doubao-pro-4k-240515": 4000,
|
||||
"doubao-pro-32k-240515": 32000,
|
||||
"doubao-pro-128k-240515": 128000,
|
||||
}
|
||||
|
||||
# For Amazon Bedrock US region
|
||||
|
|
@ -262,6 +277,14 @@ BEDROCK_TOKEN_COSTS = {
|
|||
"ai21.j2-ultra-v1": {"prompt": 0.0188, "completion": 0.0188},
|
||||
}
|
||||
|
||||
# https://xinghuo.xfyun.cn/sparkapi?scr=price
|
||||
SPARK_TOKENS = {
|
||||
"general": {"prompt": 0.0, "completion": 0.0}, # Spark-Lite
|
||||
"generalv2": {"prompt": 0.0188, "completion": 0.0188}, # Spark V2.0
|
||||
"generalv3": {"prompt": 0.0035, "completion": 0.0035}, # Spark Pro
|
||||
"generalv3.5": {"prompt": 0.0035, "completion": 0.0035}, # Spark3.5 Max
|
||||
}
|
||||
|
||||
|
||||
def count_input_tokens(messages, model="gpt-3.5-turbo-0125"):
|
||||
"""Return the number of tokens used by a list of messages."""
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ wrapt==1.15.0
|
|||
# azure-cognitiveservices-speech~=1.31.0 # Used by metagpt/tools/azure_tts.py
|
||||
#aioboto3~=12.4.0 # Used by metagpt/utils/s3.py
|
||||
aioredis~=2.0.1 # Used by metagpt/utils/redis.py
|
||||
websocket-client==1.6.2
|
||||
websocket-client~=1.8.0
|
||||
aiofiles==23.2.1
|
||||
gitpython==3.1.40
|
||||
zhipuai==2.0.1
|
||||
|
|
@ -71,3 +71,4 @@ dashscope==1.14.1
|
|||
rank-bm25==0.2.2 # for tool recommendation
|
||||
gymnasium==0.29.1
|
||||
boto3~=1.34.69
|
||||
spark_ai_python~=0.3.30
|
||||
File diff suppressed because one or more lines are too long
|
|
@ -38,7 +38,7 @@ DESIGN = {
|
|||
|
||||
|
||||
TASK = {
|
||||
"Required Python packages": ["pygame==2.0.1"],
|
||||
"Required packages": ["pygame==2.0.1"],
|
||||
"Required Other language third-party packages": ["No third-party dependencies required"],
|
||||
"Logic Analysis": [
|
||||
["game.py", "Contains Game class and related functions for game logic"],
|
||||
|
|
|
|||
|
|
@ -69,3 +69,5 @@ mock_llm_config_bedrock = LLMConfig(
|
|||
secret_key="123abc",
|
||||
max_token=10000,
|
||||
)
|
||||
|
||||
mock_llm_config_ark = LLMConfig(api_type="ark", api_key="eyxxx", base_url="xxx", model="ep-xxx")
|
||||
|
|
|
|||
85
tests/metagpt/provider/test_ark.py
Normal file
85
tests/metagpt/provider/test_ark.py
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
"""
|
||||
用于火山方舟Python SDK V3的测试用例
|
||||
API文档:https://www.volcengine.com/docs/82379/1263482
|
||||
"""
|
||||
|
||||
from typing import AsyncIterator, List, Union
|
||||
|
||||
import pytest
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta
|
||||
|
||||
from metagpt.provider.ark_api import ArkLLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config_ark
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
get_openai_chat_completion,
|
||||
llm_general_chat_funcs_test,
|
||||
messages,
|
||||
prompt,
|
||||
resp_cont_tmpl,
|
||||
)
|
||||
|
||||
name = "AI assistant"
|
||||
resp_cont = resp_cont_tmpl.format(name=name)
|
||||
USAGE = {"completion_tokens": 1000, "prompt_tokens": 1000, "total_tokens": 2000}
|
||||
default_resp = get_openai_chat_completion(name)
|
||||
default_resp.model = "doubao-pro-32k-240515"
|
||||
default_resp.usage = USAGE
|
||||
|
||||
|
||||
def create_chat_completion_chunk(
|
||||
content: str, finish_reason: str = None, choices: List[Choice] = None
|
||||
) -> ChatCompletionChunk:
|
||||
if choices is None:
|
||||
choices = [
|
||||
Choice(
|
||||
delta=ChoiceDelta(content=content, function_call=None, role="assistant", tool_calls=None),
|
||||
finish_reason=finish_reason,
|
||||
index=0,
|
||||
logprobs=None,
|
||||
)
|
||||
]
|
||||
|
||||
return ChatCompletionChunk(
|
||||
id="012",
|
||||
choices=choices,
|
||||
created=1716278586,
|
||||
model="doubao-pro-32k-240515",
|
||||
object="chat.completion.chunk",
|
||||
system_fingerprint=None,
|
||||
usage=None if choices else USAGE,
|
||||
)
|
||||
|
||||
|
||||
ark_resp_chunk = create_chat_completion_chunk(content="")
|
||||
ark_resp_chunk_finish = create_chat_completion_chunk(content=resp_cont, finish_reason="stop")
|
||||
ark_resp_chunk_last = create_chat_completion_chunk(content="", choices=[])
|
||||
|
||||
|
||||
async def chunk_iterator(chunks: List[ChatCompletionChunk]) -> AsyncIterator[ChatCompletionChunk]:
|
||||
for chunk in chunks:
|
||||
yield chunk
|
||||
|
||||
|
||||
async def mock_ark_acompletions_create(
|
||||
self, stream: bool = False, **kwargs
|
||||
) -> Union[ChatCompletionChunk, ChatCompletion]:
|
||||
if stream:
|
||||
chunks = [ark_resp_chunk, ark_resp_chunk_finish, ark_resp_chunk_last]
|
||||
return chunk_iterator(chunks)
|
||||
else:
|
||||
return default_resp
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ark_acompletion(mocker):
|
||||
mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_ark_acompletions_create)
|
||||
|
||||
llm = ArkLLM(mock_llm_config_ark)
|
||||
|
||||
resp = await llm.acompletion(messages)
|
||||
assert resp.choices[0].finish_reason == "stop"
|
||||
assert resp.choices[0].message.content == resp_cont
|
||||
assert resp.usage == USAGE
|
||||
|
||||
await llm_general_chat_funcs_test(llm, prompt, messages, resp_cont)
|
||||
|
|
@ -1,62 +1,55 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of spark api
|
||||
"""
|
||||
用于讯飞星火SDK的测试用例
|
||||
文档:https://www.xfyun.cn/doc/spark/Web.html
|
||||
"""
|
||||
|
||||
|
||||
from typing import AsyncIterator, List
|
||||
|
||||
import pytest
|
||||
from sparkai.core.messages.ai import AIMessage, AIMessageChunk
|
||||
from sparkai.core.outputs.chat_generation import ChatGeneration
|
||||
from sparkai.core.outputs.llm_result import LLMResult
|
||||
|
||||
from metagpt.provider.spark_api import GetMessageFromWeb, SparkLLM
|
||||
from tests.metagpt.provider.mock_llm_config import (
|
||||
mock_llm_config,
|
||||
mock_llm_config_spark,
|
||||
)
|
||||
from metagpt.provider.spark_api import SparkLLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config_spark
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
llm_general_chat_funcs_test,
|
||||
messages,
|
||||
prompt,
|
||||
resp_cont_tmpl,
|
||||
)
|
||||
|
||||
resp_cont = resp_cont_tmpl.format(name="Spark")
|
||||
USAGE = {
|
||||
"token_usage": {"question_tokens": 1000, "prompt_tokens": 1000, "completion_tokens": 1000, "total_tokens": 2000}
|
||||
}
|
||||
spark_agenerate_result = LLMResult(
|
||||
generations=[[ChatGeneration(text=resp_cont, message=AIMessage(content=resp_cont, additional_kwargs=USAGE))]]
|
||||
)
|
||||
|
||||
chunks = [AIMessageChunk(content=resp_cont), AIMessageChunk(content="", additional_kwargs=USAGE)]
|
||||
|
||||
|
||||
class MockWebSocketApp(object):
|
||||
def __init__(self, ws_url, on_message=None, on_error=None, on_close=None, on_open=None):
|
||||
pass
|
||||
|
||||
def run_forever(self, sslopt=None):
|
||||
pass
|
||||
async def chunk_iterator(chunks: List[AIMessageChunk]) -> AsyncIterator[AIMessageChunk]:
|
||||
for chunk in chunks:
|
||||
yield chunk
|
||||
|
||||
|
||||
def test_get_msg_from_web(mocker):
|
||||
mocker.patch("websocket.WebSocketApp", MockWebSocketApp)
|
||||
|
||||
get_msg_from_web = GetMessageFromWeb(prompt, mock_llm_config)
|
||||
assert get_msg_from_web.gen_params()["parameter"]["chat"]["domain"] == "mock_domain"
|
||||
|
||||
ret = get_msg_from_web.run()
|
||||
assert ret == ""
|
||||
|
||||
|
||||
def mock_spark_get_msg_from_web_run(self) -> str:
|
||||
return resp_cont
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spark_aask(mocker):
|
||||
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
|
||||
|
||||
llm = SparkLLM(mock_llm_config_spark)
|
||||
|
||||
resp = await llm.aask("Hello!")
|
||||
assert resp == resp_cont
|
||||
async def mock_spark_acreate(self, messages, stream):
|
||||
if stream:
|
||||
return chunk_iterator(chunks)
|
||||
else:
|
||||
return spark_agenerate_result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spark_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
|
||||
mocker.patch("metagpt.provider.spark_api.SparkLLM.acreate", mock_spark_acreate)
|
||||
|
||||
spark_llm = SparkLLM(mock_llm_config)
|
||||
spark_llm = SparkLLM(mock_llm_config_spark)
|
||||
|
||||
resp = await spark_llm.acompletion([])
|
||||
assert resp == resp_cont
|
||||
resp = await spark_llm.acompletion([messages])
|
||||
assert spark_llm.get_choice_text(resp) == resp_cont
|
||||
|
||||
await llm_general_chat_funcs_test(spark_llm, prompt, prompt, resp_cont)
|
||||
await llm_general_chat_funcs_test(spark_llm, prompt, messages, resp_cont)
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ class TestRAGEmbeddingFactory:
|
|||
|
||||
@pytest.mark.parametrize(
|
||||
"model, embed_batch_size, expected_params",
|
||||
[("test_model", 100, {"model": "test_model", "embed_batch_size": 100}), (None, None, {})],
|
||||
[("test_model", 100, {"model_name": "test_model", "embed_batch_size": 100}), (None, None, {})],
|
||||
)
|
||||
def test_try_set_model_and_batch_size(self, mock_config, model, embed_batch_size, expected_params):
|
||||
# Mock
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from metagpt.utils.mermaid import MMC1, mermaid_to_file
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("engine", ["nodejs", "ink"]) # TODO: playwright and pyppeteer
|
||||
@pytest.mark.parametrize("engine", ["nodejs", "ink", "playwright", "pyppeteer"])
|
||||
async def test_mermaid(engine, context, mermaid_mocker):
|
||||
# nodejs prerequisites: npm install -g @mermaid-js/mermaid-cli
|
||||
# ink prerequisites: connected to internet
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue