mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
add tests
This commit is contained in:
parent
f304fb3871
commit
7d33755749
2 changed files with 26 additions and 5 deletions
|
|
@ -61,7 +61,6 @@ class SparkLLM(BaseLLM):
|
|||
return {}
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
|
||||
messages = convert_to_messages(messages)
|
||||
response = await self.acreate(messages, stream=False)
|
||||
usage = self.get_usage(response)
|
||||
self._update_costs(usage)
|
||||
|
|
@ -89,6 +88,7 @@ class SparkLLM(BaseLLM):
|
|||
return "\n".join([i.content for i in context if "AIMessage" in any_to_str(i)])
|
||||
|
||||
async def acreate(self, messages: list[dict], stream: bool = True):
|
||||
messages = convert_to_messages(messages)
|
||||
if stream:
|
||||
return self.client.astream(messages)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,12 @@
|
|||
"""
|
||||
|
||||
|
||||
from typing import AsyncIterator, List
|
||||
|
||||
import pytest
|
||||
from sparkai.core.messages.ai import AIMessage, AIMessageChunk
|
||||
from sparkai.core.outputs.chat_generation import ChatGeneration
|
||||
from sparkai.core.outputs.llm_result import LLMResult
|
||||
|
||||
from metagpt.provider.spark_api import SparkLLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config_spark
|
||||
|
|
@ -16,10 +21,26 @@ from tests.metagpt.provider.req_resp_const import (
|
|||
)
|
||||
|
||||
resp_cont = resp_cont_tmpl.format(name="Spark")
|
||||
USAGE = {
|
||||
"token_usage": {"question_tokens": 1000, "prompt_tokens": 1000, "completion_tokens": 1000, "total_tokens": 2000}
|
||||
}
|
||||
spark_agenerate_result = LLMResult(
|
||||
generations=[[ChatGeneration(text=resp_cont, message=AIMessage(content=resp_cont, additional_kwargs=USAGE))]]
|
||||
)
|
||||
|
||||
chunks = [AIMessageChunk(content=resp_cont), AIMessageChunk(content="", additional_kwargs=USAGE)]
|
||||
|
||||
|
||||
def mock_spark_acreate(self, messages, stream):
|
||||
pass
|
||||
async def chunk_iterator(chunks: List[AIMessageChunk]) -> AsyncIterator[AIMessageChunk]:
|
||||
for chunk in chunks:
|
||||
yield chunk
|
||||
|
||||
|
||||
async def mock_spark_acreate(self, messages, stream):
|
||||
if stream:
|
||||
return chunk_iterator(chunks)
|
||||
else:
|
||||
return spark_agenerate_result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -29,6 +50,6 @@ async def test_spark_acompletion(mocker):
|
|||
spark_llm = SparkLLM(mock_llm_config_spark)
|
||||
|
||||
resp = await spark_llm.acompletion([messages])
|
||||
assert resp == resp_cont
|
||||
assert spark_llm.get_choice_text(resp) == resp_cont
|
||||
|
||||
await llm_general_chat_funcs_test(spark_llm, prompt, prompt, resp_cont)
|
||||
await llm_general_chat_funcs_test(spark_llm, prompt, messages, resp_cont)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue