diff --git a/metagpt/provider/spark_api.py b/metagpt/provider/spark_api.py index 594267259..36b7efedd 100644 --- a/metagpt/provider/spark_api.py +++ b/metagpt/provider/spark_api.py @@ -1,175 +1,78 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -""" -@File : spark_api.py -""" -import _thread as thread -import base64 -import datetime -import hashlib -import hmac -import json -import ssl -from time import mktime -from urllib.parse import urlencode, urlparse -from wsgiref.handlers import format_date_time -import websocket # 使用websocket_client +from sparkai.core.messages import _convert_to_message, convert_to_messages +from sparkai.llm.llm import ChatSparkLLM from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.const import USE_CONFIG_TIMEOUT -from metagpt.logs import logger +from metagpt.logs import log_llm_stream from metagpt.provider.base_llm import BaseLLM from metagpt.provider.llm_provider_registry import register_provider +# from sparkai.schema import LLMResult, HumanMessage, AIMessage 由于其使用Pydantic V1,导入会报错 + @register_provider(LLMType.SPARK) class SparkLLM(BaseLLM): + """ + 用于讯飞星火大模型系列 + 参考:https://github.com/iflytek/spark-ai-python""" + def __init__(self, config: LLMConfig): self.config = config - logger.warning("SparkLLM:当前方法无法支持异步运行。当你使用acompletion时,并不能并行访问。") + self._init_client() - def get_choice_text(self, rsp: dict) -> str: - return rsp["payload"]["choices"]["text"][-1]["content"] + def _init_client(self): + self.client = ChatSparkLLM( + spark_api_url=self.config.base_url, + spark_app_id=self.config.app_id, + spark_api_key=self.config.api_key, + spark_api_secret=self.config.api_secret, + spark_llm_domain=self.config.domain, + streaming=True, + ) - async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str: - pass + def _system_msg(self, msg: str): + return _convert_to_message(msg) - async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = USE_CONFIG_TIMEOUT) -> str: - # 不支持 - # logger.warning("当前方法无法支持异步运行。当你使用acompletion时,并不能并行访问。") - w = GetMessageFromWeb(messages, self.config) - return w.run() + def _user_msg(self, msg: str, **kwargs): + return _convert_to_message(msg) + + def _assistant_msg(self, msg: str): + return _convert_to_message(msg) + + def get_choice_text(self, rsp) -> str: + return rsp.generations[0][0].text + + def get_usage(self, response): + message = response.generations[0][0].message + if hasattr(message, "additional_kwargs"): + return message.additional_kwargs.get("token_usage", {}) + else: + return {} async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): - pass + messages = convert_to_messages(messages) + response = await self.client.agenerate([messages]) + usage = self.get_usage(response) + self._update_costs(usage) + return response async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): - # 不支持异步 - w = GetMessageFromWeb(messages, self.config) - return w.run() + return self._achat_completion(messages, timeout) + async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str: + response = self.client.astream(messages) + collected_content = [] + usage = {} + async for chunk in response: + collected_content.append(chunk.content) + log_llm_stream(chunk.content) + if hasattr(chunk, "additional_kwargs"): + usage = chunk.additional_kwargs.get("token_usage", {}) -class GetMessageFromWeb: - class WsParam: - """ - 该类适合讯飞星火大部分接口的调用。 - 输入 app_id, api_key, api_secret, spark_url以初始化, - create_url方法返回接口url - """ - - # 初始化 - def __init__(self, app_id, api_key, api_secret, spark_url, message=None): - self.app_id = app_id - self.api_key = api_key - self.api_secret = api_secret - self.host = urlparse(spark_url).netloc - self.path = urlparse(spark_url).path - self.spark_url = spark_url - self.message = message - - # 生成url - def create_url(self): - # 生成RFC1123格式的时间戳 - now = datetime.datetime.now() - date = format_date_time(mktime(now.timetuple())) - - # 拼接字符串 - signature_origin = "host: " + self.host + "\n" - signature_origin += "date: " + date + "\n" - signature_origin += "GET " + self.path + " HTTP/1.1" - - # 进行hmac-sha256进行加密 - signature_sha = hmac.new( - self.api_secret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256 - ).digest() - - signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding="utf-8") - - authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"' - - authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8") - - # 将请求的鉴权参数组合为字典 - v = {"authorization": authorization, "date": date, "host": self.host} - # 拼接鉴权参数,生成url - url = self.spark_url + "?" + urlencode(v) - # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致 - return url - - def __init__(self, text, config: LLMConfig): - self.text = text - self.ret = "" - self.spark_appid = config.app_id - self.spark_api_secret = config.api_secret - self.spark_api_key = config.api_key - self.domain = config.domain - self.spark_url = config.base_url - - def on_message(self, ws, message): - data = json.loads(message) - code = data["header"]["code"] - - if code != 0: - ws.close() # 请求错误,则关闭socket - logger.critical(f"回答获取失败,响应信息反序列化之后为: {data}") - return - else: - choices = data["payload"]["choices"] - # seq = choices["seq"] # 服务端是流式返回,seq为返回的数据序号 - status = choices["status"] # 服务端是流式返回,status用于判断信息是否传送完毕 - content = choices["text"][0]["content"] # 本次接收到的回答文本 - self.ret += content - if status == 2: - ws.close() - - # 收到websocket错误的处理 - def on_error(self, ws, error): - # on_message方法处理接收到的信息,出现任何错误,都会调用这个方法 - logger.critical(f"通讯连接出错,【错误提示: {error}】") - - # 收到websocket关闭的处理 - def on_close(self, ws, one, two): - pass - - # 处理请求数据 - def gen_params(self): - data = { - "header": {"app_id": self.spark_appid, "uid": "1234"}, - "parameter": { - "chat": { - # domain为必传参数 - "domain": self.domain, - # 以下为可微调,非必传参数 - # 注意:官方建议,temperature和top_k修改一个即可 - "max_tokens": 2048, # 默认2048,模型回答的tokens的最大长度,即允许它输出文本的最长字数 - "temperature": 0.5, # 取值为[0,1],默认为0.5。取值越高随机性越强、发散性越高,即相同的问题得到的不同答案的可能性越高 - "top_k": 4, # 取值为[1,6],默认为4。从k个候选中随机选择一个(非等概率) - } - }, - "payload": {"message": {"text": self.text}}, - } - return data - - def send(self, ws, *args): - data = json.dumps(self.gen_params()) - ws.send(data) - - # 收到websocket连接建立的处理 - def on_open(self, ws): - thread.start_new_thread(self.send, (ws,)) - - # 处理收到的 websocket消息,出现任何错误,调用on_error方法 - def run(self): - return self._run(self.text) - - def _run(self, text_list): - ws_param = self.WsParam(self.spark_appid, self.spark_api_key, self.spark_api_secret, self.spark_url, text_list) - ws_url = ws_param.create_url() - - websocket.enableTrace(False) # 默认禁用 WebSocket 的跟踪功能 - ws = websocket.WebSocketApp( - ws_url, on_message=self.on_message, on_error=self.on_error, on_close=self.on_close, on_open=self.on_open - ) - ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) - return self.ret + log_llm_stream("\n") + self._update_costs(usage) + full_content = "".join(collected_content) + return full_content diff --git a/requirements.txt b/requirements.txt index b40972efb..0d006d0b9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -44,7 +44,7 @@ wrapt==1.15.0 # azure-cognitiveservices-speech~=1.31.0 # Used by metagpt/tools/azure_tts.py #aioboto3~=12.4.0 # Used by metagpt/utils/s3.py aioredis~=2.0.1 # Used by metagpt/utils/redis.py -websocket-client==1.6.2 +websocket-client~=1.6.2 aiofiles==23.2.1 gitpython==3.1.40 zhipuai==2.0.1 diff --git a/tests/metagpt/provider/test_spark_api.py b/tests/metagpt/provider/test_spark_api.py index 9c278267d..043d98d13 100644 --- a/tests/metagpt/provider/test_spark_api.py +++ b/tests/metagpt/provider/test_spark_api.py @@ -1,62 +1,3 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # @Desc : the unittest of spark api - -import pytest - -from metagpt.provider.spark_api import GetMessageFromWeb, SparkLLM -from tests.metagpt.provider.mock_llm_config import ( - mock_llm_config, - mock_llm_config_spark, -) -from tests.metagpt.provider.req_resp_const import ( - llm_general_chat_funcs_test, - prompt, - resp_cont_tmpl, -) - -resp_cont = resp_cont_tmpl.format(name="Spark") - - -class MockWebSocketApp(object): - def __init__(self, ws_url, on_message=None, on_error=None, on_close=None, on_open=None): - pass - - def run_forever(self, sslopt=None): - pass - - -def test_get_msg_from_web(mocker): - mocker.patch("websocket.WebSocketApp", MockWebSocketApp) - - get_msg_from_web = GetMessageFromWeb(prompt, mock_llm_config) - assert get_msg_from_web.gen_params()["parameter"]["chat"]["domain"] == "mock_domain" - - ret = get_msg_from_web.run() - assert ret == "" - - -def mock_spark_get_msg_from_web_run(self) -> str: - return resp_cont - - -@pytest.mark.asyncio -async def test_spark_aask(mocker): - mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run) - - llm = SparkLLM(mock_llm_config_spark) - - resp = await llm.aask("Hello!") - assert resp == resp_cont - - -@pytest.mark.asyncio -async def test_spark_acompletion(mocker): - mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run) - - spark_llm = SparkLLM(mock_llm_config) - - resp = await spark_llm.acompletion([]) - assert resp == resp_cont - - await llm_general_chat_funcs_test(spark_llm, prompt, prompt, resp_cont)