Merge branch 'geekan:main' into update_qwen

This commit is contained in:
usamimeri_renko 2024-05-29 16:27:49 +08:00 committed by GitHub
commit 4e3c46a5fd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
26 changed files with 389 additions and 294 deletions

View file

@ -84,7 +84,7 @@ class WriteTasks(Action):
async def _update_requirements(self, doc):
m = json.loads(doc.content)
packages = set(m.get("Required Python packages", set()))
packages = set(m.get("Required packages", set()))
requirement_doc = await self.repo.get(filename=PACKAGE_REQUIREMENTS_FILENAME)
if not requirement_doc:
requirement_doc = Document(filename=PACKAGE_REQUIREMENTS_FILENAME, root_path=".", content="")

View file

@ -9,10 +9,10 @@ from typing import List
from metagpt.actions.action_node import ActionNode
REQUIRED_PYTHON_PACKAGES = ActionNode(
key="Required Python packages",
REQUIRED_PACKAGES = ActionNode(
key="Required packages",
expected_type=List[str],
instruction="Provide required Python packages in requirements.txt format.",
instruction="Provide required packages in requirements.txt format.",
example=["flask==1.1.2", "bcrypt==3.2.0"],
)
@ -97,7 +97,7 @@ ANYTHING_UNCLEAR_PM = ActionNode(
)
NODES = [
REQUIRED_PYTHON_PACKAGES,
REQUIRED_PACKAGES,
REQUIRED_OTHER_LANGUAGE_PACKAGES,
LOGIC_ANALYSIS,
TASK_LIST,
@ -107,7 +107,7 @@ NODES = [
]
REFINED_NODES = [
REQUIRED_PYTHON_PACKAGES,
REQUIRED_PACKAGES,
REQUIRED_OTHER_LANGUAGE_PACKAGES,
REFINED_LOGIC_ANALYSIS,
REFINED_TASK_LIST,

View file

@ -139,7 +139,7 @@ Language: Please use the same language as the user requirement, but the title an
end", "Anything UNCLEAR": "目前项目要求明确没有不清楚的地方"}
## Tasks
{"Required Python packages": ["无需Python包"], "Required Other language third-party packages": ["vue.js"], "Logic Analysis": [["index.html", "作为游戏的入口文件和主要的HTML结构"], ["styles.css", "包含所有的CSS样式确保游戏界面美观"], ["main.js", "包含Main类负责初始化游戏和绑定事件"], ["game.js", "包含Game类负责游戏逻辑如开始游戏、移动方块等"], ["storage.js", "包含Storage类用于获取和设置玩家的最高分"]], "Task list": ["index.html", "styles.css", "storage.js", "game.js", "main.js"], "Full API spec": "", "Shared Knowledge": "\'game.js\' 包含游戏逻辑相关的函数,被 \'main.js\' 调用。", "Anything UNCLEAR": "目前项目要求明确,没有不清楚的地方。"}
{"Required packages": ["无需Python包"], "Required Other language third-party packages": ["vue.js"], "Logic Analysis": [["index.html", "作为游戏的入口文件和主要的HTML结构"], ["styles.css", "包含所有的CSS样式确保游戏界面美观"], ["main.js", "包含Main类负责初始化游戏和绑定事件"], ["game.js", "包含Game类负责游戏逻辑如开始游戏、移动方块等"], ["storage.js", "包含Storage类用于获取和设置玩家的最高分"]], "Task list": ["index.html", "styles.css", "storage.js", "game.js", "main.js"], "Full API spec": "", "Shared Knowledge": "\'game.js\' 包含游戏逻辑相关的函数,被 \'main.js\' 调用。", "Anything UNCLEAR": "目前项目要求明确,没有不清楚的地方。"}
## Code Files
----- index.html

View file

@ -33,6 +33,7 @@ class LLMType(Enum):
YI = "yi" # lingyiwanwu
OPENROUTER = "openrouter"
BEDROCK = "bedrock"
ARK = "ark"
def __missing__(self, key):
return self.OPENAI

View file

@ -56,10 +56,16 @@ def get_role_environment(sim_code: str, role_name: str, step: int = 0) -> dict:
def write_curr_sim_code(curr_sim_code: dict, temp_storage_path: Optional[Path] = None):
temp_storage_path = Path(temp_storage_path) or TEMP_STORAGE_PATH
if temp_storage_path is None:
temp_storage_path = TEMP_STORAGE_PATH
else:
temp_storage_path = Path(temp_storage_path)
write_json_file(temp_storage_path.joinpath("curr_sim_code.json"), curr_sim_code)
def write_curr_step(curr_step: dict, temp_storage_path: Optional[Path] = None):
temp_storage_path = Path(temp_storage_path) or TEMP_STORAGE_PATH
if temp_storage_path is None:
temp_storage_path = TEMP_STORAGE_PATH
else:
temp_storage_path = Path(temp_storage_path)
write_json_file(temp_storage_path.joinpath("curr_step.json"), curr_step)

View file

@ -18,6 +18,7 @@ from metagpt.provider.qianfan_api import QianFanLLM
from metagpt.provider.dashscope_api import DashScopeLLM
from metagpt.provider.anthropic_api import AnthropicLLM
from metagpt.provider.bedrock_api import BedrockLLM
from metagpt.provider.ark_api import ArkLLM
__all__ = [
"GeminiLLM",
@ -32,4 +33,5 @@ __all__ = [
"DashScopeLLM",
"AnthropicLLM",
"BedrockLLM",
"ArkLLM",
]

View file

@ -0,0 +1,44 @@
from openai import AsyncStream
from openai.types import CompletionUsage
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from metagpt.configs.llm_config import LLMType
from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.logs import log_llm_stream
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import OpenAILLM
@register_provider(LLMType.ARK)
class ArkLLM(OpenAILLM):
"""
用于火山方舟的API
https://www.volcengine.com/docs/82379/1263482
"""
async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str:
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
**self._cons_kwargs(messages, timeout=self.get_timeout(timeout)),
stream=True,
extra_body={"stream_options": {"include_usage": True}} # 只有增加这个参数才会在流式时最后返回usage
)
usage = None
collected_messages = []
async for chunk in response:
chunk_message = chunk.choices[0].delta.content or "" if chunk.choices else "" # extract the message
log_llm_stream(chunk_message)
collected_messages.append(chunk_message)
if chunk.usage:
# 火山方舟的流式调用会在最后一个chunk中返回usage,最后一个chunk的choices为[]
usage = CompletionUsage(**chunk.usage)
log_llm_stream("\n")
full_reply_content = "".join(collected_messages)
self._update_costs(usage, chunk.model)
return full_reply_content
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> ChatCompletion:
kwargs = self._cons_kwargs(messages, timeout=self.get_timeout(timeout))
rsp: ChatCompletion = await self.aclient.chat.completions.create(**kwargs)
self._update_costs(rsp.usage, rsp.model)
return rsp

View file

@ -102,7 +102,10 @@ class OpenAILLM(BaseLLM):
if finish_reason:
if hasattr(chunk, "usage") and chunk.usage is not None:
# Some services have usage as an attribute of the chunk, such as Fireworks
usage = CompletionUsage(**chunk.usage)
if isinstance(chunk.usage, CompletionUsage):
usage = chunk.usage
else:
usage = CompletionUsage(**chunk.usage)
elif hasattr(chunk.choices[0], "usage"):
# The usage of some services is an attribute of chunk.choices[0], such as Moonshot
usage = CompletionUsage(**chunk.choices[0].usage)

View file

@ -1,175 +1,95 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@File : spark_api.py
"""
import _thread as thread
import base64
import datetime
import hashlib
import hmac
import json
import ssl
from time import mktime
from urllib.parse import urlencode, urlparse
from wsgiref.handlers import format_date_time
import websocket # 使用websocket_client
from sparkai.core.messages import _convert_to_message, convert_to_messages
from sparkai.core.messages.ai import AIMessage
from sparkai.core.messages.base import BaseMessage
from sparkai.core.messages.human import HumanMessage
from sparkai.core.messages.system import SystemMessage
from sparkai.core.outputs.llm_result import LLMResult
from sparkai.llm.llm import ChatSparkLLM
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.logs import logger
from metagpt.logs import log_llm_stream
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.utils.common import any_to_str
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.token_counter import SPARK_TOKENS
@register_provider(LLMType.SPARK)
class SparkLLM(BaseLLM):
"""
用于讯飞星火大模型系列
参考https://github.com/iflytek/spark-ai-python"""
def __init__(self, config: LLMConfig):
self.config = config
logger.warning("SparkLLM当前方法无法支持异步运行。当你使用acompletion时并不能并行访问。")
self.cost_manager = CostManager(token_costs=SPARK_TOKENS)
self.model = self.config.domain
self._init_client()
def get_choice_text(self, rsp: dict) -> str:
return rsp["payload"]["choices"]["text"][-1]["content"]
def _init_client(self):
self.client = ChatSparkLLM(
spark_api_url=self.config.base_url,
spark_app_id=self.config.app_id,
spark_api_key=self.config.api_key,
spark_api_secret=self.config.api_secret,
spark_llm_domain=self.config.domain,
streaming=True,
)
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
pass
def _system_msg(self, msg: str) -> SystemMessage:
return _convert_to_message(msg)
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = USE_CONFIG_TIMEOUT) -> str:
# 不支持
# logger.warning("当前方法无法支持异步运行。当你使用acompletion时并不能并行访问。")
w = GetMessageFromWeb(messages, self.config)
return w.run()
def _user_msg(self, msg: str, **kwargs) -> HumanMessage:
return _convert_to_message(msg)
def _assistant_msg(self, msg: str) -> AIMessage:
return _convert_to_message(msg)
def get_choice_text(self, rsp: LLMResult) -> str:
return rsp.generations[0][0].text
def get_usage(self, response: LLMResult):
message = response.generations[0][0].message
if hasattr(message, "additional_kwargs"):
return message.additional_kwargs.get("token_usage", {})
else:
return {}
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
pass
response = await self.acreate(messages, stream=False)
usage = self.get_usage(response)
self._update_costs(usage)
return response
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
# 不支持异步
w = GetMessageFromWeb(messages, self.config)
return w.run()
return await self._achat_completion(messages, timeout)
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
response = await self.acreate(messages, stream=True)
collected_content = []
usage = {}
async for chunk in response:
collected_content.append(chunk.content)
log_llm_stream(chunk.content)
if hasattr(chunk, "additional_kwargs"):
usage = chunk.additional_kwargs.get("token_usage", {})
class GetMessageFromWeb:
class WsParam:
"""
该类适合讯飞星火大部分接口的调用
输入 app_id, api_key, api_secret, spark_url以初始化
create_url方法返回接口url
"""
log_llm_stream("\n")
self._update_costs(usage)
full_content = "".join(collected_content)
return full_content
# 初始化
def __init__(self, app_id, api_key, api_secret, spark_url, message=None):
self.app_id = app_id
self.api_key = api_key
self.api_secret = api_secret
self.host = urlparse(spark_url).netloc
self.path = urlparse(spark_url).path
self.spark_url = spark_url
self.message = message
def _extract_assistant_rsp(self, context: list[BaseMessage]) -> str:
return "\n".join([i.content for i in context if "AIMessage" in any_to_str(i)])
# 生成url
def create_url(self):
# 生成RFC1123格式的时间戳
now = datetime.datetime.now()
date = format_date_time(mktime(now.timetuple()))
# 拼接字符串
signature_origin = "host: " + self.host + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + self.path + " HTTP/1.1"
# 进行hmac-sha256进行加密
signature_sha = hmac.new(
self.api_secret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256
).digest()
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding="utf-8")
authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8")
# 将请求的鉴权参数组合为字典
v = {"authorization": authorization, "date": date, "host": self.host}
# 拼接鉴权参数生成url
url = self.spark_url + "?" + urlencode(v)
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释比对相同参数时生成的url与自己代码生成的url是否一致
return url
def __init__(self, text, config: LLMConfig):
self.text = text
self.ret = ""
self.spark_appid = config.app_id
self.spark_api_secret = config.api_secret
self.spark_api_key = config.api_key
self.domain = config.domain
self.spark_url = config.base_url
def on_message(self, ws, message):
data = json.loads(message)
code = data["header"]["code"]
if code != 0:
ws.close() # 请求错误则关闭socket
logger.critical(f"回答获取失败,响应信息反序列化之后为: {data}")
return
async def acreate(self, messages: list[dict], stream: bool = True):
messages = convert_to_messages(messages)
if stream:
return self.client.astream(messages)
else:
choices = data["payload"]["choices"]
# seq = choices["seq"] # 服务端是流式返回seq为返回的数据序号
status = choices["status"] # 服务端是流式返回status用于判断信息是否传送完毕
content = choices["text"][0]["content"] # 本次接收到的回答文本
self.ret += content
if status == 2:
ws.close()
# 收到websocket错误的处理
def on_error(self, ws, error):
# on_message方法处理接收到的信息出现任何错误都会调用这个方法
logger.critical(f"通讯连接出错,【错误提示: {error}")
# 收到websocket关闭的处理
def on_close(self, ws, one, two):
pass
# 处理请求数据
def gen_params(self):
data = {
"header": {"app_id": self.spark_appid, "uid": "1234"},
"parameter": {
"chat": {
# domain为必传参数
"domain": self.domain,
# 以下为可微调,非必传参数
# 注意官方建议temperature和top_k修改一个即可
"max_tokens": 2048, # 默认2048模型回答的tokens的最大长度即允许它输出文本的最长字数
"temperature": 0.5, # 取值为[0,1],默认为0.5。取值越高随机性越强、发散性越高,即相同的问题得到的不同答案的可能性越高
"top_k": 4, # 取值为[16],默认为4。从k个候选中随机选择一个非等概率
}
},
"payload": {"message": {"text": self.text}},
}
return data
def send(self, ws, *args):
data = json.dumps(self.gen_params())
ws.send(data)
# 收到websocket连接建立的处理
def on_open(self, ws):
thread.start_new_thread(self.send, (ws,))
# 处理收到的 websocket消息出现任何错误调用on_error方法
def run(self):
return self._run(self.text)
def _run(self, text_list):
ws_param = self.WsParam(self.spark_appid, self.spark_api_key, self.spark_api_secret, self.spark_url, text_list)
ws_url = ws_param.create_url()
websocket.enableTrace(False) # 默认禁用 WebSocket 的跟踪功能
ws = websocket.WebSocketApp(
ws_url, on_message=self.on_message, on_error=self.on_error, on_close=self.on_close, on_open=self.on_open
)
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
return self.ret
return await self.client.agenerate([messages])

View file

@ -89,9 +89,9 @@ class RAGEmbeddingFactory(GenericFactory):
return OllamaEmbedding(**params)
def _try_set_model_and_batch_size(self, params: dict):
"""Set the model and embed_batch_size only when they are specified."""
"""Set the model_name and embed_batch_size only when they are specified."""
if config.embedding.model:
params["model"] = config.embedding.model
params["model_name"] = config.embedding.model
if config.embedding.embed_batch_size:
params["embed_batch_size"] = config.embedding.embed_batch_size

View file

@ -136,4 +136,4 @@ class Assistant(Role):
try:
self.memory = BrainMemory(**m)
except Exception as e:
logger.exception(f"load error:{e}, data:{jsn}")
logger.exception(f"load error:{e}, data:{m}")

View file

@ -53,30 +53,30 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
await page.wait_for_load_state("networkidle")
await page.wait_for_selector("div#container", state="attached")
# mermaid_config = {}
mermaid_config = {}
background_color = "#ffffff"
# my_css = ""
my_css = ""
await page.evaluate(f'document.body.style.background = "{background_color}";')
# metadata = await page.evaluate(
# """async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
# const { mermaid, zenuml } = globalThis;
# await mermaid.registerExternalDiagrams([zenuml]);
# mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
# const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
# document.getElementById('container').innerHTML = svg;
# const svgElement = document.querySelector('svg');
# svgElement.style.backgroundColor = backgroundColor;
#
# if (myCSS) {
# const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
# style.appendChild(document.createTextNode(myCSS));
# svgElement.appendChild(style);
# }
#
# }""",
# [mermaid_code, mermaid_config, my_css, background_color],
# )
await page.evaluate(
"""async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
const { mermaid, zenuml } = globalThis;
await mermaid.registerExternalDiagrams([zenuml]);
mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
document.getElementById('container').innerHTML = svg;
const svgElement = document.querySelector('svg');
svgElement.style.backgroundColor = backgroundColor;
if (myCSS) {
const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
style.appendChild(document.createTextNode(myCSS));
svgElement.appendChild(style);
}
}""",
[mermaid_code, mermaid_config, my_css, background_color],
)
if "svg" in suffixes:
svg_xml = await page.evaluate(

View file

@ -55,29 +55,29 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
await page.goto(mermaid_html_url)
await page.querySelector("div#container")
# mermaid_config = {}
mermaid_config = {}
background_color = "#ffffff"
# my_css = ""
my_css = ""
await page.evaluate(f'document.body.style.background = "{background_color}";')
# metadata = await page.evaluate(
# """async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
# const { mermaid, zenuml } = globalThis;
# await mermaid.registerExternalDiagrams([zenuml]);
# mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
# const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
# document.getElementById('container').innerHTML = svg;
# const svgElement = document.querySelector('svg');
# svgElement.style.backgroundColor = backgroundColor;
#
# if (myCSS) {
# const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
# style.appendChild(document.createTextNode(myCSS));
# svgElement.appendChild(style);
# }
# }""",
# [mermaid_code, mermaid_config, my_css, background_color],
# )
await page.evaluate(
"""async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
const { mermaid, zenuml } = globalThis;
await mermaid.registerExternalDiagrams([zenuml]);
mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
document.getElementById('container').innerHTML = svg;
const svgElement = document.querySelector('svg');
svgElement.style.backgroundColor = backgroundColor;
if (myCSS) {
const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
style.appendChild(document.createTextNode(myCSS));
svgElement.appendChild(style);
}
}""",
[mermaid_code, mermaid_config, my_css, background_color],
)
if "svg" in suffixes:
svg_xml = await page.evaluate(

View file

@ -68,6 +68,15 @@ TOKEN_COSTS = {
"openai/gpt-4-turbo-preview": {"prompt": 0.01, "completion": 0.03},
"deepseek-chat": {"prompt": 0.00014, "completion": 0.00028},
"deepseek-coder": {"prompt": 0.00014, "completion": 0.00028},
# For ark model https://www.volcengine.com/docs/82379/1099320
"doubao-lite-4k-240515": {"prompt": 0.000042, "completion": 0.000084},
"doubao-lite-32k-240515": {"prompt": 0.000042, "completion": 0.000084},
"doubao-lite-128k-240515": {"prompt": 0.00011, "completion": 0.00013},
"doubao-pro-4k-240515": {"prompt": 0.00011, "completion": 0.00028},
"doubao-pro-32k-240515": {"prompt": 0.00011, "completion": 0.00028},
"doubao-pro-128k-240515": {"prompt": 0.0007, "completion": 0.0012},
"llama3-70b-llama3-70b-instruct": {"prompt": 0.0, "completion": 0.0},
"llama3-8b-llama3-8b-instruct": {"prompt": 0.0, "completion": 0.0},
}
@ -213,6 +222,12 @@ TOKEN_MAX = {
"openai/gpt-4-turbo-preview": 128000,
"deepseek-chat": 32768,
"deepseek-coder": 16385,
"doubao-lite-4k-240515": 4000,
"doubao-lite-32k-240515": 32000,
"doubao-lite-128k-240515": 128000,
"doubao-pro-4k-240515": 4000,
"doubao-pro-32k-240515": 32000,
"doubao-pro-128k-240515": 128000,
}
# For Amazon Bedrock US region
@ -262,6 +277,14 @@ BEDROCK_TOKEN_COSTS = {
"ai21.j2-ultra-v1": {"prompt": 0.0188, "completion": 0.0188},
}
# https://xinghuo.xfyun.cn/sparkapi?scr=price
SPARK_TOKENS = {
"general": {"prompt": 0.0, "completion": 0.0}, # Spark-Lite
"generalv2": {"prompt": 0.0188, "completion": 0.0188}, # Spark V2.0
"generalv3": {"prompt": 0.0035, "completion": 0.0035}, # Spark Pro
"generalv3.5": {"prompt": 0.0035, "completion": 0.0035}, # Spark3.5 Max
}
def count_input_tokens(messages, model="gpt-3.5-turbo-0125"):
"""Return the number of tokens used by a list of messages."""