diff --git a/config/config.yaml b/config/config.yaml index ab4d49f5d..5025a4977 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -141,3 +141,4 @@ TIMEOUT: 60 # Timeout for llm invocation #REDIS_PASSWORD: "YOUR_REDIS_PASSWORD" #REDIS_DB: "YOUR_REDIS_DB_INDEX, str, 0-based" +# DISABLE_LLM_PROVIDER_CHECK: false diff --git a/metagpt/actions/research.py b/metagpt/actions/research.py index 074cdee0a..c47a77bdd 100644 --- a/metagpt/actions/research.py +++ b/metagpt/actions/research.py @@ -180,9 +180,11 @@ class WebBrowseAndSummarize(Action): llm: BaseGPTAPI = Field(default_factory=LLM) desc: str = "Explore the web and provide summaries of articles and webpages." browse_func: Union[Callable[[list[str]], None], None] = None - web_browser_engine: WebBrowserEngine = WebBrowserEngine( - engine=WebBrowserEngineType.CUSTOM if browse_func else None, - run_func=browse_func, + web_browser_engine: WebBrowserEngine = Field( + default_factory=lambda: WebBrowserEngine( + engine=WebBrowserEngineType.CUSTOM if WebBrowseAndSummarize.browse_func else None, + run_func=WebBrowseAndSummarize.browse_func, + ) ) def __init__(self, **kwargs): diff --git a/metagpt/config.py b/metagpt/config.py index 222254ac7..1ce12216d 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -138,7 +138,9 @@ class Config(metaclass=Singleton): self.gemini_api_key = self._get("GEMINI_API_KEY") self.ollama_api_base = self._get("OLLAMA_API_BASE") self.ollama_api_model = self._get("OLLAMA_API_MODEL") - # _ = self.get_default_llm_provider_enum() + + if not self._get("DISABLE_LLM_PROVIDER_CHECK"): + _ = self.get_default_llm_provider_enum() # self.openai_base_url = self._get("OPENAI_BASE_URL") self.openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy diff --git a/metagpt/llm.py b/metagpt/llm.py index 8763642f0..f1cb98dae 100644 --- a/metagpt/llm.py +++ b/metagpt/llm.py @@ -6,6 +6,8 @@ @File : llm.py """ +from typing import Optional + from metagpt.config import CONFIG, LLMProviderEnum from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.provider.human_provider import HumanProvider @@ -14,6 +16,9 @@ from metagpt.provider.llm_provider_registry import LLM_REGISTRY _ = HumanProvider() # Avoid pre-commit error -def LLM(provider: LLMProviderEnum = CONFIG.get_default_llm_provider_enum()) -> BaseGPTAPI: +def LLM(provider: Optional[LLMProviderEnum] = None) -> BaseGPTAPI: """get the default llm provider""" + if provider is None: + provider = CONFIG.get_default_llm_provider_enum() + return LLM_REGISTRY.get_provider(provider) diff --git a/metagpt/logs.py b/metagpt/logs.py index ab1bc4e94..fb0fdd553 100644 --- a/metagpt/logs.py +++ b/metagpt/logs.py @@ -8,6 +8,7 @@ import sys from datetime import datetime +from functools import partial from loguru import logger as _logger @@ -26,3 +27,15 @@ def define_log_level(print_level="INFO", logfile_level="DEBUG"): logger = define_log_level() + + +def log_llm_stream(msg): + _llm_stream_log(msg) + + +def set_llm_stream_logfunc(func): + global _llm_stream_log + _llm_stream_log = func + + +_llm_stream_log = partial(print, end="") diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index e9d3ea70d..eace329aa 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -20,7 +20,7 @@ from tenacity import ( ) from metagpt.config import CONFIG, LLMProviderEnum -from metagpt.logs import logger +from metagpt.logs import log_llm_stream, logger from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.provider.llm_provider_registry import register_provider from metagpt.provider.openai_api import log_and_reraise @@ -121,7 +121,7 @@ class GeminiGPTAPI(BaseGPTAPI): collected_content = [] async for chunk in resp: content = chunk.text - print(content, end="") + log_llm_stream(content) collected_content.append(content) full_content = "".join(collected_content) diff --git a/metagpt/provider/ollama_api.py b/metagpt/provider/ollama_api.py index 7d858e769..90a50a154 100644 --- a/metagpt/provider/ollama_api.py +++ b/metagpt/provider/ollama_api.py @@ -15,7 +15,7 @@ from tenacity import ( from metagpt.config import CONFIG, LLMProviderEnum from metagpt.const import LLM_API_TIMEOUT -from metagpt.logs import logger +from metagpt.logs import log_llm_stream, logger from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.provider.general_api_requestor import GeneralAPIRequestor from metagpt.provider.llm_provider_registry import register_provider @@ -131,7 +131,7 @@ class OllamaGPTAPI(BaseGPTAPI): if not chunk.get("done", False): content = self.get_choice_text(chunk) collected_content.append(content) - print(content, end="") + log_llm_stream(content) else: # stream finished usage = self.get_usage(chunk) diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 1d2cdb591..195d2ea16 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -29,7 +29,7 @@ from tenacity import ( from metagpt.config import CONFIG, Config, LLMProviderEnum from metagpt.const import DEFAULT_MAX_TOKENS, DEFAULT_TOKEN_SIZE -from metagpt.logs import logger +from metagpt.logs import log_llm_stream, logger from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA, GENERAL_TOOL_CHOICE from metagpt.provider.llm_provider_registry import register_provider @@ -180,7 +180,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): collected_messages = [] async for i in resp: - print(i, end="") + log_llm_stream(i) collected_messages.append(i) full_reply_content = "".join(collected_messages) diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index 0d5663431..8d57cd444 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -16,7 +16,7 @@ from tenacity import ( ) from metagpt.config import CONFIG, LLMProviderEnum -from metagpt.logs import logger +from metagpt.logs import log_llm_stream, logger from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.provider.llm_provider_registry import register_provider from metagpt.provider.openai_api import log_and_reraise @@ -96,7 +96,7 @@ class ZhiPuAIGPTAPI(BaseGPTAPI): if event.event == ZhiPuEvent.ADD.value: content = event.data collected_content.append(content) - print(content, end="") + log_llm_stream(content) elif event.event == ZhiPuEvent.ERROR.value or event.event == ZhiPuEvent.INTERRUPTED.value: content = event.data logger.error(f"event error: {content}", end="") diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 23a7faaae..3e5f268f8 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -154,7 +154,7 @@ class Role(BaseModel): builtin_class_name: str = "" _private_attributes = { - "_llm": LLM() if not is_human else HumanProvider(), + "_llm": None, "_role_id": _role_id, "_states": [], "_actions": [],