diff --git a/config/examples/spark_lite.yaml b/config/examples/spark_lite.yaml index 5164e73e5..15898e019 100644 --- a/config/examples/spark_lite.yaml +++ b/config/examples/spark_lite.yaml @@ -1,7 +1,7 @@ # 适用于讯飞星火的spark-lite 参考 https://www.xfyun.cn/doc/spark/Web.html#_2-function-call%E8%AF%B4%E6%98%8E llm: - api_type: 'spark' + api_type: "spark" # 对应模型的url 参考 https://www.xfyun.cn/doc/spark/Web.html#_1-%E6%8E%A5%E5%8F%A3%E8%AF%B4%E6%98%8E base_url: "ws(s)://spark-api.xf-yun.com/v1.1/chat" app_id: "" diff --git a/metagpt/provider/spark_api.py b/metagpt/provider/spark_api.py index 93fc533ea..bdac050d3 100644 --- a/metagpt/provider/spark_api.py +++ b/metagpt/provider/spark_api.py @@ -2,6 +2,11 @@ # -*- coding: utf-8 -*- from sparkai.core.messages import _convert_to_message, convert_to_messages +from sparkai.core.messages.ai import AIMessage +from sparkai.core.messages.base import BaseMessage +from sparkai.core.messages.human import HumanMessage +from sparkai.core.messages.system import SystemMessage +from sparkai.core.outputs.llm_result import LLMResult from sparkai.llm.llm import ChatSparkLLM from metagpt.configs.llm_config import LLMConfig, LLMType @@ -9,8 +14,6 @@ from metagpt.const import USE_CONFIG_TIMEOUT from metagpt.logs import log_llm_stream from metagpt.provider.base_llm import BaseLLM from metagpt.provider.llm_provider_registry import register_provider - -# from sparkai.schema import LLMResult, HumanMessage, AIMessage # 由于其使用Pydantic V1,导入会报错 from metagpt.utils.common import any_to_str from metagpt.utils.cost_manager import CostManager from metagpt.utils.token_counter import SPARK_TOKENS @@ -38,19 +41,19 @@ class SparkLLM(BaseLLM): streaming=True, ) - def _system_msg(self, msg: str): + def _system_msg(self, msg: str) -> SystemMessage: return _convert_to_message(msg) - def _user_msg(self, msg: str, **kwargs): + def _user_msg(self, msg: str, **kwargs) -> HumanMessage: return _convert_to_message(msg) - def _assistant_msg(self, msg: str): + def _assistant_msg(self, msg: str) -> AIMessage: return _convert_to_message(msg) - def get_choice_text(self, rsp) -> str: + def get_choice_text(self, rsp: LLMResult) -> str: return rsp.generations[0][0].text - def get_usage(self, response): + def get_usage(self, response: LLMResult): message = response.generations[0][0].message if hasattr(message, "additional_kwargs"): return message.additional_kwargs.get("token_usage", {}) @@ -59,7 +62,7 @@ class SparkLLM(BaseLLM): async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): messages = convert_to_messages(messages) - response = await self.client.agenerate([messages]) + response = await self.acreate(messages, stream=False) usage = self.get_usage(response) self._update_costs(usage) return response @@ -68,7 +71,7 @@ class SparkLLM(BaseLLM): return await self._achat_completion(messages, timeout) async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str: - response = self.client.astream(messages) + response = await self.acreate(messages, stream=True) collected_content = [] usage = {} async for chunk in response: @@ -82,5 +85,11 @@ class SparkLLM(BaseLLM): full_content = "".join(collected_content) return full_content - def _extract_assistant_rsp(self, context): + def _extract_assistant_rsp(self, context: list[BaseMessage]) -> str: return "\n".join([i.content for i in context if "AIMessage" in any_to_str(i)]) + + async def acreate(self, messages: list[dict], stream: bool = True): + if stream: + return self.client.astream(messages) + else: + return await self.client.agenerate([messages]) diff --git a/tests/metagpt/provider/test_spark_api.py b/tests/metagpt/provider/test_spark_api.py index 043d98d13..08d307153 100644 --- a/tests/metagpt/provider/test_spark_api.py +++ b/tests/metagpt/provider/test_spark_api.py @@ -1,3 +1,34 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# @Desc : the unittest of spark api +""" +用于讯飞星火SDK的测试用例 +文档:https://www.xfyun.cn/doc/spark/Web.html +""" + + +import pytest + +from metagpt.provider.spark_api import SparkLLM +from tests.metagpt.provider.mock_llm_config import mock_llm_config_spark +from tests.metagpt.provider.req_resp_const import ( + llm_general_chat_funcs_test, + messages, + prompt, + resp_cont_tmpl, +) + +resp_cont = resp_cont_tmpl.format(name="Spark") + + +def mock_spark_acreate(self, messages, stream): + pass + + +@pytest.mark.asyncio +async def test_spark_acompletion(mocker): + mocker.patch("metagpt.provider.spark_api.SparkLLM.acreate", mock_spark_acreate) + + spark_llm = SparkLLM(mock_llm_config_spark) + + resp = await spark_llm.acompletion([messages]) + assert resp == resp_cont + + await llm_general_chat_funcs_test(spark_llm, prompt, prompt, resp_cont)