fix bugs for test

This commit is contained in:
sunjiashuo 2025-06-17 10:52:39 +08:00
commit a05eed2e9a
3 changed files with 143 additions and 12 deletions

View file

@ -191,7 +191,7 @@ async def llm_general_chat_funcs_test(llm: BaseLLM, prompt: str, messages: list[
BEDROCK_PROVIDER_REQUEST_BODY = {
"mistral": {"prompt": "", "max_tokens": 0, "stop": [], "temperature": 0.0, "top_p": 0.0, "top_k": 0},
"meta": {"prompt": "", "temperature": 0.0, "top_p": 0.0, "max_gen_len": 0},
"ai21": {
"ai21-j2": {
"prompt": "",
"temperature": 0.0,
"topP": 0.0,
@ -201,6 +201,16 @@ BEDROCK_PROVIDER_REQUEST_BODY = {
"presencePenalty": {"scale": 0.0},
"frequencyPenalty": {"scale": 0.0},
},
"ai21-jamba": {
"messages": [],
"temperature": 0.0,
"topP": 0.0,
"max_tokens": 0,
"stopSequences": [],
"countPenalty": {"scale": 0.0},
"presencePenalty": {"scale": 0.0},
"frequencyPenalty": {"scale": 0.0},
},
"cohere": {
"prompt": "",
"temperature": 0.0,
@ -214,6 +224,20 @@ BEDROCK_PROVIDER_REQUEST_BODY = {
"logit_bias": {},
"truncate": "NONE",
},
"cohere-command-r": {
"message": [],
"chat_history": [],
"temperature": 0.0,
"p": 0.0,
"k": 0.0,
"max_tokens": 0,
"stop_sequences": [],
"return_likelihoods": "NONE",
"stream": False,
"num_generations": 0,
"logit_bias": {},
"truncate": "NONE",
},
"anthropic": {
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": 0,
@ -233,12 +257,20 @@ BEDROCK_PROVIDER_REQUEST_BODY = {
BEDROCK_PROVIDER_RESPONSE_BODY = {
"mistral": {"outputs": [{"text": "Hello World", "stop_reason": ""}]},
"meta": {"generation": "Hello World", "prompt_token_count": 0, "generation_token_count": 0, "stop_reason": ""},
"ai21": {
"ai21-jamba": {
"id": "",
"prompt": {"text": "Hello World", "tokens": []},
"completions": [
{"data": {"text": "Hello World", "tokens": []}, "finishReason": {"reason": "length", "length": 2}}
],
"choices": [{"message": {"content": "Hello World"}}],
},
"ai21-jamba-stream": {
"id": "",
"prompt": {"text": "Hello World", "tokens": []},
"choices": [{"delta": {"content": "Hello World"}}],
},
"ai21-j2": {
"id": "",
"prompt": {"text": "Hello World", "tokens": []},
"completions": [{"data": {"text": "Hello World"}, "finishReason": {"reason": "length", "length": 2}}],
},
"cohere": {
"generations": [
@ -255,6 +287,21 @@ BEDROCK_PROVIDER_RESPONSE_BODY = {
"id": "",
"prompt": "",
},
"cohere-command-r": {
"generations": [
{
"finish_reason": "",
"id": "",
"text": "Hello World",
"likelihood": 0.0,
"token_likelihoods": [{"token": 0.0}],
"is_finished": True,
"index": 0,
}
],
"id": "",
"prompt": "",
},
"anthropic": {
"id": "",
"model": "",

View file

@ -8,6 +8,7 @@
import pytest
from metagpt.configs.compress_msg_config import CompressType
from metagpt.configs.llm_config import LLMConfig
from metagpt.const import IMAGES
from metagpt.provider.base_llm import BaseLLM
@ -25,7 +26,6 @@ name = "GPT"
class MockBaseLLM(BaseLLM):
def __init__(self, config: LLMConfig = None):
self.config = config or mock_llm_config
self.model = mock_llm_config.model
def completion(self, messages: list[dict], timeout=3):
return get_part_chat_completion(name)
@ -108,6 +108,64 @@ async def test_async_base_llm():
# assert resp == default_resp_cont
@pytest.mark.parametrize("compress_type", list(CompressType))
def test_compress_messages_no_effect(compress_type):
base_llm = MockBaseLLM()
messages = [
{"role": "system", "content": "first system msg"},
{"role": "system", "content": "second system msg"},
]
for i in range(5):
messages.append({"role": "user", "content": f"u{i}"})
messages.append({"role": "assistant", "content": f"a{i}"})
compressed = base_llm.compress_messages(messages, compress_type=compress_type)
# should take no effect for short context
assert compressed == messages
@pytest.mark.parametrize("compress_type", CompressType.cut_types())
def test_compress_messages_long(compress_type):
base_llm = MockBaseLLM()
base_llm.config.model = "test_llm"
max_token_limit = 100
messages = [
{"role": "system", "content": "first system msg"},
{"role": "system", "content": "second system msg"},
]
for i in range(100):
messages.append({"role": "user", "content": f"u{i}" * 10}) # ~2x10x0.5 = 10 tokens
messages.append({"role": "assistant", "content": f"a{i}" * 10})
compressed = base_llm.compress_messages(messages, compress_type=compress_type, max_token=max_token_limit)
print(compressed)
print(len(compressed))
assert 3 <= len(compressed) < len(messages)
assert compressed[0]["role"] == "system" and compressed[1]["role"] == "system"
assert compressed[2]["role"] != "system"
def test_long_messages_no_compress():
base_llm = MockBaseLLM()
messages = [{"role": "user", "content": "1" * 10000}] * 10000
compressed = base_llm.compress_messages(messages)
assert len(compressed) == len(messages)
@pytest.mark.parametrize("compress_type", CompressType.cut_types())
def test_compress_messages_long_no_sys_msg(compress_type):
base_llm = MockBaseLLM()
base_llm.config.model = "test_llm"
max_token_limit = 100
messages = [{"role": "user", "content": "1" * 10000}]
compressed = base_llm.compress_messages(messages, compress_type=compress_type, max_token=max_token_limit)
print(compressed)
assert compressed
assert len(compressed[0]["content"]) < len(messages[0]["content"])
def test_format_msg(mocker):
base_llm = MockBaseLLM()
messages = [UserMessage(content="req"), AIMessage(content="rsp")]
@ -143,4 +201,4 @@ def test_format_msg_w_images(mocker):
if name == "__main__":
pytest.main([__file__, "-s"])
pytest.main([__file__, "-s"])

View file

@ -22,9 +22,33 @@ usage = {
}
def get_provider_name(model: str) -> str:
arr = model.split(".")
if len(arr) == 2:
provider, model_name = arr # meta、mistral……
elif len(arr) == 3:
# some model_ids may contain country like us.xx.xxx
_, provider, model_name = arr
return provider
def deal_special_provider(provider: str, model: str, stream: bool = False) -> str:
# for ai21
if "j2-" in model:
provider = f"{provider}-j2"
elif "jamba-" in model:
provider = f"{provider}-jamba"
elif "command-r" in model:
provider = f"{provider}-command-r"
if stream and "ai21" in model:
provider = f"{provider}-stream"
return provider
async def mock_invoke_model(self: BedrockLLM, *args, **kwargs) -> dict:
provider = self.config.model.split(".")[0]
self._update_costs(usage, self.config.model)
provider = get_provider_name(self.model)
self._update_costs(usage, self.model)
provider = deal_special_provider(provider, self.model)
return BEDROCK_PROVIDER_RESPONSE_BODY[provider]
@ -33,7 +57,7 @@ async def mock_invoke_model_stream(self: BedrockLLM, *args, **kwargs) -> dict:
def dict2bytes(x):
return json.dumps(x).encode("utf-8")
provider = self.config.model.split(".")[0]
provider = get_provider_name(self.model)
if provider == "amazon":
response_body_bytes = dict2bytes({"outputText": "Hello World"})
@ -44,15 +68,17 @@ async def mock_invoke_model_stream(self: BedrockLLM, *args, **kwargs) -> dict:
elif provider == "cohere":
response_body_bytes = dict2bytes({"is_finished": False, "text": "Hello World"})
else:
provider = deal_special_provider(provider, self.model, stream=True)
response_body_bytes = dict2bytes(BEDROCK_PROVIDER_RESPONSE_BODY[provider])
response_body_stream = {"body": [{"chunk": {"bytes": response_body_bytes}}]}
self._update_costs(usage, self.config.model)
self._update_costs(usage, self.model)
return response_body_stream
def get_bedrock_request_body(model_id) -> dict:
provider = model_id.split(".")[0]
provider = get_provider_name(model_id)
provider = deal_special_provider(provider, model_id)
return BEDROCK_PROVIDER_REQUEST_BODY[provider]