implement spark chat

This commit is contained in:
usamimeri_renko 2024-05-22 21:10:55 +08:00
parent e03db313a2
commit 1b13a28a77
3 changed files with 56 additions and 212 deletions

View file

@ -1,175 +1,78 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@File : spark_api.py
"""
import _thread as thread
import base64
import datetime
import hashlib
import hmac
import json
import ssl
from time import mktime
from urllib.parse import urlencode, urlparse
from wsgiref.handlers import format_date_time
import websocket # 使用websocket_client
from sparkai.core.messages import _convert_to_message, convert_to_messages
from sparkai.llm.llm import ChatSparkLLM
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.logs import logger
from metagpt.logs import log_llm_stream
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
# from sparkai.schema import LLMResult, HumanMessage, AIMessage 由于其使用Pydantic V1,导入会报错
@register_provider(LLMType.SPARK)
class SparkLLM(BaseLLM):
"""
用于讯飞星火大模型系列
参考https://github.com/iflytek/spark-ai-python"""
def __init__(self, config: LLMConfig):
self.config = config
logger.warning("SparkLLM当前方法无法支持异步运行。当你使用acompletion时并不能并行访问。")
self._init_client()
def get_choice_text(self, rsp: dict) -> str:
return rsp["payload"]["choices"]["text"][-1]["content"]
def _init_client(self):
self.client = ChatSparkLLM(
spark_api_url=self.config.base_url,
spark_app_id=self.config.app_id,
spark_api_key=self.config.api_key,
spark_api_secret=self.config.api_secret,
spark_llm_domain=self.config.domain,
streaming=True,
)
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
pass
def _system_msg(self, msg: str):
return _convert_to_message(msg)
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = USE_CONFIG_TIMEOUT) -> str:
# 不支持
# logger.warning("当前方法无法支持异步运行。当你使用acompletion时并不能并行访问。")
w = GetMessageFromWeb(messages, self.config)
return w.run()
def _user_msg(self, msg: str, **kwargs):
return _convert_to_message(msg)
def _assistant_msg(self, msg: str):
return _convert_to_message(msg)
def get_choice_text(self, rsp) -> str:
return rsp.generations[0][0].text
def get_usage(self, response):
message = response.generations[0][0].message
if hasattr(message, "additional_kwargs"):
return message.additional_kwargs.get("token_usage", {})
else:
return {}
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
pass
messages = convert_to_messages(messages)
response = await self.client.agenerate([messages])
usage = self.get_usage(response)
self._update_costs(usage)
return response
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
# 不支持异步
w = GetMessageFromWeb(messages, self.config)
return w.run()
return self._achat_completion(messages, timeout)
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
response = self.client.astream(messages)
collected_content = []
usage = {}
async for chunk in response:
collected_content.append(chunk.content)
log_llm_stream(chunk.content)
if hasattr(chunk, "additional_kwargs"):
usage = chunk.additional_kwargs.get("token_usage", {})
class GetMessageFromWeb:
class WsParam:
"""
该类适合讯飞星火大部分接口的调用
输入 app_id, api_key, api_secret, spark_url以初始化
create_url方法返回接口url
"""
# 初始化
def __init__(self, app_id, api_key, api_secret, spark_url, message=None):
self.app_id = app_id
self.api_key = api_key
self.api_secret = api_secret
self.host = urlparse(spark_url).netloc
self.path = urlparse(spark_url).path
self.spark_url = spark_url
self.message = message
# 生成url
def create_url(self):
# 生成RFC1123格式的时间戳
now = datetime.datetime.now()
date = format_date_time(mktime(now.timetuple()))
# 拼接字符串
signature_origin = "host: " + self.host + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + self.path + " HTTP/1.1"
# 进行hmac-sha256进行加密
signature_sha = hmac.new(
self.api_secret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256
).digest()
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding="utf-8")
authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8")
# 将请求的鉴权参数组合为字典
v = {"authorization": authorization, "date": date, "host": self.host}
# 拼接鉴权参数生成url
url = self.spark_url + "?" + urlencode(v)
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释比对相同参数时生成的url与自己代码生成的url是否一致
return url
def __init__(self, text, config: LLMConfig):
self.text = text
self.ret = ""
self.spark_appid = config.app_id
self.spark_api_secret = config.api_secret
self.spark_api_key = config.api_key
self.domain = config.domain
self.spark_url = config.base_url
def on_message(self, ws, message):
data = json.loads(message)
code = data["header"]["code"]
if code != 0:
ws.close() # 请求错误则关闭socket
logger.critical(f"回答获取失败,响应信息反序列化之后为: {data}")
return
else:
choices = data["payload"]["choices"]
# seq = choices["seq"] # 服务端是流式返回seq为返回的数据序号
status = choices["status"] # 服务端是流式返回status用于判断信息是否传送完毕
content = choices["text"][0]["content"] # 本次接收到的回答文本
self.ret += content
if status == 2:
ws.close()
# 收到websocket错误的处理
def on_error(self, ws, error):
# on_message方法处理接收到的信息出现任何错误都会调用这个方法
logger.critical(f"通讯连接出错,【错误提示: {error}")
# 收到websocket关闭的处理
def on_close(self, ws, one, two):
pass
# 处理请求数据
def gen_params(self):
data = {
"header": {"app_id": self.spark_appid, "uid": "1234"},
"parameter": {
"chat": {
# domain为必传参数
"domain": self.domain,
# 以下为可微调,非必传参数
# 注意官方建议temperature和top_k修改一个即可
"max_tokens": 2048, # 默认2048模型回答的tokens的最大长度即允许它输出文本的最长字数
"temperature": 0.5, # 取值为[0,1],默认为0.5。取值越高随机性越强、发散性越高,即相同的问题得到的不同答案的可能性越高
"top_k": 4, # 取值为[16],默认为4。从k个候选中随机选择一个非等概率
}
},
"payload": {"message": {"text": self.text}},
}
return data
def send(self, ws, *args):
data = json.dumps(self.gen_params())
ws.send(data)
# 收到websocket连接建立的处理
def on_open(self, ws):
thread.start_new_thread(self.send, (ws,))
# 处理收到的 websocket消息出现任何错误调用on_error方法
def run(self):
return self._run(self.text)
def _run(self, text_list):
ws_param = self.WsParam(self.spark_appid, self.spark_api_key, self.spark_api_secret, self.spark_url, text_list)
ws_url = ws_param.create_url()
websocket.enableTrace(False) # 默认禁用 WebSocket 的跟踪功能
ws = websocket.WebSocketApp(
ws_url, on_message=self.on_message, on_error=self.on_error, on_close=self.on_close, on_open=self.on_open
)
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
return self.ret
log_llm_stream("\n")
self._update_costs(usage)
full_content = "".join(collected_content)
return full_content

View file

@ -44,7 +44,7 @@ wrapt==1.15.0
# azure-cognitiveservices-speech~=1.31.0 # Used by metagpt/tools/azure_tts.py
#aioboto3~=12.4.0 # Used by metagpt/utils/s3.py
aioredis~=2.0.1 # Used by metagpt/utils/redis.py
websocket-client==1.6.2
websocket-client~=1.6.2
aiofiles==23.2.1
gitpython==3.1.40
zhipuai==2.0.1

View file

@ -1,62 +1,3 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the unittest of spark api
import pytest
from metagpt.provider.spark_api import GetMessageFromWeb, SparkLLM
from tests.metagpt.provider.mock_llm_config import (
mock_llm_config,
mock_llm_config_spark,
)
from tests.metagpt.provider.req_resp_const import (
llm_general_chat_funcs_test,
prompt,
resp_cont_tmpl,
)
resp_cont = resp_cont_tmpl.format(name="Spark")
class MockWebSocketApp(object):
def __init__(self, ws_url, on_message=None, on_error=None, on_close=None, on_open=None):
pass
def run_forever(self, sslopt=None):
pass
def test_get_msg_from_web(mocker):
mocker.patch("websocket.WebSocketApp", MockWebSocketApp)
get_msg_from_web = GetMessageFromWeb(prompt, mock_llm_config)
assert get_msg_from_web.gen_params()["parameter"]["chat"]["domain"] == "mock_domain"
ret = get_msg_from_web.run()
assert ret == ""
def mock_spark_get_msg_from_web_run(self) -> str:
return resp_cont
@pytest.mark.asyncio
async def test_spark_aask(mocker):
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
llm = SparkLLM(mock_llm_config_spark)
resp = await llm.aask("Hello!")
assert resp == resp_cont
@pytest.mark.asyncio
async def test_spark_acompletion(mocker):
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
spark_llm = SparkLLM(mock_llm_config)
resp = await spark_llm.acompletion([])
assert resp == resp_cont
await llm_general_chat_funcs_test(spark_llm, prompt, prompt, resp_cont)