Merge pull request #1291 from usamimeri/update_spark

Feat: support async spark api
This commit is contained in:
Alexander Wu 2024-05-29 16:20:29 +08:00 committed by GitHub
commit 519be03d4a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 123 additions and 191 deletions

View file

@ -1,62 +1,55 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the unittest of spark api
"""
用于讯飞星火SDK的测试用例
文档https://www.xfyun.cn/doc/spark/Web.html
"""
from typing import AsyncIterator, List
import pytest
from sparkai.core.messages.ai import AIMessage, AIMessageChunk
from sparkai.core.outputs.chat_generation import ChatGeneration
from sparkai.core.outputs.llm_result import LLMResult
from metagpt.provider.spark_api import GetMessageFromWeb, SparkLLM
from tests.metagpt.provider.mock_llm_config import (
mock_llm_config,
mock_llm_config_spark,
)
from metagpt.provider.spark_api import SparkLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config_spark
from tests.metagpt.provider.req_resp_const import (
llm_general_chat_funcs_test,
messages,
prompt,
resp_cont_tmpl,
)
resp_cont = resp_cont_tmpl.format(name="Spark")
USAGE = {
"token_usage": {"question_tokens": 1000, "prompt_tokens": 1000, "completion_tokens": 1000, "total_tokens": 2000}
}
spark_agenerate_result = LLMResult(
generations=[[ChatGeneration(text=resp_cont, message=AIMessage(content=resp_cont, additional_kwargs=USAGE))]]
)
chunks = [AIMessageChunk(content=resp_cont), AIMessageChunk(content="", additional_kwargs=USAGE)]
class MockWebSocketApp(object):
def __init__(self, ws_url, on_message=None, on_error=None, on_close=None, on_open=None):
pass
def run_forever(self, sslopt=None):
pass
async def chunk_iterator(chunks: List[AIMessageChunk]) -> AsyncIterator[AIMessageChunk]:
for chunk in chunks:
yield chunk
def test_get_msg_from_web(mocker):
mocker.patch("websocket.WebSocketApp", MockWebSocketApp)
get_msg_from_web = GetMessageFromWeb(prompt, mock_llm_config)
assert get_msg_from_web.gen_params()["parameter"]["chat"]["domain"] == "mock_domain"
ret = get_msg_from_web.run()
assert ret == ""
def mock_spark_get_msg_from_web_run(self) -> str:
return resp_cont
@pytest.mark.asyncio
async def test_spark_aask(mocker):
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
llm = SparkLLM(mock_llm_config_spark)
resp = await llm.aask("Hello!")
assert resp == resp_cont
async def mock_spark_acreate(self, messages, stream):
if stream:
return chunk_iterator(chunks)
else:
return spark_agenerate_result
@pytest.mark.asyncio
async def test_spark_acompletion(mocker):
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
mocker.patch("metagpt.provider.spark_api.SparkLLM.acreate", mock_spark_acreate)
spark_llm = SparkLLM(mock_llm_config)
spark_llm = SparkLLM(mock_llm_config_spark)
resp = await spark_llm.acompletion([])
assert resp == resp_cont
resp = await spark_llm.acompletion([messages])
assert spark_llm.get_choice_text(resp) == resp_cont
await llm_general_chat_funcs_test(spark_llm, prompt, prompt, resp_cont)
await llm_general_chat_funcs_test(spark_llm, prompt, messages, resp_cont)