diff --git a/config/config.yaml b/config/config.yaml index ab4d49f5d..033ab0af2 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -121,6 +121,7 @@ TIMEOUT: 60 # Timeout for llm invocation # PROMPT_FORMAT: json #json or markdown +<<<<<<< HEAD ### Agent configurations # RAISE_NOT_CONFIG_ERROR: true # "true" if the LLM key is not configured, throw a NotConfiguredException, else "false". # WORKSPACE_PATH_WITH_UID: false # "true" if using `{workspace}/{uid}` as the workspace path; "false" use `{workspace}`. @@ -141,3 +142,6 @@ TIMEOUT: 60 # Timeout for llm invocation #REDIS_PASSWORD: "YOUR_REDIS_PASSWORD" #REDIS_DB: "YOUR_REDIS_DB_INDEX, str, 0-based" +======= +# DISABLE_LLM_PROVIDER_CHECK: false +>>>>>>> main diff --git a/examples/search_kb.py b/examples/search_kb.py index be15846d4..0e0e0ffd0 100644 --- a/examples/search_kb.py +++ b/examples/search_kb.py @@ -6,12 +6,20 @@ """ import asyncio -from metagpt.const import EXAMPLE_PATH +from langchain.embeddings import OpenAIEmbeddings + +from metagpt.config import CONFIG +from metagpt.const import DATA_PATH, EXAMPLE_PATH from metagpt.document_store import FaissStore from metagpt.logs import logger from metagpt.roles import Sales +def get_store(): + embedding = OpenAIEmbeddings(openai_api_key=CONFIG.openai_api_key, openai_api_base=CONFIG.openai_base_url) + return FaissStore(DATA_PATH / "example.json", embedding=embedding) + + async def search(): store = FaissStore(EXAMPLE_PATH / "example.json") role = Sales(profile="Sales", store=store) diff --git a/metagpt/actions/research.py b/metagpt/actions/research.py index 074cdee0a..c47a77bdd 100644 --- a/metagpt/actions/research.py +++ b/metagpt/actions/research.py @@ -180,9 +180,11 @@ class WebBrowseAndSummarize(Action): llm: BaseGPTAPI = Field(default_factory=LLM) desc: str = "Explore the web and provide summaries of articles and webpages." browse_func: Union[Callable[[list[str]], None], None] = None - web_browser_engine: WebBrowserEngine = WebBrowserEngine( - engine=WebBrowserEngineType.CUSTOM if browse_func else None, - run_func=browse_func, + web_browser_engine: WebBrowserEngine = Field( + default_factory=lambda: WebBrowserEngine( + engine=WebBrowserEngineType.CUSTOM if WebBrowseAndSummarize.browse_func else None, + run_func=WebBrowseAndSummarize.browse_func, + ) ) def __init__(self, **kwargs): diff --git a/metagpt/config.py b/metagpt/config.py index 222254ac7..1ce12216d 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -138,7 +138,9 @@ class Config(metaclass=Singleton): self.gemini_api_key = self._get("GEMINI_API_KEY") self.ollama_api_base = self._get("OLLAMA_API_BASE") self.ollama_api_model = self._get("OLLAMA_API_MODEL") - # _ = self.get_default_llm_provider_enum() + + if not self._get("DISABLE_LLM_PROVIDER_CHECK"): + _ = self.get_default_llm_provider_enum() # self.openai_base_url = self._get("OPENAI_BASE_URL") self.openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy diff --git a/metagpt/llm.py b/metagpt/llm.py index 8763642f0..f1cb98dae 100644 --- a/metagpt/llm.py +++ b/metagpt/llm.py @@ -6,6 +6,8 @@ @File : llm.py """ +from typing import Optional + from metagpt.config import CONFIG, LLMProviderEnum from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.provider.human_provider import HumanProvider @@ -14,6 +16,9 @@ from metagpt.provider.llm_provider_registry import LLM_REGISTRY _ = HumanProvider() # Avoid pre-commit error -def LLM(provider: LLMProviderEnum = CONFIG.get_default_llm_provider_enum()) -> BaseGPTAPI: +def LLM(provider: Optional[LLMProviderEnum] = None) -> BaseGPTAPI: """get the default llm provider""" + if provider is None: + provider = CONFIG.get_default_llm_provider_enum() + return LLM_REGISTRY.get_provider(provider) diff --git a/metagpt/logs.py b/metagpt/logs.py index ab1bc4e94..fb0fdd553 100644 --- a/metagpt/logs.py +++ b/metagpt/logs.py @@ -8,6 +8,7 @@ import sys from datetime import datetime +from functools import partial from loguru import logger as _logger @@ -26,3 +27,15 @@ def define_log_level(print_level="INFO", logfile_level="DEBUG"): logger = define_log_level() + + +def log_llm_stream(msg): + _llm_stream_log(msg) + + +def set_llm_stream_logfunc(func): + global _llm_stream_log + _llm_stream_log = func + + +_llm_stream_log = partial(print, end="") diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index e9d3ea70d..eace329aa 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -20,7 +20,7 @@ from tenacity import ( ) from metagpt.config import CONFIG, LLMProviderEnum -from metagpt.logs import logger +from metagpt.logs import log_llm_stream, logger from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.provider.llm_provider_registry import register_provider from metagpt.provider.openai_api import log_and_reraise @@ -121,7 +121,7 @@ class GeminiGPTAPI(BaseGPTAPI): collected_content = [] async for chunk in resp: content = chunk.text - print(content, end="") + log_llm_stream(content) collected_content.append(content) full_content = "".join(collected_content) diff --git a/metagpt/provider/ollama_api.py b/metagpt/provider/ollama_api.py index 7d858e769..90a50a154 100644 --- a/metagpt/provider/ollama_api.py +++ b/metagpt/provider/ollama_api.py @@ -15,7 +15,7 @@ from tenacity import ( from metagpt.config import CONFIG, LLMProviderEnum from metagpt.const import LLM_API_TIMEOUT -from metagpt.logs import logger +from metagpt.logs import log_llm_stream, logger from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.provider.general_api_requestor import GeneralAPIRequestor from metagpt.provider.llm_provider_registry import register_provider @@ -131,7 +131,7 @@ class OllamaGPTAPI(BaseGPTAPI): if not chunk.get("done", False): content = self.get_choice_text(chunk) collected_content.append(content) - print(content, end="") + log_llm_stream(content) else: # stream finished usage = self.get_usage(chunk) diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index 0d5663431..8d57cd444 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -16,7 +16,7 @@ from tenacity import ( ) from metagpt.config import CONFIG, LLMProviderEnum -from metagpt.logs import logger +from metagpt.logs import log_llm_stream, logger from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.provider.llm_provider_registry import register_provider from metagpt.provider.openai_api import log_and_reraise @@ -96,7 +96,7 @@ class ZhiPuAIGPTAPI(BaseGPTAPI): if event.event == ZhiPuEvent.ADD.value: content = event.data collected_content.append(content) - print(content, end="") + log_llm_stream(content) elif event.event == ZhiPuEvent.ERROR.value or event.event == ZhiPuEvent.INTERRUPTED.value: content = event.data logger.error(f"event error: {content}", end="") diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 23a7faaae..3e5f268f8 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -154,7 +154,7 @@ class Role(BaseModel): builtin_class_name: str = "" _private_attributes = { - "_llm": LLM() if not is_human else HumanProvider(), + "_llm": None, "_role_id": _role_id, "_states": [], "_actions": [], diff --git a/metagpt/roles/sk_agent.py b/metagpt/roles/sk_agent.py index 791dff5e2..6063205bd 100644 --- a/metagpt/roles/sk_agent.py +++ b/metagpt/roles/sk_agent.py @@ -7,13 +7,13 @@ @Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.2.1 and 2.2.2 of RFC 116, utilize the new message distribution feature for message filtering. """ +from typing import Any, Type from pydantic import Field from semantic_kernel import Kernel -from semantic_kernel.orchestration.sk_function_base import SKFunctionBase from semantic_kernel.planning import SequentialPlanner from semantic_kernel.planning.action_planner.action_planner import ActionPlanner -from semantic_kernel.planning.basic_planner import BasicPlanner, Plan +from semantic_kernel.planning.basic_planner import BasicPlanner from metagpt.actions import UserRequirement from metagpt.actions.execute_task import ExecuteTask @@ -41,13 +41,13 @@ class SkAgent(Role): goal: str = "Execute task based on passed in task description" constraints: str = "" - plan: Plan = None - planner_cls: BasicPlanner = BasicPlanner - planner: BasicPlanner = Field(default_factory=BasicPlanner) + plan: Any = None + planner_cls: Any = None + planner: Any = None llm: BaseGPTAPI = Field(default_factory=LLM) kernel: Kernel = Field(default_factory=Kernel) - import_semantic_skill_from_directory: str = "" - import_skill: dict[str, SKFunctionBase] = dict() + import_semantic_skill_from_directory: Type[Kernel.import_semantic_skill_from_directory] = None + import_skill: Type[Kernel.import_skill] = None def __init__(self, **kwargs) -> None: """Initializes the Engineer role with given attributes.""" @@ -57,8 +57,8 @@ class SkAgent(Role): self.kernel = make_sk_kernel() # how funny the interface is inconsistent - if self.planner_cls == BasicPlanner: - self.planner = self.planner_cls() + if self.planner_cls == BasicPlanner or self.planner_cls is None: + self.planner = BasicPlanner() elif self.planner_cls in [SequentialPlanner, ActionPlanner]: self.planner = self.planner_cls(self.kernel) else: @@ -78,6 +78,7 @@ class SkAgent(Role): async def _act(self) -> Message: # how funny the interface is inconsistent + result = None if isinstance(self.planner, BasicPlanner): result = await self.planner.execute_plan_async(self.plan, self.kernel) elif any(isinstance(self.planner, cls) for cls in [SequentialPlanner, ActionPlanner]):