add tests

This commit is contained in:
usamimeri_renko 2024-05-22 23:13:38 +08:00
parent f304fb3871
commit 7d33755749
2 changed files with 26 additions and 5 deletions

View file

@ -61,7 +61,6 @@ class SparkLLM(BaseLLM):
return {}
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
messages = convert_to_messages(messages)
response = await self.acreate(messages, stream=False)
usage = self.get_usage(response)
self._update_costs(usage)
@ -89,6 +88,7 @@ class SparkLLM(BaseLLM):
return "\n".join([i.content for i in context if "AIMessage" in any_to_str(i)])
async def acreate(self, messages: list[dict], stream: bool = True):
messages = convert_to_messages(messages)
if stream:
return self.client.astream(messages)
else:

View file

@ -4,7 +4,12 @@
"""
from typing import AsyncIterator, List
import pytest
from sparkai.core.messages.ai import AIMessage, AIMessageChunk
from sparkai.core.outputs.chat_generation import ChatGeneration
from sparkai.core.outputs.llm_result import LLMResult
from metagpt.provider.spark_api import SparkLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config_spark
@ -16,10 +21,26 @@ from tests.metagpt.provider.req_resp_const import (
)
resp_cont = resp_cont_tmpl.format(name="Spark")
USAGE = {
"token_usage": {"question_tokens": 1000, "prompt_tokens": 1000, "completion_tokens": 1000, "total_tokens": 2000}
}
spark_agenerate_result = LLMResult(
generations=[[ChatGeneration(text=resp_cont, message=AIMessage(content=resp_cont, additional_kwargs=USAGE))]]
)
chunks = [AIMessageChunk(content=resp_cont), AIMessageChunk(content="", additional_kwargs=USAGE)]
def mock_spark_acreate(self, messages, stream):
pass
async def chunk_iterator(chunks: List[AIMessageChunk]) -> AsyncIterator[AIMessageChunk]:
for chunk in chunks:
yield chunk
async def mock_spark_acreate(self, messages, stream):
if stream:
return chunk_iterator(chunks)
else:
return spark_agenerate_result
@pytest.mark.asyncio
@ -29,6 +50,6 @@ async def test_spark_acompletion(mocker):
spark_llm = SparkLLM(mock_llm_config_spark)
resp = await spark_llm.acompletion([messages])
assert resp == resp_cont
assert spark_llm.get_choice_text(resp) == resp_cont
await llm_general_chat_funcs_test(spark_llm, prompt, prompt, resp_cont)
await llm_general_chat_funcs_test(spark_llm, prompt, messages, resp_cont)