diff --git a/metagpt/provider/base_llm.py b/metagpt/provider/base_llm.py index 6560fc7dd..781963434 100644 --- a/metagpt/provider/base_llm.py +++ b/metagpt/provider/base_llm.py @@ -42,7 +42,9 @@ class BaseLLM(ABC): # OpenAI / Azure / Others aclient: Optional[Union[AsyncOpenAI]] = None cost_manager: Optional[CostManager] = None - model: Optional[str] = None # deprecated + # Maintain model name in own instance in case the global config has changed, + # Should always use model not config.model within this class + model: Optional[str] = None pricing_plan: Optional[str] = None _reasoning_content: Optional[str] = None # content from reasoning mode @@ -87,7 +89,7 @@ class BaseLLM(ABC): return {"role": "system", "content": msg} def support_image_input(self) -> bool: - return any([m in self.config.model for m in MULTI_MODAL_MODELS]) + return any([m in self.model for m in MULTI_MODAL_MODELS]) def format_msg(self, messages: Union[str, "Message", list[dict], list["Message"], list[str]]) -> list[dict]: """convert messages to list[dict].""" @@ -321,7 +323,7 @@ class BaseLLM(ABC): def with_model(self, model: str): """Set model and return self. For example, `with_model("gpt-3.5-turbo")`.""" - self.config.model = model + self.model = model return self def get_timeout(self, timeout: int) -> int: @@ -352,7 +354,7 @@ class BaseLLM(ABC): if compress_type == CompressType.NO_COMPRESS: return messages - max_token = TOKEN_MAX.get(self.config.model, max_token) + max_token = TOKEN_MAX.get(self.model, max_token) keep_token = int(max_token * threshold) compressed = [] diff --git a/metagpt/provider/bedrock_api.py b/metagpt/provider/bedrock_api.py index 80da95eae..c4d9e2183 100644 --- a/metagpt/provider/bedrock_api.py +++ b/metagpt/provider/bedrock_api.py @@ -22,13 +22,14 @@ from metagpt.utils.token_counter import BEDROCK_TOKEN_COSTS class BedrockLLM(BaseLLM): def __init__(self, config: LLMConfig): self.config = config + self.model = config.model self.__client = self.__init_client("bedrock-runtime") self.__provider = get_provider( - self.config.model, reasoning=self.config.reasoning, reasoning_max_token=self.config.reasoning_max_token + self.model, reasoning=self.config.reasoning, reasoning_max_token=self.config.reasoning_max_token ) self.cost_manager = CostManager(token_costs=BEDROCK_TOKEN_COSTS) - if self.config.model in NOT_SUPPORT_STREAM_MODELS: - logger.warning(f"model {self.config.model} doesn't support streaming output!") + if self.model in NOT_SUPPORT_STREAM_MODELS: + logger.warning(f"model {self.model} doesn't support streaming output!") def __init_client(self, service_name: Literal["bedrock-runtime", "bedrock"]): """initialize boto3 client""" @@ -72,25 +73,25 @@ class BedrockLLM(BaseLLM): async def invoke_model(self, request_body: str) -> dict: loop = asyncio.get_running_loop() response = await loop.run_in_executor( - None, partial(self.client.invoke_model, modelId=self.config.model, body=request_body) + None, partial(self.client.invoke_model, modelId=self.model, body=request_body) ) usage = self._get_usage(response) - self._update_costs(usage, self.config.model) + self._update_costs(usage, self.model) response_body = self._get_response_body(response) return response_body async def invoke_model_with_response_stream(self, request_body: str) -> EventStream: loop = asyncio.get_running_loop() response = await loop.run_in_executor( - None, partial(self.client.invoke_model_with_response_stream, modelId=self.config.model, body=request_body) + None, partial(self.client.invoke_model_with_response_stream, modelId=self.model, body=request_body) ) usage = self._get_usage(response) - self._update_costs(usage, self.config.model) + self._update_costs(usage, self.model) return response @property def _const_kwargs(self) -> dict: - model_max_tokens = get_max_tokens(self.config.model) + model_max_tokens = get_max_tokens(self.model) if self.config.max_token > model_max_tokens: max_tokens = model_max_tokens else: @@ -119,7 +120,7 @@ class BedrockLLM(BaseLLM): return await self.acompletion(messages) async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str: - if self.config.model in NOT_SUPPORT_STREAM_MODELS: + if self.model in NOT_SUPPORT_STREAM_MODELS: rsp = await self.acompletion(messages) full_text = self.get_choice_text(rsp) log_llm_stream(full_text) @@ -132,7 +133,7 @@ class BedrockLLM(BaseLLM): full_text = ("".join(collected_content)).lstrip() if self.__provider.usage: # if provider provide usage, update it - self._update_costs(self.__provider.usage, self.config.model) + self._update_costs(self.__provider.usage, self.model) return full_text def _get_response_body(self, response) -> dict: diff --git a/metagpt/provider/human_provider.py b/metagpt/provider/human_provider.py index a16a49c20..9c032d4b3 100644 --- a/metagpt/provider/human_provider.py +++ b/metagpt/provider/human_provider.py @@ -18,6 +18,7 @@ class HumanProvider(BaseLLM): def __init__(self, config: LLMConfig): self.config = config + self.model = config.model def ask(self, msg: str, timeout=USE_CONFIG_TIMEOUT) -> str: logger.info("It's your turn, please type in your response. You may also refer to the context below") diff --git a/metagpt/provider/ollama_api.py b/metagpt/provider/ollama_api.py index bb458b0e8..1663bf2b7 100644 --- a/metagpt/provider/ollama_api.py +++ b/metagpt/provider/ollama_api.py @@ -195,6 +195,7 @@ class OllamaLLM(BaseLLM): def __init__(self, config: LLMConfig): self.client = GeneralAPIRequestor(base_url=config.base_url, key=config.api_key) self.config = config + self.model = config.model self.http_method = "post" self.use_system_prompt = False self.cost_manager = TokenCostManager() diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index c608d1d29..9ed60bfe8 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -321,6 +321,6 @@ class OpenAILLM(BaseLLM): def count_tokens(self, messages: list[dict]) -> int: try: - return count_message_tokens(messages, self.config.model) + return count_message_tokens(messages, self.model) except: return super().count_tokens(messages) diff --git a/metagpt/provider/qianfan_api.py b/metagpt/provider/qianfan_api.py index 3ada7908d..d0a95e734 100644 --- a/metagpt/provider/qianfan_api.py +++ b/metagpt/provider/qianfan_api.py @@ -37,6 +37,7 @@ class QianFanLLM(BaseLLM): self.cost_manager = CostManager(token_costs=self.token_costs) def __init_qianfan(self): + self.model = self.config.model if self.config.access_key and self.config.secret_key: # for system level auth, use access_key and secret_key, recommended by official # set environment variable due to official recommendation @@ -61,14 +62,14 @@ class QianFanLLM(BaseLLM): ("ERNIE-Speed", "ernie_speed"), ("EB-turbo-AppBuilder", "ai_apaas"), ] - if self.config.model in [pair[0] for pair in support_system_pairs]: + if self.model in [pair[0] for pair in support_system_pairs]: # only some ERNIE models support self.use_system_prompt = True if self.config.endpoint in [pair[1] for pair in support_system_pairs]: self.use_system_prompt = True - assert not (self.config.model and self.config.endpoint), "Only set `model` or `endpoint` in the config" - assert self.config.model or self.config.endpoint, "Should set one of `model` or `endpoint` in the config" + assert not (self.model and self.config.endpoint), "Only set `model` or `endpoint` in the config" + assert self.model or self.config.endpoint, "Should set one of `model` or `endpoint` in the config" self.token_costs = copy.deepcopy(QIANFAN_MODEL_TOKEN_COSTS) self.token_costs.update(QIANFAN_ENDPOINT_TOKEN_COSTS) @@ -87,8 +88,8 @@ class QianFanLLM(BaseLLM): kwargs["temperature"] = self.config.temperature if self.config.endpoint: kwargs["endpoint"] = self.config.endpoint - elif self.config.model: - kwargs["model"] = self.config.model + elif self.model: + kwargs["model"] = self.model if self.use_system_prompt: # if the model support system prompt, extract and pass it @@ -99,7 +100,7 @@ class QianFanLLM(BaseLLM): def _update_costs(self, usage: dict): """update each request's token cost""" - model_or_endpoint = self.config.model or self.config.endpoint + model_or_endpoint = self.model or self.config.endpoint local_calc_usage = model_or_endpoint in self.token_costs super()._update_costs(usage, model_or_endpoint, local_calc_usage) diff --git a/tests/metagpt/provider/mock_llm_config.py b/tests/metagpt/provider/mock_llm_config.py index f563dccad..60b443fcd 100644 --- a/tests/metagpt/provider/mock_llm_config.py +++ b/tests/metagpt/provider/mock_llm_config.py @@ -15,6 +15,7 @@ mock_llm_config = LLMConfig( app_id="mock_app_id", api_secret="mock_api_secret", domain="mock_domain", + model="mock_model", ) diff --git a/tests/metagpt/provider/test_base_llm.py b/tests/metagpt/provider/test_base_llm.py index 62083a769..3ef1b4a1f 100644 --- a/tests/metagpt/provider/test_base_llm.py +++ b/tests/metagpt/provider/test_base_llm.py @@ -8,7 +8,6 @@ import pytest -from metagpt.configs.compress_msg_config import CompressType from metagpt.configs.llm_config import LLMConfig from metagpt.const import IMAGES from metagpt.provider.base_llm import BaseLLM @@ -26,6 +25,7 @@ name = "GPT" class MockBaseLLM(BaseLLM): def __init__(self, config: LLMConfig = None): self.config = config or mock_llm_config + self.model = mock_llm_config.model def completion(self, messages: list[dict], timeout=3): return get_part_chat_completion(name) @@ -108,64 +108,6 @@ async def test_async_base_llm(): # assert resp == default_resp_cont -@pytest.mark.parametrize("compress_type", list(CompressType)) -def test_compress_messages_no_effect(compress_type): - base_llm = MockBaseLLM() - messages = [ - {"role": "system", "content": "first system msg"}, - {"role": "system", "content": "second system msg"}, - ] - for i in range(5): - messages.append({"role": "user", "content": f"u{i}"}) - messages.append({"role": "assistant", "content": f"a{i}"}) - compressed = base_llm.compress_messages(messages, compress_type=compress_type) - # should take no effect for short context - assert compressed == messages - - -@pytest.mark.parametrize("compress_type", CompressType.cut_types()) -def test_compress_messages_long(compress_type): - base_llm = MockBaseLLM() - base_llm.config.model = "test_llm" - max_token_limit = 100 - - messages = [ - {"role": "system", "content": "first system msg"}, - {"role": "system", "content": "second system msg"}, - ] - for i in range(100): - messages.append({"role": "user", "content": f"u{i}" * 10}) # ~2x10x0.5 = 10 tokens - messages.append({"role": "assistant", "content": f"a{i}" * 10}) - compressed = base_llm.compress_messages(messages, compress_type=compress_type, max_token=max_token_limit) - - print(compressed) - print(len(compressed)) - assert 3 <= len(compressed) < len(messages) - assert compressed[0]["role"] == "system" and compressed[1]["role"] == "system" - assert compressed[2]["role"] != "system" - - -def test_long_messages_no_compress(): - base_llm = MockBaseLLM() - messages = [{"role": "user", "content": "1" * 10000}] * 10000 - compressed = base_llm.compress_messages(messages) - assert len(compressed) == len(messages) - - -@pytest.mark.parametrize("compress_type", CompressType.cut_types()) -def test_compress_messages_long_no_sys_msg(compress_type): - base_llm = MockBaseLLM() - base_llm.config.model = "test_llm" - max_token_limit = 100 - - messages = [{"role": "user", "content": "1" * 10000}] - compressed = base_llm.compress_messages(messages, compress_type=compress_type, max_token=max_token_limit) - - print(compressed) - assert compressed - assert len(compressed[0]["content"]) < len(messages[0]["content"]) - - def test_format_msg(mocker): base_llm = MockBaseLLM() messages = [UserMessage(content="req"), AIMessage(content="rsp")] @@ -175,7 +117,7 @@ def test_format_msg(mocker): def test_format_msg_w_images(mocker): base_llm = MockBaseLLM() - base_llm.config.model = "gpt-4o" + base_llm.model = "gpt-4o" msg_w_images = UserMessage(content="req1") msg_w_images.add_metadata(IMAGES, ["base64 string 1", "base64 string 2"]) msg_w_empty_images = UserMessage(content="req2") diff --git a/tests/metagpt/provider/test_bedrock_api.py b/tests/metagpt/provider/test_bedrock_api.py index 28d1d7008..44b1189bf 100644 --- a/tests/metagpt/provider/test_bedrock_api.py +++ b/tests/metagpt/provider/test_bedrock_api.py @@ -22,13 +22,13 @@ usage = { } -def mock_invoke_model(self: BedrockLLM, *args, **kwargs) -> dict: +async def mock_invoke_model(self: BedrockLLM, *args, **kwargs) -> dict: provider = self.config.model.split(".")[0] self._update_costs(usage, self.config.model) return BEDROCK_PROVIDER_RESPONSE_BODY[provider] -def mock_invoke_model_stream(self: BedrockLLM, *args, **kwargs) -> dict: +async def mock_invoke_model_stream(self: BedrockLLM, *args, **kwargs) -> dict: # use json object to mock EventStream def dict2bytes(x): return json.dumps(x).encode("utf-8") diff --git a/tests/metagpt/provider/test_openai.py b/tests/metagpt/provider/test_openai.py index d292a8286..6b8dd403e 100644 --- a/tests/metagpt/provider/test_openai.py +++ b/tests/metagpt/provider/test_openai.py @@ -169,7 +169,7 @@ async def test_openai_acompletion(mocker): def test_count_tokens(): llm = LLM() - llm.config.model = "gpt-4o" + llm.model = "gpt-4o" messages = [ llm._system_msg("some system msg"), llm._system_msg("some system message 2"), @@ -184,7 +184,7 @@ def test_count_tokens(): def test_count_tokens_long(): llm = LLM() - llm.config.model = "gpt-4-0613" + llm.model = "gpt-4-0613" test_msg_content = " ".join([str(i) for i in range(100000)]) messages = [ llm._system_msg("You are a helpful assistant"), @@ -193,7 +193,7 @@ def test_count_tokens_long(): cnt = llm.count_tokens(messages) # 299023, ~300k assert 290000 <= cnt <= 300000 - llm.config.model = "test_llm" # a non-openai model, will use heuristics base count_tokens + llm.model = "test_llm" # a non-openai model, will use heuristics base count_tokens cnt = llm.count_tokens(messages) # 294474, ~300k, ~2% difference assert 290000 <= cnt <= 300000 @@ -202,7 +202,7 @@ def test_count_tokens_long(): @pytest.mark.asyncio async def test_aask_long(): llm = LLM() - llm.config.model = "deepseek-ai/DeepSeek-Coder-V2-Instruct" # deepseek-coder on siliconflow, limit 32k + llm.model = "deepseek-ai/DeepSeek-Coder-V2-Instruct" # deepseek-coder on siliconflow, limit 32k llm.config.compress_type = CompressType.POST_CUT_BY_TOKEN test_msg_content = " ".join([str(i) for i in range(100000)]) # corresponds to ~300k tokens messages = [ @@ -216,7 +216,7 @@ async def test_aask_long(): @pytest.mark.asyncio async def test_aask_long_no_compress(): llm = LLM() - llm.config.model = "deepseek-ai/DeepSeek-Coder-V2-Instruct" # deepseek-coder on siliconflow, limit 32k + llm.model = "deepseek-ai/DeepSeek-Coder-V2-Instruct" # deepseek-coder on siliconflow, limit 32k # Not specifying llm.config.compress_type will use default "", no compress test_msg_content = " ".join([str(i) for i in range(100000)]) # corresponds to ~300k tokens messages = [