add tests

This commit is contained in:
usamimeri_renko 2024-05-22 23:13:38 +08:00
parent f304fb3871
commit 7d33755749
2 changed files with 26 additions and 5 deletions

View file

@ -4,7 +4,12 @@
"""
from typing import AsyncIterator, List
import pytest
from sparkai.core.messages.ai import AIMessage, AIMessageChunk
from sparkai.core.outputs.chat_generation import ChatGeneration
from sparkai.core.outputs.llm_result import LLMResult
from metagpt.provider.spark_api import SparkLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config_spark
@ -16,10 +21,26 @@ from tests.metagpt.provider.req_resp_const import (
)
resp_cont = resp_cont_tmpl.format(name="Spark")
USAGE = {
"token_usage": {"question_tokens": 1000, "prompt_tokens": 1000, "completion_tokens": 1000, "total_tokens": 2000}
}
spark_agenerate_result = LLMResult(
generations=[[ChatGeneration(text=resp_cont, message=AIMessage(content=resp_cont, additional_kwargs=USAGE))]]
)
chunks = [AIMessageChunk(content=resp_cont), AIMessageChunk(content="", additional_kwargs=USAGE)]
def mock_spark_acreate(self, messages, stream):
pass
async def chunk_iterator(chunks: List[AIMessageChunk]) -> AsyncIterator[AIMessageChunk]:
for chunk in chunks:
yield chunk
async def mock_spark_acreate(self, messages, stream):
if stream:
return chunk_iterator(chunks)
else:
return spark_agenerate_result
@pytest.mark.asyncio
@ -29,6 +50,6 @@ async def test_spark_acompletion(mocker):
spark_llm = SparkLLM(mock_llm_config_spark)
resp = await spark_llm.acompletion([messages])
assert resp == resp_cont
assert spark_llm.get_choice_text(resp) == resp_cont
await llm_general_chat_funcs_test(spark_llm, prompt, prompt, resp_cont)
await llm_general_chat_funcs_test(spark_llm, prompt, messages, resp_cont)