mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
Merge pull request #1291 from usamimeri/update_spark
Feat: support async spark api
This commit is contained in:
commit
519be03d4a
5 changed files with 123 additions and 191 deletions
10
config/examples/spark_lite.yaml
Normal file
10
config/examples/spark_lite.yaml
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
# 适用于讯飞星火的spark-lite 参考 https://www.xfyun.cn/doc/spark/Web.html#_2-function-call%E8%AF%B4%E6%98%8E
|
||||
|
||||
llm:
|
||||
api_type: "spark"
|
||||
# 对应模型的url 参考 https://www.xfyun.cn/doc/spark/Web.html#_1-%E6%8E%A5%E5%8F%A3%E8%AF%B4%E6%98%8E
|
||||
base_url: "ws(s)://spark-api.xf-yun.com/v1.1/chat"
|
||||
app_id: ""
|
||||
api_key: ""
|
||||
api_secret: ""
|
||||
domain: "general" # 取值为 [general,generalv2,generalv3,generalv3.5] 和url一一对应
|
||||
|
|
@ -1,175 +1,95 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@File : spark_api.py
|
||||
"""
|
||||
import _thread as thread
|
||||
import base64
|
||||
import datetime
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import ssl
|
||||
from time import mktime
|
||||
from urllib.parse import urlencode, urlparse
|
||||
from wsgiref.handlers import format_date_time
|
||||
|
||||
import websocket # 使用websocket_client
|
||||
from sparkai.core.messages import _convert_to_message, convert_to_messages
|
||||
from sparkai.core.messages.ai import AIMessage
|
||||
from sparkai.core.messages.base import BaseMessage
|
||||
from sparkai.core.messages.human import HumanMessage
|
||||
from sparkai.core.messages.system import SystemMessage
|
||||
from sparkai.core.outputs.llm_result import LLMResult
|
||||
from sparkai.llm.llm import ChatSparkLLM
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.const import USE_CONFIG_TIMEOUT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.logs import log_llm_stream
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.utils.common import any_to_str
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.token_counter import SPARK_TOKENS
|
||||
|
||||
|
||||
@register_provider(LLMType.SPARK)
|
||||
class SparkLLM(BaseLLM):
|
||||
"""
|
||||
用于讯飞星火大模型系列
|
||||
参考:https://github.com/iflytek/spark-ai-python"""
|
||||
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.config = config
|
||||
logger.warning("SparkLLM:当前方法无法支持异步运行。当你使用acompletion时,并不能并行访问。")
|
||||
self.cost_manager = CostManager(token_costs=SPARK_TOKENS)
|
||||
self.model = self.config.domain
|
||||
self._init_client()
|
||||
|
||||
def get_choice_text(self, rsp: dict) -> str:
|
||||
return rsp["payload"]["choices"]["text"][-1]["content"]
|
||||
def _init_client(self):
|
||||
self.client = ChatSparkLLM(
|
||||
spark_api_url=self.config.base_url,
|
||||
spark_app_id=self.config.app_id,
|
||||
spark_api_key=self.config.api_key,
|
||||
spark_api_secret=self.config.api_secret,
|
||||
spark_llm_domain=self.config.domain,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
|
||||
pass
|
||||
def _system_msg(self, msg: str) -> SystemMessage:
|
||||
return _convert_to_message(msg)
|
||||
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = USE_CONFIG_TIMEOUT) -> str:
|
||||
# 不支持
|
||||
# logger.warning("当前方法无法支持异步运行。当你使用acompletion时,并不能并行访问。")
|
||||
w = GetMessageFromWeb(messages, self.config)
|
||||
return w.run()
|
||||
def _user_msg(self, msg: str, **kwargs) -> HumanMessage:
|
||||
return _convert_to_message(msg)
|
||||
|
||||
def _assistant_msg(self, msg: str) -> AIMessage:
|
||||
return _convert_to_message(msg)
|
||||
|
||||
def get_choice_text(self, rsp: LLMResult) -> str:
|
||||
return rsp.generations[0][0].text
|
||||
|
||||
def get_usage(self, response: LLMResult):
|
||||
message = response.generations[0][0].message
|
||||
if hasattr(message, "additional_kwargs"):
|
||||
return message.additional_kwargs.get("token_usage", {})
|
||||
else:
|
||||
return {}
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
|
||||
pass
|
||||
response = await self.acreate(messages, stream=False)
|
||||
usage = self.get_usage(response)
|
||||
self._update_costs(usage)
|
||||
return response
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
|
||||
# 不支持异步
|
||||
w = GetMessageFromWeb(messages, self.config)
|
||||
return w.run()
|
||||
return await self._achat_completion(messages, timeout)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
|
||||
response = await self.acreate(messages, stream=True)
|
||||
collected_content = []
|
||||
usage = {}
|
||||
async for chunk in response:
|
||||
collected_content.append(chunk.content)
|
||||
log_llm_stream(chunk.content)
|
||||
if hasattr(chunk, "additional_kwargs"):
|
||||
usage = chunk.additional_kwargs.get("token_usage", {})
|
||||
|
||||
class GetMessageFromWeb:
|
||||
class WsParam:
|
||||
"""
|
||||
该类适合讯飞星火大部分接口的调用。
|
||||
输入 app_id, api_key, api_secret, spark_url以初始化,
|
||||
create_url方法返回接口url
|
||||
"""
|
||||
log_llm_stream("\n")
|
||||
self._update_costs(usage)
|
||||
full_content = "".join(collected_content)
|
||||
return full_content
|
||||
|
||||
# 初始化
|
||||
def __init__(self, app_id, api_key, api_secret, spark_url, message=None):
|
||||
self.app_id = app_id
|
||||
self.api_key = api_key
|
||||
self.api_secret = api_secret
|
||||
self.host = urlparse(spark_url).netloc
|
||||
self.path = urlparse(spark_url).path
|
||||
self.spark_url = spark_url
|
||||
self.message = message
|
||||
def _extract_assistant_rsp(self, context: list[BaseMessage]) -> str:
|
||||
return "\n".join([i.content for i in context if "AIMessage" in any_to_str(i)])
|
||||
|
||||
# 生成url
|
||||
def create_url(self):
|
||||
# 生成RFC1123格式的时间戳
|
||||
now = datetime.datetime.now()
|
||||
date = format_date_time(mktime(now.timetuple()))
|
||||
|
||||
# 拼接字符串
|
||||
signature_origin = "host: " + self.host + "\n"
|
||||
signature_origin += "date: " + date + "\n"
|
||||
signature_origin += "GET " + self.path + " HTTP/1.1"
|
||||
|
||||
# 进行hmac-sha256进行加密
|
||||
signature_sha = hmac.new(
|
||||
self.api_secret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256
|
||||
).digest()
|
||||
|
||||
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding="utf-8")
|
||||
|
||||
authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
|
||||
|
||||
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8")
|
||||
|
||||
# 将请求的鉴权参数组合为字典
|
||||
v = {"authorization": authorization, "date": date, "host": self.host}
|
||||
# 拼接鉴权参数,生成url
|
||||
url = self.spark_url + "?" + urlencode(v)
|
||||
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
|
||||
return url
|
||||
|
||||
def __init__(self, text, config: LLMConfig):
|
||||
self.text = text
|
||||
self.ret = ""
|
||||
self.spark_appid = config.app_id
|
||||
self.spark_api_secret = config.api_secret
|
||||
self.spark_api_key = config.api_key
|
||||
self.domain = config.domain
|
||||
self.spark_url = config.base_url
|
||||
|
||||
def on_message(self, ws, message):
|
||||
data = json.loads(message)
|
||||
code = data["header"]["code"]
|
||||
|
||||
if code != 0:
|
||||
ws.close() # 请求错误,则关闭socket
|
||||
logger.critical(f"回答获取失败,响应信息反序列化之后为: {data}")
|
||||
return
|
||||
async def acreate(self, messages: list[dict], stream: bool = True):
|
||||
messages = convert_to_messages(messages)
|
||||
if stream:
|
||||
return self.client.astream(messages)
|
||||
else:
|
||||
choices = data["payload"]["choices"]
|
||||
# seq = choices["seq"] # 服务端是流式返回,seq为返回的数据序号
|
||||
status = choices["status"] # 服务端是流式返回,status用于判断信息是否传送完毕
|
||||
content = choices["text"][0]["content"] # 本次接收到的回答文本
|
||||
self.ret += content
|
||||
if status == 2:
|
||||
ws.close()
|
||||
|
||||
# 收到websocket错误的处理
|
||||
def on_error(self, ws, error):
|
||||
# on_message方法处理接收到的信息,出现任何错误,都会调用这个方法
|
||||
logger.critical(f"通讯连接出错,【错误提示: {error}】")
|
||||
|
||||
# 收到websocket关闭的处理
|
||||
def on_close(self, ws, one, two):
|
||||
pass
|
||||
|
||||
# 处理请求数据
|
||||
def gen_params(self):
|
||||
data = {
|
||||
"header": {"app_id": self.spark_appid, "uid": "1234"},
|
||||
"parameter": {
|
||||
"chat": {
|
||||
# domain为必传参数
|
||||
"domain": self.domain,
|
||||
# 以下为可微调,非必传参数
|
||||
# 注意:官方建议,temperature和top_k修改一个即可
|
||||
"max_tokens": 2048, # 默认2048,模型回答的tokens的最大长度,即允许它输出文本的最长字数
|
||||
"temperature": 0.5, # 取值为[0,1],默认为0.5。取值越高随机性越强、发散性越高,即相同的问题得到的不同答案的可能性越高
|
||||
"top_k": 4, # 取值为[1,6],默认为4。从k个候选中随机选择一个(非等概率)
|
||||
}
|
||||
},
|
||||
"payload": {"message": {"text": self.text}},
|
||||
}
|
||||
return data
|
||||
|
||||
def send(self, ws, *args):
|
||||
data = json.dumps(self.gen_params())
|
||||
ws.send(data)
|
||||
|
||||
# 收到websocket连接建立的处理
|
||||
def on_open(self, ws):
|
||||
thread.start_new_thread(self.send, (ws,))
|
||||
|
||||
# 处理收到的 websocket消息,出现任何错误,调用on_error方法
|
||||
def run(self):
|
||||
return self._run(self.text)
|
||||
|
||||
def _run(self, text_list):
|
||||
ws_param = self.WsParam(self.spark_appid, self.spark_api_key, self.spark_api_secret, self.spark_url, text_list)
|
||||
ws_url = ws_param.create_url()
|
||||
|
||||
websocket.enableTrace(False) # 默认禁用 WebSocket 的跟踪功能
|
||||
ws = websocket.WebSocketApp(
|
||||
ws_url, on_message=self.on_message, on_error=self.on_error, on_close=self.on_close, on_open=self.on_open
|
||||
)
|
||||
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
||||
return self.ret
|
||||
return await self.client.agenerate([messages])
|
||||
|
|
|
|||
|
|
@ -258,6 +258,14 @@ BEDROCK_TOKEN_COSTS = {
|
|||
"ai21.j2-ultra-v1": {"prompt": 0.0188, "completion": 0.0188},
|
||||
}
|
||||
|
||||
# https://xinghuo.xfyun.cn/sparkapi?scr=price
|
||||
SPARK_TOKENS = {
|
||||
"general": {"prompt": 0.0, "completion": 0.0}, # Spark-Lite
|
||||
"generalv2": {"prompt": 0.0188, "completion": 0.0188}, # Spark V2.0
|
||||
"generalv3": {"prompt": 0.0035, "completion": 0.0035}, # Spark Pro
|
||||
"generalv3.5": {"prompt": 0.0035, "completion": 0.0035}, # Spark3.5 Max
|
||||
}
|
||||
|
||||
|
||||
def count_input_tokens(messages, model="gpt-3.5-turbo-0125"):
|
||||
"""Return the number of tokens used by a list of messages."""
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ wrapt==1.15.0
|
|||
# azure-cognitiveservices-speech~=1.31.0 # Used by metagpt/tools/azure_tts.py
|
||||
#aioboto3~=12.4.0 # Used by metagpt/utils/s3.py
|
||||
aioredis~=2.0.1 # Used by metagpt/utils/redis.py
|
||||
websocket-client==1.6.2
|
||||
websocket-client~=1.8.0
|
||||
aiofiles==23.2.1
|
||||
gitpython==3.1.40
|
||||
zhipuai==2.0.1
|
||||
|
|
@ -71,3 +71,4 @@ dashscope==1.14.1
|
|||
rank-bm25==0.2.2 # for tool recommendation
|
||||
gymnasium==0.29.1
|
||||
boto3~=1.34.69
|
||||
spark_ai_python~=0.3.30
|
||||
|
|
@ -1,62 +1,55 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of spark api
|
||||
"""
|
||||
用于讯飞星火SDK的测试用例
|
||||
文档:https://www.xfyun.cn/doc/spark/Web.html
|
||||
"""
|
||||
|
||||
|
||||
from typing import AsyncIterator, List
|
||||
|
||||
import pytest
|
||||
from sparkai.core.messages.ai import AIMessage, AIMessageChunk
|
||||
from sparkai.core.outputs.chat_generation import ChatGeneration
|
||||
from sparkai.core.outputs.llm_result import LLMResult
|
||||
|
||||
from metagpt.provider.spark_api import GetMessageFromWeb, SparkLLM
|
||||
from tests.metagpt.provider.mock_llm_config import (
|
||||
mock_llm_config,
|
||||
mock_llm_config_spark,
|
||||
)
|
||||
from metagpt.provider.spark_api import SparkLLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config_spark
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
llm_general_chat_funcs_test,
|
||||
messages,
|
||||
prompt,
|
||||
resp_cont_tmpl,
|
||||
)
|
||||
|
||||
resp_cont = resp_cont_tmpl.format(name="Spark")
|
||||
USAGE = {
|
||||
"token_usage": {"question_tokens": 1000, "prompt_tokens": 1000, "completion_tokens": 1000, "total_tokens": 2000}
|
||||
}
|
||||
spark_agenerate_result = LLMResult(
|
||||
generations=[[ChatGeneration(text=resp_cont, message=AIMessage(content=resp_cont, additional_kwargs=USAGE))]]
|
||||
)
|
||||
|
||||
chunks = [AIMessageChunk(content=resp_cont), AIMessageChunk(content="", additional_kwargs=USAGE)]
|
||||
|
||||
|
||||
class MockWebSocketApp(object):
|
||||
def __init__(self, ws_url, on_message=None, on_error=None, on_close=None, on_open=None):
|
||||
pass
|
||||
|
||||
def run_forever(self, sslopt=None):
|
||||
pass
|
||||
async def chunk_iterator(chunks: List[AIMessageChunk]) -> AsyncIterator[AIMessageChunk]:
|
||||
for chunk in chunks:
|
||||
yield chunk
|
||||
|
||||
|
||||
def test_get_msg_from_web(mocker):
|
||||
mocker.patch("websocket.WebSocketApp", MockWebSocketApp)
|
||||
|
||||
get_msg_from_web = GetMessageFromWeb(prompt, mock_llm_config)
|
||||
assert get_msg_from_web.gen_params()["parameter"]["chat"]["domain"] == "mock_domain"
|
||||
|
||||
ret = get_msg_from_web.run()
|
||||
assert ret == ""
|
||||
|
||||
|
||||
def mock_spark_get_msg_from_web_run(self) -> str:
|
||||
return resp_cont
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spark_aask(mocker):
|
||||
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
|
||||
|
||||
llm = SparkLLM(mock_llm_config_spark)
|
||||
|
||||
resp = await llm.aask("Hello!")
|
||||
assert resp == resp_cont
|
||||
async def mock_spark_acreate(self, messages, stream):
|
||||
if stream:
|
||||
return chunk_iterator(chunks)
|
||||
else:
|
||||
return spark_agenerate_result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spark_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
|
||||
mocker.patch("metagpt.provider.spark_api.SparkLLM.acreate", mock_spark_acreate)
|
||||
|
||||
spark_llm = SparkLLM(mock_llm_config)
|
||||
spark_llm = SparkLLM(mock_llm_config_spark)
|
||||
|
||||
resp = await spark_llm.acompletion([])
|
||||
assert resp == resp_cont
|
||||
resp = await spark_llm.acompletion([messages])
|
||||
assert spark_llm.get_choice_text(resp) == resp_cont
|
||||
|
||||
await llm_general_chat_funcs_test(spark_llm, prompt, prompt, resp_cont)
|
||||
await llm_general_chat_funcs_test(spark_llm, prompt, messages, resp_cont)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue