From 325e45247e4682c7cc1425d96059b1a308962eb0 Mon Sep 17 00:00:00 2001 From: better629 Date: Sun, 29 Sep 2024 15:15:29 +0800 Subject: [PATCH] support o1-series --- metagpt/configs/llm_config.py | 3 +++ metagpt/provider/llm_provider_registry.py | 6 +++++- metagpt/provider/openai_api.py | 4 ++++ metagpt/utils/token_counter.py | 1 + 4 files changed, 13 insertions(+), 1 deletion(-) diff --git a/metagpt/configs/llm_config.py b/metagpt/configs/llm_config.py index e7c280ee3..7388063aa 100644 --- a/metagpt/configs/llm_config.py +++ b/metagpt/configs/llm_config.py @@ -90,6 +90,9 @@ class LLMConfig(YamlModel): # Cost Control calc_usage: bool = True + # For Messages Control + use_system_prompt: bool = True + @field_validator("api_key") @classmethod def check_llm_key(cls, v): diff --git a/metagpt/provider/llm_provider_registry.py b/metagpt/provider/llm_provider_registry.py index 4fd2b1978..7f8618590 100644 --- a/metagpt/provider/llm_provider_registry.py +++ b/metagpt/provider/llm_provider_registry.py @@ -37,7 +37,11 @@ def register_provider(keys): def create_llm_instance(config: LLMConfig) -> BaseLLM: """get the default llm provider""" - return LLM_REGISTRY.get_provider(config.api_type)(config) + llm = LLM_REGISTRY.get_provider(config.api_type)(config) + if llm.use_system_prompt and not config.use_system_prompt: + # for models like o1-series, default openai provider.use_system_prompt is True, but it should be False for o1-* + llm.use_system_prompt = config.use_system_prompt + return llm # Registry instance diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 1e0ccd206..ce3a06ec8 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -138,6 +138,10 @@ class OpenAILLM(BaseLLM): "model": self.model, "timeout": self.get_timeout(timeout), } + if "o1-" in self.model: + # compatible to openai o1-series + kwargs["temperature"] = 1 + kwargs.pop("max_tokens") if extra_kwargs: kwargs.update(extra_kwargs) return kwargs diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index 90fd9e960..c922f2cb4 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -397,6 +397,7 @@ def count_input_tokens(messages, model="gpt-3.5-turbo-0125"): "gpt-4-turbo", "gpt-4-turbo-preview", "gpt-4-0125-preview", + "gpt-4-1106-preview", "gpt-4-turbo", "gpt-4-vision-preview", "gpt-4-1106-vision-preview",