add type hint

This commit is contained in:
usamimeri_renko 2024-05-22 22:43:10 +08:00
parent bb8ea2eaf9
commit f304fb3871
3 changed files with 54 additions and 14 deletions

View file

@ -1,7 +1,7 @@
# 适用于讯飞星火的spark-lite 参考 https://www.xfyun.cn/doc/spark/Web.html#_2-function-call%E8%AF%B4%E6%98%8E
llm:
api_type: 'spark'
api_type: "spark"
# 对应模型的url 参考 https://www.xfyun.cn/doc/spark/Web.html#_1-%E6%8E%A5%E5%8F%A3%E8%AF%B4%E6%98%8E
base_url: "ws(s)://spark-api.xf-yun.com/v1.1/chat"
app_id: ""

View file

@ -2,6 +2,11 @@
# -*- coding: utf-8 -*-
from sparkai.core.messages import _convert_to_message, convert_to_messages
from sparkai.core.messages.ai import AIMessage
from sparkai.core.messages.base import BaseMessage
from sparkai.core.messages.human import HumanMessage
from sparkai.core.messages.system import SystemMessage
from sparkai.core.outputs.llm_result import LLMResult
from sparkai.llm.llm import ChatSparkLLM
from metagpt.configs.llm_config import LLMConfig, LLMType
@ -9,8 +14,6 @@ from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.logs import log_llm_stream
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
# from sparkai.schema import LLMResult, HumanMessage, AIMessage # 由于其使用Pydantic V1,导入会报错
from metagpt.utils.common import any_to_str
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.token_counter import SPARK_TOKENS
@ -38,19 +41,19 @@ class SparkLLM(BaseLLM):
streaming=True,
)
def _system_msg(self, msg: str):
def _system_msg(self, msg: str) -> SystemMessage:
return _convert_to_message(msg)
def _user_msg(self, msg: str, **kwargs):
def _user_msg(self, msg: str, **kwargs) -> HumanMessage:
return _convert_to_message(msg)
def _assistant_msg(self, msg: str):
def _assistant_msg(self, msg: str) -> AIMessage:
return _convert_to_message(msg)
def get_choice_text(self, rsp) -> str:
def get_choice_text(self, rsp: LLMResult) -> str:
return rsp.generations[0][0].text
def get_usage(self, response):
def get_usage(self, response: LLMResult):
message = response.generations[0][0].message
if hasattr(message, "additional_kwargs"):
return message.additional_kwargs.get("token_usage", {})
@ -59,7 +62,7 @@ class SparkLLM(BaseLLM):
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
messages = convert_to_messages(messages)
response = await self.client.agenerate([messages])
response = await self.acreate(messages, stream=False)
usage = self.get_usage(response)
self._update_costs(usage)
return response
@ -68,7 +71,7 @@ class SparkLLM(BaseLLM):
return await self._achat_completion(messages, timeout)
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
response = self.client.astream(messages)
response = await self.acreate(messages, stream=True)
collected_content = []
usage = {}
async for chunk in response:
@ -82,5 +85,11 @@ class SparkLLM(BaseLLM):
full_content = "".join(collected_content)
return full_content
def _extract_assistant_rsp(self, context):
def _extract_assistant_rsp(self, context: list[BaseMessage]) -> str:
return "\n".join([i.content for i in context if "AIMessage" in any_to_str(i)])
async def acreate(self, messages: list[dict], stream: bool = True):
if stream:
return self.client.astream(messages)
else:
return await self.client.agenerate([messages])

View file

@ -1,3 +1,34 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the unittest of spark api
"""
用于讯飞星火SDK的测试用例
文档https://www.xfyun.cn/doc/spark/Web.html
"""
import pytest
from metagpt.provider.spark_api import SparkLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config_spark
from tests.metagpt.provider.req_resp_const import (
llm_general_chat_funcs_test,
messages,
prompt,
resp_cont_tmpl,
)
resp_cont = resp_cont_tmpl.format(name="Spark")
def mock_spark_acreate(self, messages, stream):
pass
@pytest.mark.asyncio
async def test_spark_acompletion(mocker):
mocker.patch("metagpt.provider.spark_api.SparkLLM.acreate", mock_spark_acreate)
spark_llm = SparkLLM(mock_llm_config_spark)
resp = await spark_llm.acompletion([messages])
assert resp == resp_cont
await llm_general_chat_funcs_test(spark_llm, prompt, prompt, resp_cont)