mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-26 17:26:22 +02:00
Merge remote-tracking branch 'origin/main'
This commit is contained in:
commit
eb3c6d14f9
41 changed files with 1093 additions and 363 deletions
|
|
@ -6,6 +6,9 @@
|
|||
@File : test_faiss_store.py
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from metagpt.const import EXAMPLE_PATH
|
||||
|
|
@ -14,8 +17,17 @@ from metagpt.logs import logger
|
|||
from metagpt.roles import Sales
|
||||
|
||||
|
||||
def mock_openai_embed_documents(self, texts: list[str], chunk_size: Optional[int] = 0) -> list[list[float]]:
|
||||
num = len(texts)
|
||||
embeds = np.random.randint(1, 100, size=(num, 1536)) # 1536: openai embedding dim
|
||||
embeds = (embeds - embeds.mean(axis=0)) / (embeds.std(axis=0))
|
||||
return embeds
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_json():
|
||||
async def test_search_json(mocker):
|
||||
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
|
||||
|
||||
store = FaissStore(EXAMPLE_PATH / "example.json")
|
||||
role = Sales(profile="Sales", store=store)
|
||||
query = "Which facial cleanser is good for oily skin?"
|
||||
|
|
@ -24,7 +36,9 @@ async def test_search_json():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_xlsx():
|
||||
async def test_search_xlsx(mocker):
|
||||
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
|
||||
|
||||
store = FaissStore(EXAMPLE_PATH / "example.xlsx")
|
||||
role = Sales(profile="Sales", store=store)
|
||||
query = "Which facial cleanser is good for oily skin?"
|
||||
|
|
@ -33,7 +47,9 @@ async def test_search_xlsx():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write():
|
||||
async def test_write(mocker):
|
||||
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
|
||||
|
||||
store = FaissStore(EXAMPLE_PATH / "example.xlsx", meta_col="Answer", content_col="Question")
|
||||
_faiss_store = store.write()
|
||||
assert _faiss_store.docstore
|
||||
|
|
|
|||
33
tests/metagpt/memory/mock_text_embed.py
Normal file
33
tests/metagpt/memory/mock_text_embed.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
dim = 1536 # openai embedding dim
|
||||
|
||||
text_embed_arr = [
|
||||
{"text": "Write a cli snake game", "embed": np.zeros(shape=[1, dim])}, # mock data, same as below
|
||||
{"text": "Write a game of cli snake", "embed": np.zeros(shape=[1, dim])},
|
||||
{"text": "Write a 2048 web game", "embed": np.ones(shape=[1, dim])},
|
||||
{"text": "Write a Battle City", "embed": np.ones(shape=[1, dim])},
|
||||
{
|
||||
"text": "The user has requested the creation of a command-line interface (CLI) snake game",
|
||||
"embed": np.zeros(shape=[1, dim]),
|
||||
},
|
||||
{"text": "The request is command-line interface (CLI) snake game", "embed": np.zeros(shape=[1, dim])},
|
||||
{
|
||||
"text": "Incorporate basic features of a snake game such as scoring and increasing difficulty",
|
||||
"embed": np.ones(shape=[1, dim]),
|
||||
},
|
||||
]
|
||||
|
||||
text_idx_dict = {item["text"]: idx for idx, item in enumerate(text_embed_arr)}
|
||||
|
||||
|
||||
def mock_openai_embed_documents(self, texts: list[str], chunk_size: Optional[int] = 0) -> list[list[float]]:
|
||||
idx = text_idx_dict.get(texts[0])
|
||||
embed = text_embed_arr[idx].get("embed")
|
||||
return embed
|
||||
|
|
@ -4,20 +4,22 @@
|
|||
@Desc : unittest of `metagpt/memory/longterm_memory.py`
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import UserRequirement
|
||||
from metagpt.config2 import config
|
||||
from metagpt.memory.longterm_memory import LongTermMemory
|
||||
from metagpt.roles.role import RoleContext
|
||||
from metagpt.schema import Message
|
||||
|
||||
os.environ.setdefault("OPENAI_API_KEY", config.get_openai_llm().api_key)
|
||||
from tests.metagpt.memory.mock_text_embed import (
|
||||
mock_openai_embed_documents,
|
||||
text_embed_arr,
|
||||
)
|
||||
|
||||
|
||||
def test_ltm_search():
|
||||
def test_ltm_search(mocker):
|
||||
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
|
||||
|
||||
role_id = "UTUserLtm(Product Manager)"
|
||||
from metagpt.environment import Environment
|
||||
|
||||
|
|
@ -27,20 +29,20 @@ def test_ltm_search():
|
|||
ltm = LongTermMemory()
|
||||
ltm.recover_memory(role_id, rc)
|
||||
|
||||
idea = "Write a cli snake game"
|
||||
idea = text_embed_arr[0].get("text", "Write a cli snake game")
|
||||
message = Message(role="User", content=idea, cause_by=UserRequirement)
|
||||
news = ltm.find_news([message])
|
||||
assert len(news) == 1
|
||||
ltm.add(message)
|
||||
|
||||
sim_idea = "Write a game of cli snake"
|
||||
sim_idea = text_embed_arr[1].get("text", "Write a game of cli snake")
|
||||
|
||||
sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement)
|
||||
news = ltm.find_news([sim_message])
|
||||
assert len(news) == 0
|
||||
ltm.add(sim_message)
|
||||
|
||||
new_idea = "Write a 2048 web game"
|
||||
new_idea = text_embed_arr[2].get("text", "Write a 2048 web game")
|
||||
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
|
||||
news = ltm.find_news([new_message])
|
||||
assert len(news) == 1
|
||||
|
|
@ -56,7 +58,7 @@ def test_ltm_search():
|
|||
news = ltm_new.find_news([sim_message])
|
||||
assert len(news) == 0
|
||||
|
||||
new_idea = "Write a Battle City"
|
||||
new_idea = text_embed_arr[3].get("text", "Write a Battle City")
|
||||
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
|
||||
news = ltm_new.find_news([new_message])
|
||||
assert len(news) == 1
|
||||
|
|
|
|||
|
|
@ -4,23 +4,25 @@
|
|||
@Desc : the unittests of metagpt/memory/memory_storage.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from metagpt.actions import UserRequirement, WritePRD
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.config2 import config
|
||||
from metagpt.const import DATA_PATH
|
||||
from metagpt.memory.memory_storage import MemoryStorage
|
||||
from metagpt.schema import Message
|
||||
|
||||
os.environ.setdefault("OPENAI_API_KEY", config.get_openai_llm().api_key)
|
||||
from tests.metagpt.memory.mock_text_embed import (
|
||||
mock_openai_embed_documents,
|
||||
text_embed_arr,
|
||||
)
|
||||
|
||||
|
||||
def test_idea_message():
|
||||
idea = "Write a cli snake game"
|
||||
def test_idea_message(mocker):
|
||||
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
|
||||
|
||||
idea = text_embed_arr[0].get("text", "Write a cli snake game")
|
||||
role_id = "UTUser1(Product Manager)"
|
||||
message = Message(role="User", content=idea, cause_by=UserRequirement)
|
||||
|
||||
|
|
@ -33,12 +35,12 @@ def test_idea_message():
|
|||
memory_storage.add(message)
|
||||
assert memory_storage.is_initialized is True
|
||||
|
||||
sim_idea = "Write a game of cli snake"
|
||||
sim_idea = text_embed_arr[1].get("text", "Write a game of cli snake")
|
||||
sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement)
|
||||
new_messages = memory_storage.search_dissimilar(sim_message)
|
||||
assert len(new_messages) == 0 # similar, return []
|
||||
|
||||
new_idea = "Write a 2048 web game"
|
||||
new_idea = text_embed_arr[2].get("text", "Write a 2048 web game")
|
||||
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
|
||||
new_messages = memory_storage.search_dissimilar(new_message)
|
||||
assert new_messages[0].content == message.content
|
||||
|
|
@ -47,13 +49,17 @@ def test_idea_message():
|
|||
assert memory_storage.is_initialized is False
|
||||
|
||||
|
||||
def test_actionout_message():
|
||||
def test_actionout_message(mocker):
|
||||
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
|
||||
|
||||
out_mapping = {"field1": (str, ...), "field2": (List[str], ...)}
|
||||
out_data = {"field1": "field1 value", "field2": ["field2 value1", "field2 value2"]}
|
||||
ic_obj = ActionNode.create_model_class("prd", out_mapping)
|
||||
|
||||
role_id = "UTUser2(Architect)"
|
||||
content = "The user has requested the creation of a command-line interface (CLI) snake game"
|
||||
content = text_embed_arr[4].get(
|
||||
"text", "The user has requested the creation of a command-line interface (CLI) snake game"
|
||||
)
|
||||
message = Message(
|
||||
content=content, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD
|
||||
) # WritePRD as test action
|
||||
|
|
@ -67,12 +73,14 @@ def test_actionout_message():
|
|||
memory_storage.add(message)
|
||||
assert memory_storage.is_initialized is True
|
||||
|
||||
sim_conent = "The request is command-line interface (CLI) snake game"
|
||||
sim_conent = text_embed_arr[5].get("text", "The request is command-line interface (CLI) snake game")
|
||||
sim_message = Message(content=sim_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
|
||||
new_messages = memory_storage.search_dissimilar(sim_message)
|
||||
assert len(new_messages) == 0 # similar, return []
|
||||
|
||||
new_conent = "Incorporate basic features of a snake game such as scoring and increasing difficulty"
|
||||
new_conent = text_embed_arr[6].get(
|
||||
"text", "Incorporate basic features of a snake game such as scoring and increasing difficulty"
|
||||
)
|
||||
new_message = Message(content=new_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
|
||||
new_messages = memory_storage.search_dissimilar(new_message)
|
||||
assert new_messages[0].content == message.content
|
||||
|
|
|
|||
|
|
@ -42,3 +42,17 @@ mock_llm_config_zhipu = LLMConfig(
|
|||
model="mock_zhipu_model",
|
||||
proxy="http://localhost:8080",
|
||||
)
|
||||
|
||||
|
||||
mock_llm_config_spark = LLMConfig(
|
||||
api_type="spark",
|
||||
app_id="xxx",
|
||||
api_key="xxx",
|
||||
api_secret="xxx",
|
||||
domain="generalv2",
|
||||
base_url="wss://spark-api.xf-yun.com/v3.1/chat",
|
||||
)
|
||||
|
||||
mock_llm_config_qianfan = LLMConfig(api_type="qianfan", access_key="xxx", secret_key="xxx", model="ERNIE-Bot-turbo")
|
||||
|
||||
mock_llm_config_dashscope = LLMConfig(api_type="dashscope", api_key="xxx", model="qwen-max")
|
||||
|
|
|
|||
145
tests/metagpt/provider/req_resp_const.py
Normal file
145
tests/metagpt/provider/req_resp_const.py
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : default request & response data for provider unittest
|
||||
|
||||
|
||||
from dashscope.api_entities.dashscope_response import (
|
||||
DashScopeAPIResponse,
|
||||
GenerationOutput,
|
||||
GenerationResponse,
|
||||
GenerationUsage,
|
||||
)
|
||||
from openai.types.chat.chat_completion import (
|
||||
ChatCompletion,
|
||||
ChatCompletionMessage,
|
||||
Choice,
|
||||
)
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion_chunk import Choice as AChoice
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
from qianfan.resources.typing import QfResponse
|
||||
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
|
||||
prompt = "who are you?"
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
resp_cont_tmpl = "I'm {name}"
|
||||
default_resp_cont = resp_cont_tmpl.format(name="GPT")
|
||||
|
||||
|
||||
# part of whole ChatCompletion of openai like structure
|
||||
def get_part_chat_completion(name: str) -> dict:
|
||||
part_chat_completion = {
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": resp_cont_tmpl.format(name=name),
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {"completion_tokens": 22, "prompt_tokens": 19, "total_tokens": 41},
|
||||
}
|
||||
return part_chat_completion
|
||||
|
||||
|
||||
def get_openai_chat_completion(name: str) -> ChatCompletion:
|
||||
openai_chat_completion = ChatCompletion(
|
||||
id="cmpl-a6652c1bb181caae8dd19ad8",
|
||||
model="xx/xxx",
|
||||
object="chat.completion",
|
||||
created=1703300855,
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(role="assistant", content=resp_cont_tmpl.format(name=name)),
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202),
|
||||
)
|
||||
return openai_chat_completion
|
||||
|
||||
|
||||
def get_openai_chat_completion_chunk(name: str, usage_as_dict: bool = False) -> ChatCompletionChunk:
|
||||
usage = CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202)
|
||||
usage = usage if not usage_as_dict else usage.model_dump()
|
||||
openai_chat_completion_chunk = ChatCompletionChunk(
|
||||
id="cmpl-a6652c1bb181caae8dd19ad8",
|
||||
model="xx/xxx",
|
||||
object="chat.completion.chunk",
|
||||
created=1703300855,
|
||||
choices=[
|
||||
AChoice(
|
||||
delta=ChoiceDelta(role="assistant", content=resp_cont_tmpl.format(name=name)),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
usage=usage,
|
||||
)
|
||||
return openai_chat_completion_chunk
|
||||
|
||||
|
||||
# For gemini
|
||||
gemini_messages = [{"role": "user", "parts": prompt}]
|
||||
|
||||
|
||||
# For QianFan
|
||||
qf_jsonbody_dict = {
|
||||
"id": "as-4v1h587fyv",
|
||||
"object": "chat.completion",
|
||||
"created": 1695021339,
|
||||
"result": "",
|
||||
"is_truncated": False,
|
||||
"need_clear_history": False,
|
||||
"usage": {"prompt_tokens": 7, "completion_tokens": 15, "total_tokens": 22},
|
||||
}
|
||||
|
||||
|
||||
def get_qianfan_response(name: str) -> QfResponse:
|
||||
qf_jsonbody_dict["result"] = resp_cont_tmpl.format(name=name)
|
||||
return QfResponse(code=200, body=qf_jsonbody_dict)
|
||||
|
||||
|
||||
# For DashScope
|
||||
def get_dashscope_response(name: str) -> GenerationResponse:
|
||||
return GenerationResponse.from_api_response(
|
||||
DashScopeAPIResponse(
|
||||
status_code=200,
|
||||
output=GenerationOutput(
|
||||
**{
|
||||
"text": "",
|
||||
"finish_reason": "",
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"message": {"role": "assistant", "content": resp_cont_tmpl.format(name=name)},
|
||||
}
|
||||
],
|
||||
}
|
||||
),
|
||||
usage=GenerationUsage(**{"input_tokens": 12, "output_tokens": 98, "total_tokens": 110}),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# For llm general chat functions call
|
||||
async def llm_general_chat_funcs_test(llm: BaseLLM, prompt: str, messages: list[dict], resp_cont: str):
|
||||
resp = await llm.aask(prompt, stream=False)
|
||||
assert resp == resp_cont
|
||||
|
||||
resp = await llm.aask(prompt)
|
||||
assert resp == resp_cont
|
||||
|
||||
resp = await llm.acompletion_text(messages, stream=False)
|
||||
assert resp == resp_cont
|
||||
|
||||
resp = await llm.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_cont
|
||||
|
|
@ -8,25 +8,25 @@ from anthropic.resources.completions import Completion
|
|||
|
||||
from metagpt.provider.anthropic_api import Claude2
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config
|
||||
from tests.metagpt.provider.req_resp_const import prompt, resp_cont_tmpl
|
||||
|
||||
prompt = "who are you"
|
||||
resp = "I'am Claude2"
|
||||
resp_cont = resp_cont_tmpl.format(name="Claude")
|
||||
|
||||
|
||||
def mock_anthropic_completions_create(self, model: str, prompt: str, max_tokens_to_sample: int) -> Completion:
|
||||
return Completion(id="xx", completion=resp, model="claude-2", stop_reason="stop_sequence", type="completion")
|
||||
return Completion(id="xx", completion=resp_cont, model="claude-2", stop_reason="stop_sequence", type="completion")
|
||||
|
||||
|
||||
async def mock_anthropic_acompletions_create(self, model: str, prompt: str, max_tokens_to_sample: int) -> Completion:
|
||||
return Completion(id="xx", completion=resp, model="claude-2", stop_reason="stop_sequence", type="completion")
|
||||
return Completion(id="xx", completion=resp_cont, model="claude-2", stop_reason="stop_sequence", type="completion")
|
||||
|
||||
|
||||
def test_claude2_ask(mocker):
|
||||
mocker.patch("anthropic.resources.completions.Completions.create", mock_anthropic_completions_create)
|
||||
assert resp == Claude2(mock_llm_config).ask(prompt)
|
||||
assert resp_cont == Claude2(mock_llm_config).ask(prompt)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_claude2_aask(mocker):
|
||||
mocker.patch("anthropic.resources.completions.AsyncCompletions.create", mock_anthropic_acompletions_create)
|
||||
assert resp == await Claude2(mock_llm_config).aask(prompt)
|
||||
assert resp_cont == await Claude2(mock_llm_config).aask(prompt)
|
||||
|
|
|
|||
|
|
@ -11,21 +11,13 @@ import pytest
|
|||
from metagpt.configs.llm_config import LLMConfig
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Message
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
default_resp_cont,
|
||||
get_part_chat_completion,
|
||||
prompt,
|
||||
)
|
||||
|
||||
default_chat_resp = {
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "I'am GPT",
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
]
|
||||
}
|
||||
prompt_msg = "who are you"
|
||||
resp_content = default_chat_resp["choices"][0]["message"]["content"]
|
||||
name = "GPT"
|
||||
|
||||
|
||||
class MockBaseLLM(BaseLLM):
|
||||
|
|
@ -33,16 +25,13 @@ class MockBaseLLM(BaseLLM):
|
|||
pass
|
||||
|
||||
def completion(self, messages: list[dict], timeout=3):
|
||||
return default_chat_resp
|
||||
return get_part_chat_completion(name)
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=3):
|
||||
return default_chat_resp
|
||||
return get_part_chat_completion(name)
|
||||
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
|
||||
return resp_content
|
||||
|
||||
async def close(self):
|
||||
return default_chat_resp
|
||||
return default_resp_cont
|
||||
|
||||
|
||||
def test_base_llm():
|
||||
|
|
@ -86,25 +75,25 @@ def test_base_llm():
|
|||
choice_text = base_llm.get_choice_text(openai_funccall_resp)
|
||||
assert choice_text == openai_funccall_resp["choices"][0]["message"]["content"]
|
||||
|
||||
# resp = base_llm.ask(prompt_msg)
|
||||
# assert resp == resp_content
|
||||
# resp = base_llm.ask(prompt)
|
||||
# assert resp == default_resp_cont
|
||||
|
||||
# resp = base_llm.ask_batch([prompt_msg])
|
||||
# assert resp == resp_content
|
||||
# resp = base_llm.ask_batch([prompt])
|
||||
# assert resp == default_resp_cont
|
||||
|
||||
# resp = base_llm.ask_code([prompt_msg])
|
||||
# assert resp == resp_content
|
||||
# resp = base_llm.ask_code([prompt])
|
||||
# assert resp == default_resp_cont
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_base_llm():
|
||||
base_llm = MockBaseLLM()
|
||||
|
||||
resp = await base_llm.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
resp = await base_llm.aask(prompt)
|
||||
assert resp == default_resp_cont
|
||||
|
||||
resp = await base_llm.aask_batch([prompt_msg])
|
||||
assert resp == resp_content
|
||||
resp = await base_llm.aask_batch([prompt])
|
||||
assert resp == default_resp_cont
|
||||
|
||||
# resp = await base_llm.aask_code([prompt_msg])
|
||||
# assert resp == resp_content
|
||||
# resp = await base_llm.aask_code([prompt])
|
||||
# assert resp == default_resp_cont
|
||||
|
|
|
|||
73
tests/metagpt/provider/test_dashscope_api.py
Normal file
73
tests/metagpt/provider/test_dashscope_api.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of DashScopeLLM
|
||||
|
||||
from typing import AsyncGenerator, Union
|
||||
|
||||
import pytest
|
||||
from dashscope.api_entities.dashscope_response import GenerationResponse
|
||||
|
||||
from metagpt.provider.dashscope_api import DashScopeLLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config_dashscope
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
get_dashscope_response,
|
||||
llm_general_chat_funcs_test,
|
||||
messages,
|
||||
prompt,
|
||||
resp_cont_tmpl,
|
||||
)
|
||||
|
||||
name = "qwen-max"
|
||||
resp_cont = resp_cont_tmpl.format(name=name)
|
||||
|
||||
|
||||
@classmethod
|
||||
def mock_dashscope_call(
|
||||
cls,
|
||||
messages: list[dict],
|
||||
model: str,
|
||||
api_key: str,
|
||||
result_format: str,
|
||||
incremental_output: bool = True,
|
||||
stream: bool = False,
|
||||
) -> GenerationResponse:
|
||||
return get_dashscope_response(name)
|
||||
|
||||
|
||||
@classmethod
|
||||
async def mock_dashscope_acall(
|
||||
cls,
|
||||
messages: list[dict],
|
||||
model: str,
|
||||
api_key: str,
|
||||
result_format: str,
|
||||
incremental_output: bool = True,
|
||||
stream: bool = False,
|
||||
) -> Union[AsyncGenerator[GenerationResponse, None], GenerationResponse]:
|
||||
resps = [get_dashscope_response(name)]
|
||||
|
||||
if stream:
|
||||
|
||||
async def aresp_iterator(resps: list[GenerationResponse]):
|
||||
for resp in resps:
|
||||
yield resp
|
||||
|
||||
return aresp_iterator(resps)
|
||||
else:
|
||||
return resps[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dashscope_acompletion(mocker):
|
||||
mocker.patch("dashscope.aigc.generation.Generation.call", mock_dashscope_call)
|
||||
mocker.patch("metagpt.provider.dashscope_api.AGeneration.acall", mock_dashscope_acall)
|
||||
|
||||
dashscope_llm = DashScopeLLM(mock_llm_config_dashscope)
|
||||
|
||||
resp = dashscope_llm.completion(messages)
|
||||
assert resp.choices[0]["message"]["content"] == resp_cont
|
||||
|
||||
resp = await dashscope_llm.acompletion(messages)
|
||||
assert resp.choices[0]["message"]["content"] == resp_cont
|
||||
|
||||
await llm_general_chat_funcs_test(dashscope_llm, prompt, messages, resp_cont)
|
||||
|
|
@ -3,14 +3,7 @@
|
|||
# @Desc : the unittest of fireworks api
|
||||
|
||||
import pytest
|
||||
from openai.types.chat.chat_completion import (
|
||||
ChatCompletion,
|
||||
ChatCompletionMessage,
|
||||
Choice,
|
||||
)
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion_chunk import Choice as AChoice
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
|
||||
from metagpt.provider.fireworks_api import (
|
||||
|
|
@ -20,42 +13,19 @@ from metagpt.provider.fireworks_api import (
|
|||
)
|
||||
from metagpt.utils.cost_manager import Costs
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config
|
||||
|
||||
resp_content = "I'm fireworks"
|
||||
default_resp = ChatCompletion(
|
||||
id="cmpl-a6652c1bb181caae8dd19ad8",
|
||||
model="accounts/fireworks/models/llama-v2-13b-chat",
|
||||
object="chat.completion",
|
||||
created=1703300855,
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(role="assistant", content=resp_content),
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202),
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
get_openai_chat_completion,
|
||||
get_openai_chat_completion_chunk,
|
||||
llm_general_chat_funcs_test,
|
||||
messages,
|
||||
prompt,
|
||||
resp_cont_tmpl,
|
||||
)
|
||||
|
||||
default_resp_chunk = ChatCompletionChunk(
|
||||
id=default_resp.id,
|
||||
model=default_resp.model,
|
||||
object="chat.completion.chunk",
|
||||
created=default_resp.created,
|
||||
choices=[
|
||||
AChoice(
|
||||
delta=ChoiceDelta(content=resp_content, role="assistant"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
usage=dict(default_resp.usage),
|
||||
)
|
||||
|
||||
prompt_msg = "who are you"
|
||||
messages = [{"role": "user", "content": prompt_msg}]
|
||||
name = "fireworks"
|
||||
resp_cont = resp_cont_tmpl.format(name=name)
|
||||
default_resp = get_openai_chat_completion(name)
|
||||
default_resp_chunk = get_openai_chat_completion_chunk(name, usage_as_dict=True)
|
||||
|
||||
|
||||
def test_fireworks_costmanager():
|
||||
|
|
@ -88,27 +58,17 @@ async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs)
|
|||
async def test_fireworks_acompletion(mocker):
|
||||
mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create)
|
||||
|
||||
fireworks_gpt = FireworksLLM(mock_llm_config)
|
||||
fireworks_gpt.model = "llama-v2-13b-chat"
|
||||
fireworks_llm = FireworksLLM(mock_llm_config)
|
||||
fireworks_llm.model = "llama-v2-13b-chat"
|
||||
|
||||
fireworks_gpt._update_costs(
|
||||
fireworks_llm._update_costs(
|
||||
usage=CompletionUsage(prompt_tokens=500000, completion_tokens=500000, total_tokens=1000000)
|
||||
)
|
||||
assert fireworks_gpt.get_costs() == Costs(
|
||||
assert fireworks_llm.get_costs() == Costs(
|
||||
total_prompt_tokens=500000, total_completion_tokens=500000, total_cost=0.5, total_budget=0
|
||||
)
|
||||
|
||||
resp = await fireworks_gpt.acompletion(messages)
|
||||
assert resp.choices[0].message.content in resp_content
|
||||
resp = await fireworks_llm.acompletion(messages)
|
||||
assert resp.choices[0].message.content in resp_cont
|
||||
|
||||
resp = await fireworks_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await fireworks_gpt.acompletion_text(messages, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await fireworks_gpt.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await fireworks_gpt.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
await llm_general_chat_funcs_test(fireworks_llm, prompt, messages, resp_cont)
|
||||
|
|
|
|||
|
|
@ -11,6 +11,12 @@ from google.generativeai.types import content_types
|
|||
|
||||
from metagpt.provider.google_gemini_api import GeminiLLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
gemini_messages,
|
||||
llm_general_chat_funcs_test,
|
||||
prompt,
|
||||
resp_cont_tmpl,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -18,10 +24,8 @@ class MockGeminiResponse(ABC):
|
|||
text: str
|
||||
|
||||
|
||||
prompt_msg = "who are you"
|
||||
messages = [{"role": "user", "parts": prompt_msg}]
|
||||
resp_content = "I'm gemini from google"
|
||||
default_resp = MockGeminiResponse(text=resp_content)
|
||||
resp_cont = resp_cont_tmpl.format(name="gemini")
|
||||
default_resp = MockGeminiResponse(text=resp_cont)
|
||||
|
||||
|
||||
def mock_gemini_count_tokens(self, contents: content_types.ContentsType) -> glm.CountTokensResponse:
|
||||
|
|
@ -60,28 +64,18 @@ async def test_gemini_acompletion(mocker):
|
|||
mock_gemini_generate_content_async,
|
||||
)
|
||||
|
||||
gemini_gpt = GeminiLLM(mock_llm_config)
|
||||
gemini_llm = GeminiLLM(mock_llm_config)
|
||||
|
||||
assert gemini_gpt._user_msg(prompt_msg) == {"role": "user", "parts": [prompt_msg]}
|
||||
assert gemini_gpt._assistant_msg(prompt_msg) == {"role": "model", "parts": [prompt_msg]}
|
||||
assert gemini_llm._user_msg(prompt) == {"role": "user", "parts": [prompt]}
|
||||
assert gemini_llm._assistant_msg(prompt) == {"role": "model", "parts": [prompt]}
|
||||
|
||||
usage = gemini_gpt.get_usage(messages, resp_content)
|
||||
usage = gemini_llm.get_usage(gemini_messages, resp_cont)
|
||||
assert usage == {"prompt_tokens": 20, "completion_tokens": 20}
|
||||
|
||||
resp = gemini_gpt.completion(messages)
|
||||
resp = gemini_llm.completion(gemini_messages)
|
||||
assert resp == default_resp
|
||||
|
||||
resp = await gemini_gpt.acompletion(messages)
|
||||
resp = await gemini_llm.acompletion(gemini_messages)
|
||||
assert resp.text == default_resp.text
|
||||
|
||||
resp = await gemini_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await gemini_gpt.acompletion_text(messages, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await gemini_gpt.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await gemini_gpt.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
await llm_general_chat_funcs_test(gemini_llm, prompt, gemini_messages, resp_cont)
|
||||
|
|
|
|||
|
|
@ -9,12 +9,15 @@ import pytest
|
|||
|
||||
from metagpt.provider.ollama_api import OllamaLLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
llm_general_chat_funcs_test,
|
||||
messages,
|
||||
prompt,
|
||||
resp_cont_tmpl,
|
||||
)
|
||||
|
||||
prompt_msg = "who are you"
|
||||
messages = [{"role": "user", "content": prompt_msg}]
|
||||
|
||||
resp_content = "I'm ollama"
|
||||
default_resp = {"message": {"role": "assistant", "content": resp_content}}
|
||||
resp_cont = resp_cont_tmpl.format(name="ollama")
|
||||
default_resp = {"message": {"role": "assistant", "content": resp_cont}}
|
||||
|
||||
|
||||
async def mock_ollama_arequest(self, stream: bool = False, **kwargs) -> Tuple[Any, Any, bool]:
|
||||
|
|
@ -41,19 +44,12 @@ async def mock_ollama_arequest(self, stream: bool = False, **kwargs) -> Tuple[An
|
|||
async def test_gemini_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.general_api_requestor.GeneralAPIRequestor.arequest", mock_ollama_arequest)
|
||||
|
||||
ollama_gpt = OllamaLLM(mock_llm_config)
|
||||
ollama_llm = OllamaLLM(mock_llm_config)
|
||||
|
||||
resp = await ollama_gpt.acompletion(messages)
|
||||
resp = await ollama_llm.acompletion(messages)
|
||||
assert resp["message"]["content"] == default_resp["message"]["content"]
|
||||
|
||||
resp = await ollama_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
resp = await ollama_llm.aask(prompt, stream=False)
|
||||
assert resp == resp_cont
|
||||
|
||||
resp = await ollama_gpt.acompletion_text(messages, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await ollama_gpt.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await ollama_gpt.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
await llm_general_chat_funcs_test(ollama_llm, prompt, messages, resp_cont)
|
||||
|
|
|
|||
|
|
@ -3,53 +3,26 @@
|
|||
# @Desc :
|
||||
|
||||
import pytest
|
||||
from openai.types.chat.chat_completion import (
|
||||
ChatCompletion,
|
||||
ChatCompletionMessage,
|
||||
Choice,
|
||||
)
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion_chunk import Choice as AChoice
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
|
||||
from metagpt.provider.open_llm_api import OpenLLM
|
||||
from metagpt.utils.cost_manager import Costs
|
||||
from metagpt.utils.cost_manager import CostManager, Costs
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config
|
||||
|
||||
resp_content = "I'm llama2"
|
||||
default_resp = ChatCompletion(
|
||||
id="cmpl-a6652c1bb181caae8dd19ad8",
|
||||
model="llama-v2-13b-chat",
|
||||
object="chat.completion",
|
||||
created=1703302755,
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(role="assistant", content=resp_content),
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
get_openai_chat_completion,
|
||||
get_openai_chat_completion_chunk,
|
||||
llm_general_chat_funcs_test,
|
||||
messages,
|
||||
prompt,
|
||||
resp_cont_tmpl,
|
||||
)
|
||||
|
||||
default_resp_chunk = ChatCompletionChunk(
|
||||
id=default_resp.id,
|
||||
model=default_resp.model,
|
||||
object="chat.completion.chunk",
|
||||
created=default_resp.created,
|
||||
choices=[
|
||||
AChoice(
|
||||
delta=ChoiceDelta(content=resp_content, role="assistant"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
)
|
||||
name = "llama2-7b"
|
||||
resp_cont = resp_cont_tmpl.format(name=name)
|
||||
default_resp = get_openai_chat_completion(name)
|
||||
|
||||
prompt_msg = "who are you"
|
||||
messages = [{"role": "user", "content": prompt_msg}]
|
||||
default_resp_chunk = get_openai_chat_completion_chunk(name)
|
||||
|
||||
|
||||
async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) -> ChatCompletionChunk:
|
||||
|
|
@ -68,25 +41,16 @@ async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs)
|
|||
async def test_openllm_acompletion(mocker):
|
||||
mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create)
|
||||
|
||||
openllm_gpt = OpenLLM(mock_llm_config)
|
||||
openllm_gpt.model = "llama-v2-13b-chat"
|
||||
openllm_llm = OpenLLM(mock_llm_config)
|
||||
openllm_llm.model = "llama-v2-13b-chat"
|
||||
|
||||
openllm_gpt._update_costs(usage=CompletionUsage(prompt_tokens=100, completion_tokens=100, total_tokens=200))
|
||||
assert openllm_gpt.get_costs() == Costs(
|
||||
openllm_llm.cost_manager = CostManager()
|
||||
openllm_llm._update_costs(usage=CompletionUsage(prompt_tokens=100, completion_tokens=100, total_tokens=200))
|
||||
assert openllm_llm.get_costs() == Costs(
|
||||
total_prompt_tokens=100, total_completion_tokens=100, total_cost=0, total_budget=0
|
||||
)
|
||||
|
||||
resp = await openllm_gpt.acompletion(messages)
|
||||
assert resp.choices[0].message.content in resp_content
|
||||
resp = await openllm_llm.acompletion(messages)
|
||||
assert resp.choices[0].message.content in resp_cont
|
||||
|
||||
resp = await openllm_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await openllm_gpt.acompletion_text(messages, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await openllm_gpt.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await openllm_gpt.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
await llm_general_chat_funcs_test(openllm_llm, prompt, messages, resp_cont)
|
||||
|
|
|
|||
56
tests/metagpt/provider/test_qianfan_api.py
Normal file
56
tests/metagpt/provider/test_qianfan_api.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of qianfan api
|
||||
|
||||
from typing import AsyncIterator, Union
|
||||
|
||||
import pytest
|
||||
from qianfan.resources.typing import JsonBody, QfResponse
|
||||
|
||||
from metagpt.provider.qianfan_api import QianFanLLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config_qianfan
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
get_qianfan_response,
|
||||
llm_general_chat_funcs_test,
|
||||
messages,
|
||||
prompt,
|
||||
resp_cont_tmpl,
|
||||
)
|
||||
|
||||
name = "ERNIE-Bot-turbo"
|
||||
resp_cont = resp_cont_tmpl.format(name=name)
|
||||
|
||||
|
||||
def mock_qianfan_do(self, messages: list[dict], model: str, stream: bool = False, system: str = None) -> QfResponse:
|
||||
return get_qianfan_response(name=name)
|
||||
|
||||
|
||||
async def mock_qianfan_ado(
|
||||
self, messages: list[dict], model: str, stream: bool = True, system: str = None
|
||||
) -> Union[QfResponse, AsyncIterator[QfResponse]]:
|
||||
resps = [get_qianfan_response(name=name)]
|
||||
if stream:
|
||||
|
||||
async def aresp_iterator(resps: list[JsonBody]):
|
||||
for resp in resps:
|
||||
yield resp
|
||||
|
||||
return aresp_iterator(resps)
|
||||
else:
|
||||
return resps[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_qianfan_acompletion(mocker):
|
||||
mocker.patch("qianfan.resources.llm.chat_completion.ChatCompletion.do", mock_qianfan_do)
|
||||
mocker.patch("qianfan.resources.llm.chat_completion.ChatCompletion.ado", mock_qianfan_ado)
|
||||
|
||||
qianfan_llm = QianFanLLM(mock_llm_config_qianfan)
|
||||
|
||||
resp = qianfan_llm.completion(messages)
|
||||
assert resp.get("result") == resp_cont
|
||||
|
||||
resp = await qianfan_llm.acompletion(messages)
|
||||
assert resp.get("result") == resp_cont
|
||||
|
||||
await llm_general_chat_funcs_test(qianfan_llm, prompt, messages, resp_cont)
|
||||
|
|
@ -4,12 +4,18 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.provider.spark_api import GetMessageFromWeb, SparkLLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config
|
||||
from tests.metagpt.provider.mock_llm_config import (
|
||||
mock_llm_config,
|
||||
mock_llm_config_spark,
|
||||
)
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
llm_general_chat_funcs_test,
|
||||
prompt,
|
||||
resp_cont_tmpl,
|
||||
)
|
||||
|
||||
prompt_msg = "who are you"
|
||||
resp_content = "I'm Spark"
|
||||
resp_cont = resp_cont_tmpl.format(name="Spark")
|
||||
|
||||
|
||||
class MockWebSocketApp(object):
|
||||
|
|
@ -23,7 +29,7 @@ class MockWebSocketApp(object):
|
|||
def test_get_msg_from_web(mocker):
|
||||
mocker.patch("websocket.WebSocketApp", MockWebSocketApp)
|
||||
|
||||
get_msg_from_web = GetMessageFromWeb(prompt_msg, mock_llm_config)
|
||||
get_msg_from_web = GetMessageFromWeb(prompt, mock_llm_config)
|
||||
assert get_msg_from_web.gen_params()["parameter"]["chat"]["domain"] == "mock_domain"
|
||||
|
||||
ret = get_msg_from_web.run()
|
||||
|
|
@ -31,34 +37,26 @@ def test_get_msg_from_web(mocker):
|
|||
|
||||
|
||||
def mock_spark_get_msg_from_web_run(self) -> str:
|
||||
return resp_content
|
||||
return resp_cont
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spark_aask():
|
||||
llm = SparkLLM(Config.from_home("spark.yaml").llm)
|
||||
async def test_spark_aask(mocker):
|
||||
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
|
||||
|
||||
llm = SparkLLM(mock_llm_config_spark)
|
||||
|
||||
resp = await llm.aask("Hello!")
|
||||
print(resp)
|
||||
assert resp == resp_cont
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spark_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
|
||||
|
||||
spark_gpt = SparkLLM(mock_llm_config)
|
||||
spark_llm = SparkLLM(mock_llm_config)
|
||||
|
||||
resp = await spark_gpt.acompletion([])
|
||||
assert resp == resp_content
|
||||
resp = await spark_llm.acompletion([])
|
||||
assert resp == resp_cont
|
||||
|
||||
resp = await spark_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await spark_gpt.acompletion_text([], stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await spark_gpt.acompletion_text([], stream=True)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await spark_gpt.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
await llm_general_chat_funcs_test(spark_llm, prompt, prompt, resp_cont)
|
||||
|
|
|
|||
|
|
@ -6,22 +6,24 @@ import pytest
|
|||
|
||||
from metagpt.provider.zhipuai_api import ZhiPuAILLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config_zhipu
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
get_part_chat_completion,
|
||||
llm_general_chat_funcs_test,
|
||||
messages,
|
||||
prompt,
|
||||
resp_cont_tmpl,
|
||||
)
|
||||
|
||||
prompt_msg = "who are you"
|
||||
messages = [{"role": "user", "content": prompt_msg}]
|
||||
|
||||
resp_content = "I'm chatglm-turbo"
|
||||
default_resp = {
|
||||
"choices": [{"finish_reason": "stop", "index": 0, "message": {"content": resp_content, "role": "assistant"}}],
|
||||
"usage": {"completion_tokens": 22, "prompt_tokens": 19, "total_tokens": 41},
|
||||
}
|
||||
name = "ChatGLM-4"
|
||||
resp_cont = resp_cont_tmpl.format(name=name)
|
||||
default_resp = get_part_chat_completion(name)
|
||||
|
||||
|
||||
async def mock_zhipuai_acreate_stream(self, **kwargs):
|
||||
class MockResponse(object):
|
||||
async def _aread(self):
|
||||
class Iterator(object):
|
||||
events = [{"choices": [{"index": 0, "delta": {"content": resp_content, "role": "assistant"}}]}]
|
||||
events = [{"choices": [{"index": 0, "delta": {"content": resp_cont, "role": "assistant"}}]}]
|
||||
|
||||
async def __aiter__(self):
|
||||
for event in self.events:
|
||||
|
|
@ -46,22 +48,12 @@ async def test_zhipuai_acompletion(mocker):
|
|||
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.acreate", mock_zhipuai_acreate)
|
||||
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.acreate_stream", mock_zhipuai_acreate_stream)
|
||||
|
||||
zhipu_gpt = ZhiPuAILLM(mock_llm_config_zhipu)
|
||||
zhipu_llm = ZhiPuAILLM(mock_llm_config_zhipu)
|
||||
|
||||
resp = await zhipu_gpt.acompletion(messages)
|
||||
assert resp["choices"][0]["message"]["content"] == resp_content
|
||||
resp = await zhipu_llm.acompletion(messages)
|
||||
assert resp["choices"][0]["message"]["content"] == resp_cont
|
||||
|
||||
resp = await zhipu_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await zhipu_gpt.acompletion_text(messages, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await zhipu_gpt.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await zhipu_gpt.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
await llm_general_chat_funcs_test(zhipu_llm, prompt, messages, resp_cont)
|
||||
|
||||
|
||||
def test_zhipuai_proxy():
|
||||
|
|
|
|||
|
|
@ -211,6 +211,11 @@ value
|
|||
output = repair_invalid_json(output, "Expecting ',' delimiter: line 4 column 1")
|
||||
assert output == target_output
|
||||
|
||||
raw_output = '{"key": "url "http" \\"https\\" "}'
|
||||
target_output = '{"key": "url \\"http\\" \\"https\\" "}'
|
||||
output = repair_invalid_json(raw_output, "Expecting ',' delimiter: line 1 column 15 (char 14)")
|
||||
assert output == target_output
|
||||
|
||||
|
||||
def test_retry_parse_json_text():
|
||||
from metagpt.utils.repair_llm_raw_output import retry_parse_json_text
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue