mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-27 14:25:20 +02:00
add type hint
This commit is contained in:
parent
bb8ea2eaf9
commit
f304fb3871
3 changed files with 54 additions and 14 deletions
|
|
@ -1,7 +1,7 @@
|
|||
# 适用于讯飞星火的spark-lite 参考 https://www.xfyun.cn/doc/spark/Web.html#_2-function-call%E8%AF%B4%E6%98%8E
|
||||
|
||||
llm:
|
||||
api_type: 'spark'
|
||||
api_type: "spark"
|
||||
# 对应模型的url 参考 https://www.xfyun.cn/doc/spark/Web.html#_1-%E6%8E%A5%E5%8F%A3%E8%AF%B4%E6%98%8E
|
||||
base_url: "ws(s)://spark-api.xf-yun.com/v1.1/chat"
|
||||
app_id: ""
|
||||
|
|
|
|||
|
|
@ -2,6 +2,11 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from sparkai.core.messages import _convert_to_message, convert_to_messages
|
||||
from sparkai.core.messages.ai import AIMessage
|
||||
from sparkai.core.messages.base import BaseMessage
|
||||
from sparkai.core.messages.human import HumanMessage
|
||||
from sparkai.core.messages.system import SystemMessage
|
||||
from sparkai.core.outputs.llm_result import LLMResult
|
||||
from sparkai.llm.llm import ChatSparkLLM
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
|
|
@ -9,8 +14,6 @@ from metagpt.const import USE_CONFIG_TIMEOUT
|
|||
from metagpt.logs import log_llm_stream
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
|
||||
# from sparkai.schema import LLMResult, HumanMessage, AIMessage # 由于其使用Pydantic V1,导入会报错
|
||||
from metagpt.utils.common import any_to_str
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.token_counter import SPARK_TOKENS
|
||||
|
|
@ -38,19 +41,19 @@ class SparkLLM(BaseLLM):
|
|||
streaming=True,
|
||||
)
|
||||
|
||||
def _system_msg(self, msg: str):
|
||||
def _system_msg(self, msg: str) -> SystemMessage:
|
||||
return _convert_to_message(msg)
|
||||
|
||||
def _user_msg(self, msg: str, **kwargs):
|
||||
def _user_msg(self, msg: str, **kwargs) -> HumanMessage:
|
||||
return _convert_to_message(msg)
|
||||
|
||||
def _assistant_msg(self, msg: str):
|
||||
def _assistant_msg(self, msg: str) -> AIMessage:
|
||||
return _convert_to_message(msg)
|
||||
|
||||
def get_choice_text(self, rsp) -> str:
|
||||
def get_choice_text(self, rsp: LLMResult) -> str:
|
||||
return rsp.generations[0][0].text
|
||||
|
||||
def get_usage(self, response):
|
||||
def get_usage(self, response: LLMResult):
|
||||
message = response.generations[0][0].message
|
||||
if hasattr(message, "additional_kwargs"):
|
||||
return message.additional_kwargs.get("token_usage", {})
|
||||
|
|
@ -59,7 +62,7 @@ class SparkLLM(BaseLLM):
|
|||
|
||||
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
|
||||
messages = convert_to_messages(messages)
|
||||
response = await self.client.agenerate([messages])
|
||||
response = await self.acreate(messages, stream=False)
|
||||
usage = self.get_usage(response)
|
||||
self._update_costs(usage)
|
||||
return response
|
||||
|
|
@ -68,7 +71,7 @@ class SparkLLM(BaseLLM):
|
|||
return await self._achat_completion(messages, timeout)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
|
||||
response = self.client.astream(messages)
|
||||
response = await self.acreate(messages, stream=True)
|
||||
collected_content = []
|
||||
usage = {}
|
||||
async for chunk in response:
|
||||
|
|
@ -82,5 +85,11 @@ class SparkLLM(BaseLLM):
|
|||
full_content = "".join(collected_content)
|
||||
return full_content
|
||||
|
||||
def _extract_assistant_rsp(self, context):
|
||||
def _extract_assistant_rsp(self, context: list[BaseMessage]) -> str:
|
||||
return "\n".join([i.content for i in context if "AIMessage" in any_to_str(i)])
|
||||
|
||||
async def acreate(self, messages: list[dict], stream: bool = True):
|
||||
if stream:
|
||||
return self.client.astream(messages)
|
||||
else:
|
||||
return await self.client.agenerate([messages])
|
||||
|
|
|
|||
|
|
@ -1,3 +1,34 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of spark api
|
||||
"""
|
||||
用于讯飞星火SDK的测试用例
|
||||
文档:https://www.xfyun.cn/doc/spark/Web.html
|
||||
"""
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.spark_api import SparkLLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config_spark
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
llm_general_chat_funcs_test,
|
||||
messages,
|
||||
prompt,
|
||||
resp_cont_tmpl,
|
||||
)
|
||||
|
||||
resp_cont = resp_cont_tmpl.format(name="Spark")
|
||||
|
||||
|
||||
def mock_spark_acreate(self, messages, stream):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spark_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.spark_api.SparkLLM.acreate", mock_spark_acreate)
|
||||
|
||||
spark_llm = SparkLLM(mock_llm_config_spark)
|
||||
|
||||
resp = await spark_llm.acompletion([messages])
|
||||
assert resp == resp_cont
|
||||
|
||||
await llm_general_chat_funcs_test(spark_llm, prompt, prompt, resp_cont)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue