From 7d337557498385aa25195f6fd291cadb7148cb15 Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Wed, 22 May 2024 23:13:38 +0800 Subject: [PATCH] add tests --- metagpt/provider/spark_api.py | 2 +- tests/metagpt/provider/test_spark_api.py | 29 ++++++++++++++++++++---- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/metagpt/provider/spark_api.py b/metagpt/provider/spark_api.py index bdac050d3..8a38d99c5 100644 --- a/metagpt/provider/spark_api.py +++ b/metagpt/provider/spark_api.py @@ -61,7 +61,6 @@ class SparkLLM(BaseLLM): return {} async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): - messages = convert_to_messages(messages) response = await self.acreate(messages, stream=False) usage = self.get_usage(response) self._update_costs(usage) @@ -89,6 +88,7 @@ class SparkLLM(BaseLLM): return "\n".join([i.content for i in context if "AIMessage" in any_to_str(i)]) async def acreate(self, messages: list[dict], stream: bool = True): + messages = convert_to_messages(messages) if stream: return self.client.astream(messages) else: diff --git a/tests/metagpt/provider/test_spark_api.py b/tests/metagpt/provider/test_spark_api.py index 08d307153..a73b3ff38 100644 --- a/tests/metagpt/provider/test_spark_api.py +++ b/tests/metagpt/provider/test_spark_api.py @@ -4,7 +4,12 @@ """ +from typing import AsyncIterator, List + import pytest +from sparkai.core.messages.ai import AIMessage, AIMessageChunk +from sparkai.core.outputs.chat_generation import ChatGeneration +from sparkai.core.outputs.llm_result import LLMResult from metagpt.provider.spark_api import SparkLLM from tests.metagpt.provider.mock_llm_config import mock_llm_config_spark @@ -16,10 +21,26 @@ from tests.metagpt.provider.req_resp_const import ( ) resp_cont = resp_cont_tmpl.format(name="Spark") +USAGE = { + "token_usage": {"question_tokens": 1000, "prompt_tokens": 1000, "completion_tokens": 1000, "total_tokens": 2000} +} +spark_agenerate_result = LLMResult( + generations=[[ChatGeneration(text=resp_cont, message=AIMessage(content=resp_cont, additional_kwargs=USAGE))]] +) + +chunks = [AIMessageChunk(content=resp_cont), AIMessageChunk(content="", additional_kwargs=USAGE)] -def mock_spark_acreate(self, messages, stream): - pass +async def chunk_iterator(chunks: List[AIMessageChunk]) -> AsyncIterator[AIMessageChunk]: + for chunk in chunks: + yield chunk + + +async def mock_spark_acreate(self, messages, stream): + if stream: + return chunk_iterator(chunks) + else: + return spark_agenerate_result @pytest.mark.asyncio @@ -29,6 +50,6 @@ async def test_spark_acompletion(mocker): spark_llm = SparkLLM(mock_llm_config_spark) resp = await spark_llm.acompletion([messages]) - assert resp == resp_cont + assert spark_llm.get_choice_text(resp) == resp_cont - await llm_general_chat_funcs_test(spark_llm, prompt, prompt, resp_cont) + await llm_general_chat_funcs_test(spark_llm, prompt, messages, resp_cont)