diff --git a/examples/agent_creator.py b/examples/agent_creator.py index bd58840ce..34160d398 100644 --- a/examples/agent_creator.py +++ b/examples/agent_creator.py @@ -6,12 +6,13 @@ Author: garylin2099 import re from metagpt.actions import Action -from metagpt.config2 import config +from metagpt.config2 import Config from metagpt.const import METAGPT_ROOT from metagpt.logs import logger from metagpt.roles import Role from metagpt.schema import Message +config = Config.default() EXAMPLE_CODE_FILE = METAGPT_ROOT / "examples/build_customized_agent.py" MULTI_ACTION_AGENT_CODE_EXAMPLE = EXAMPLE_CODE_FILE.read_text() diff --git a/metagpt/actions/rebuild_class_view.py b/metagpt/actions/rebuild_class_view.py index ff030ec87..64f003f91 100644 --- a/metagpt/actions/rebuild_class_view.py +++ b/metagpt/actions/rebuild_class_view.py @@ -14,7 +14,6 @@ from typing import Optional, Set, Tuple import aiofiles from metagpt.actions import Action -from metagpt.config2 import config from metagpt.const import ( AGGREGATION, COMPOSITION, @@ -40,7 +39,7 @@ class RebuildClassView(Action): graph_db: Optional[GraphRepository] = None - async def run(self, with_messages=None, format=config.prompt_schema): + async def run(self, with_messages=None, format=None): """ Implementation of `Action`'s `run` method. @@ -48,6 +47,7 @@ class RebuildClassView(Action): with_messages (Optional[Type]): An optional argument specifying messages to react to. format (str): The format for the prompt schema. """ + format = format if format else self.config.prompt_schema graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name self.graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json"))) repo_parser = RepoParser(base_directory=Path(self.i_context)) diff --git a/metagpt/actions/rebuild_sequence_view.py b/metagpt/actions/rebuild_sequence_view.py index fd356d58f..e23487511 100644 --- a/metagpt/actions/rebuild_sequence_view.py +++ b/metagpt/actions/rebuild_sequence_view.py @@ -18,7 +18,6 @@ from pydantic import BaseModel from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions import Action -from metagpt.config2 import config from metagpt.const import GRAPH_REPO_FILE_REPO from metagpt.logs import logger from metagpt.repo_parser import CodeBlockInfo, DotClassInfo @@ -84,7 +83,7 @@ class RebuildSequenceView(Action): graph_db: Optional[GraphRepository] = None - async def run(self, with_messages=None, format=config.prompt_schema): + async def run(self, with_messages=None, format=None): """ Implementation of `Action`'s `run` method. @@ -92,6 +91,7 @@ class RebuildSequenceView(Action): with_messages (Optional[Type]): An optional argument specifying messages to react to. format (str): The format for the prompt schema. """ + format = format if format else self.config.prompt_schema graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name self.graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json"))) if not self.i_context: diff --git a/metagpt/actions/research.py b/metagpt/actions/research.py index 5e670520c..98edfddb0 100644 --- a/metagpt/actions/research.py +++ b/metagpt/actions/research.py @@ -8,7 +8,6 @@ from typing import Any, Callable, Coroutine, Optional, Union from pydantic import TypeAdapter, model_validator from metagpt.actions import Action -from metagpt.config2 import config from metagpt.logs import logger from metagpt.tools.search_engine import SearchEngine from metagpt.tools.web_browser_engine import WebBrowserEngine @@ -134,8 +133,8 @@ class CollectLinks(Action): if len(remove) == 0: break - model_name = config.llm.model - prompt = reduce_message_length(gen_msg(), model_name, system_text, config.llm.max_token) + model_name = self.config.llm.model + prompt = reduce_message_length(gen_msg(), model_name, system_text, self.config.llm.max_token) logger.debug(prompt) queries = await self._aask(prompt, [system_text]) try: diff --git a/metagpt/actions/talk_action.py b/metagpt/actions/talk_action.py index 81f66f9a1..3fec32783 100644 --- a/metagpt/actions/talk_action.py +++ b/metagpt/actions/talk_action.py @@ -9,7 +9,6 @@ from typing import Optional from metagpt.actions import Action -from metagpt.config2 import config from metagpt.logs import logger from metagpt.schema import Message @@ -26,7 +25,7 @@ class TalkAction(Action): @property def language(self): - return self.context.kwargs.language or config.language + return self.context.kwargs.language or self.config.language @property def prompt(self): diff --git a/metagpt/config2.py b/metagpt/config2.py index 6588a6036..8ed9d3f6b 100644 --- a/metagpt/config2.py +++ b/metagpt/config2.py @@ -97,20 +97,21 @@ class Config(CLIParams, YamlModel): return Config.from_yaml_file(pathname) @classmethod - def default(cls): + def default(cls, reload: bool = False): """Load default config - Priority: env < default_config_paths - Inside default_config_paths, the latter one overwrites the former one """ - default_config_paths: List[Path] = [ + default_config_paths = ( METAGPT_ROOT / "config/config2.yaml", CONFIG_ROOT / "config2.yaml", - ] - - dicts = [dict(os.environ)] - dicts += [Config.read_yaml(path) for path in default_config_paths] - final = merge_dict(dicts) - return Config(**final) + ) + if reload or default_config_paths not in _CONFIG_CACHE: + dicts = [dict(os.environ)] + dicts += [Config.read_yaml(path) for path in default_config_paths] + final = merge_dict(dicts) + _CONFIG_CACHE[default_config_paths] = Config(**final) + return _CONFIG_CACHE[default_config_paths] @classmethod def from_llm_config(cls, llm_config: dict): @@ -160,4 +161,4 @@ def merge_dict(dicts: Iterable[Dict]) -> Dict: return result -config = Config.default() +_CONFIG_CACHE = {} diff --git a/metagpt/configs/browser_config.py b/metagpt/configs/browser_config.py index 2f8024f44..fafbaeeb8 100644 --- a/metagpt/configs/browser_config.py +++ b/metagpt/configs/browser_config.py @@ -5,12 +5,23 @@ @Author : alexanderwu @File : browser_config.py """ +from enum import Enum from typing import Literal -from metagpt.tools import WebBrowserEngineType from metagpt.utils.yaml_model import YamlModel +class WebBrowserEngineType(Enum): + PLAYWRIGHT = "playwright" + SELENIUM = "selenium" + CUSTOM = "custom" + + @classmethod + def __missing__(cls, key): + """Default type conversion""" + return cls.CUSTOM + + class BrowserConfig(YamlModel): """Config for Browser""" diff --git a/metagpt/configs/search_config.py b/metagpt/configs/search_config.py index 7b50fb6d3..2c773b685 100644 --- a/metagpt/configs/search_config.py +++ b/metagpt/configs/search_config.py @@ -5,14 +5,23 @@ @Author : alexanderwu @File : search_config.py """ +from enum import Enum from typing import Callable, Optional from pydantic import ConfigDict, Field -from metagpt.tools import SearchEngineType from metagpt.utils.yaml_model import YamlModel +class SearchEngineType(Enum): + SERPAPI_GOOGLE = "serpapi" + SERPER_GOOGLE = "serper" + DIRECT_GOOGLE = "google" + DUCK_DUCK_GO = "ddg" + CUSTOM_ENGINE = "custom" + BING = "bing" + + class SearchConfig(YamlModel): """Config for Search""" diff --git a/metagpt/context.py b/metagpt/context.py index 384e8da48..0769f78eb 100644 --- a/metagpt/context.py +++ b/metagpt/context.py @@ -10,7 +10,7 @@ from __future__ import annotations import os from typing import Any, Dict, Optional -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field from metagpt.config2 import Config from metagpt.configs.llm_config import LLMConfig, LLMType @@ -61,7 +61,7 @@ class Context(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) kwargs: AttrDict = AttrDict() - config: Config = Config.default() + config: Config = Field(default_factory=Config.default) cost_manager: CostManager = CostManager() diff --git a/metagpt/environment/minecraft/minecraft_env.py b/metagpt/environment/minecraft/minecraft_env.py index 0f39c9ccd..2bf39095c 100644 --- a/metagpt/environment/minecraft/minecraft_env.py +++ b/metagpt/environment/minecraft/minecraft_env.py @@ -11,7 +11,7 @@ from typing import Any, Iterable from llama_index.vector_stores.chroma import ChromaVectorStore from pydantic import ConfigDict, Field -from metagpt.config2 import config as CONFIG +from metagpt.config2 import Config from metagpt.environment.base_env import Environment from metagpt.environment.minecraft.const import MC_CKPT_DIR from metagpt.environment.minecraft.minecraft_ext_env import MinecraftExtEnv @@ -82,7 +82,7 @@ class MinecraftEnv(Environment, MinecraftExtEnv): persist_dir=f"{MC_CKPT_DIR}/skill/vectordb", ) - if CONFIG.resume: + if Config.default().resume: logger.info(f"Loading Action Developer from {MC_CKPT_DIR}/action") self.chest_memory = read_json_file(f"{MC_CKPT_DIR}/action/chest_memory.json") diff --git a/metagpt/exp_pool/decorator.py b/metagpt/exp_pool/decorator.py index 777d55ca9..888e61743 100644 --- a/metagpt/exp_pool/decorator.py +++ b/metagpt/exp_pool/decorator.py @@ -6,7 +6,7 @@ from typing import Any, Callable, Optional, TypeVar from pydantic import BaseModel, ConfigDict, model_validator -from metagpt.config2 import config +from metagpt.config2 import Config from metagpt.exp_pool.context_builders import BaseContextBuilder, SimpleContextBuilder from metagpt.exp_pool.manager import ExperienceManager, exp_manager from metagpt.exp_pool.perfect_judges import BasePerfectJudge, SimplePerfectJudge @@ -50,11 +50,14 @@ def exp_cache( """ def decorator(func: Callable[..., ReturnType]) -> Callable[..., ReturnType]: - if not config.exp_pool.enabled: - return func - @functools.wraps(func) async def get_or_create(args: Any, kwargs: Any) -> ReturnType: + config = Config.default() + + if not config.exp_pool.enabled: + rsp = func(*args, **kwargs) + return await rsp if asyncio.iscoroutine(rsp) else rsp + handler = ExpCacheHandler( func=func, args=args, diff --git a/metagpt/exp_pool/manager.py b/metagpt/exp_pool/manager.py index b6ae9c0a3..253d45508 100644 --- a/metagpt/exp_pool/manager.py +++ b/metagpt/exp_pool/manager.py @@ -2,9 +2,9 @@ from typing import TYPE_CHECKING, Any -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field -from metagpt.config2 import Config, config +from metagpt.config2 import Config from metagpt.exp_pool.schema import ( DEFAULT_COLLECTION_NAME, DEFAULT_SIMILARITY_TOP_K, @@ -29,7 +29,7 @@ class ExperienceManager(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) - config: Config = config + config: Config = Field(default_factory=Config.default) _storage: Any = None _vector_store: Any = None diff --git a/metagpt/ext/stanford_town/actions/st_action.py b/metagpt/ext/stanford_town/actions/st_action.py index 321676374..48cda353c 100644 --- a/metagpt/ext/stanford_town/actions/st_action.py +++ b/metagpt/ext/stanford_town/actions/st_action.py @@ -8,7 +8,6 @@ from pathlib import Path from typing import Any, Optional, Union from metagpt.actions.action import Action -from metagpt.config2 import config from metagpt.ext.stanford_town.utils.const import PROMPTS_DIR from metagpt.logs import logger @@ -62,13 +61,13 @@ class STAction(Action): async def _run_gpt35_max_tokens(self, prompt: str, max_tokens: int = 50, retry: int = 3): for idx in range(retry): try: - tmp_max_tokens_rsp = getattr(config.llm, "max_token", 1500) - setattr(config.llm, "max_token", max_tokens) + tmp_max_tokens_rsp = getattr(self.config.llm, "max_token", 1500) + setattr(self.config.llm, "max_token", max_tokens) self.llm.use_system_prompt = False # to make it behave like a non-chat completions llm_resp = await self._aask(prompt) - setattr(config.llm, "max_token", tmp_max_tokens_rsp) + setattr(self.config.llm, "max_token", tmp_max_tokens_rsp) logger.info(f"Action: {self.cls_name} llm _run_gpt35_max_tokens raw resp: {llm_resp}") if self._func_validate(llm_resp, prompt): return self._func_cleanup(llm_resp, prompt) diff --git a/metagpt/ext/stanford_town/utils/utils.py b/metagpt/ext/stanford_town/utils/utils.py index 3aa0e80e8..e09cce8fe 100644 --- a/metagpt/ext/stanford_town/utils/utils.py +++ b/metagpt/ext/stanford_town/utils/utils.py @@ -13,7 +13,7 @@ from typing import Union from openai import OpenAI -from metagpt.config2 import config +from metagpt.config2 import Config from metagpt.logs import logger @@ -48,6 +48,7 @@ def read_csv_to_list(curr_file: str, header=False, strip_trail=True): def get_embedding(text, model: str = "text-embedding-ada-002"): + config = Config.default() text = text.replace("\n", " ") if not text: text = "this is blank" diff --git a/metagpt/learn/text_to_embedding.py b/metagpt/learn/text_to_embedding.py index f859ab638..2b4adda80 100644 --- a/metagpt/learn/text_to_embedding.py +++ b/metagpt/learn/text_to_embedding.py @@ -6,12 +6,13 @@ @File : text_to_embedding.py @Desc : Text-to-Embedding skill, which provides text-to-embedding functionality. """ -import metagpt.config2 +from typing import Optional + from metagpt.config2 import Config from metagpt.tools.openai_text_to_embedding import oas3_openai_text_to_embedding -async def text_to_embedding(text, model="text-embedding-ada-002", config: Config = metagpt.config2.config): +async def text_to_embedding(text, model="text-embedding-ada-002", config: Optional[Config] = None): """Text to embedding :param text: The text used for embedding. @@ -19,6 +20,7 @@ async def text_to_embedding(text, model="text-embedding-ada-002", config: Config :param config: OpenAI config with API key, For more details, checkout: `https://platform.openai.com/account/api-keys` :return: A json object of :class:`ResultEmbedding` class if successful, otherwise `{}`. """ + config = config if config else Config.default() openai_api_key = config.get_openai_llm().api_key proxy = config.get_openai_llm().proxy return await oas3_openai_text_to_embedding(text, model=model, openai_api_key=openai_api_key, proxy=proxy) diff --git a/metagpt/learn/text_to_image.py b/metagpt/learn/text_to_image.py index 163859fc0..9bfed532b 100644 --- a/metagpt/learn/text_to_image.py +++ b/metagpt/learn/text_to_image.py @@ -7,8 +7,8 @@ @Desc : Text-to-Image skill, which provides text-to-image functionality. """ import base64 +from typing import Optional -import metagpt.config2 from metagpt.config2 import Config from metagpt.const import BASE64_FORMAT from metagpt.llm import LLM @@ -17,7 +17,7 @@ from metagpt.tools.openai_text_to_image import oas3_openai_text_to_image from metagpt.utils.s3 import S3 -async def text_to_image(text, size_type: str = "512x512", config: Config = metagpt.config2.config): +async def text_to_image(text, size_type: str = "512x512", config: Optional[Config] = None): """Text to image :param text: The text used for image conversion. @@ -25,6 +25,7 @@ async def text_to_image(text, size_type: str = "512x512", config: Config = metag :param config: Config :return: The image data is returned in Base64 encoding. """ + config = config if config else Config.default() image_declaration = "data:image/png;base64," model_url = config.metagpt_tti_url diff --git a/metagpt/learn/text_to_speech.py b/metagpt/learn/text_to_speech.py index 8dbd6d243..9d3dba685 100644 --- a/metagpt/learn/text_to_speech.py +++ b/metagpt/learn/text_to_speech.py @@ -6,7 +6,8 @@ @File : text_to_speech.py @Desc : Text-to-Speech skill, which provides text-to-speech functionality """ -import metagpt.config2 +from typing import Optional + from metagpt.config2 import Config from metagpt.const import BASE64_FORMAT from metagpt.tools.azure_tts import oas3_azsure_tts @@ -20,7 +21,7 @@ async def text_to_speech( voice="zh-CN-XiaomoNeural", style="affectionate", role="Girl", - config: Config = metagpt.config2.config, + config: Optional[Config] = None, ): """Text to speech For more details, check out:`https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts` @@ -38,7 +39,7 @@ async def text_to_speech( :return: Returns the Base64-encoded .wav/.mp3 file data if successful, otherwise an empty string. """ - + config = config if config else Config.default() subscription_key = config.azure_tts_subscription_key region = config.azure_tts_region if subscription_key and region: diff --git a/metagpt/memory/brain_memory.py b/metagpt/memory/brain_memory.py index c58148ead..8c2846d1d 100644 --- a/metagpt/memory/brain_memory.py +++ b/metagpt/memory/brain_memory.py @@ -12,9 +12,9 @@ import json import re from typing import Dict, List, Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator -from metagpt.config2 import config +from metagpt.config2 import Config as _Config from metagpt.const import DEFAULT_MAX_TOKENS, DEFAULT_TOKEN_SIZE from metagpt.logs import logger from metagpt.provider import MetaGPTLLM @@ -32,6 +32,12 @@ class BrainMemory(BaseModel): last_talk: Optional[str] = None cacheable: bool = True llm: Optional[BaseLLM] = Field(default=None, exclude=True) + config: Optional[_Config] = None + + @field_validator("config") + @classmethod + def set_default_config(cls, config): + return config if config else _Config.default() class Config: arbitrary_types_allowed = True @@ -54,9 +60,8 @@ class BrainMemory(BaseModel): texts = [m.content for m in self.knowledge] return "\n".join(texts) - @staticmethod - async def loads(redis_key: str) -> "BrainMemory": - redis = Redis(config.redis) + async def loads(self, redis_key: str) -> "BrainMemory": + redis = Redis(self.config.redis) if not redis_key: return BrainMemory() v = await redis.get(key=redis_key) @@ -70,7 +75,7 @@ class BrainMemory(BaseModel): async def dumps(self, redis_key: str, timeout_sec: int = 30 * 60): if not self.is_dirty: return - redis = Redis(config.redis) + redis = Redis(self.config.redis) if not redis_key: return False v = self.model_dump_json() @@ -140,7 +145,7 @@ class BrainMemory(BaseModel): return text summary = await self._summarize(text=text, max_words=max_words, keep_language=keep_language, limit=limit) if summary: - await self.set_history_summary(history_summary=summary, redis_key=config.redis_key) + await self.set_history_summary(history_summary=summary, redis_key=self.config.redis_key) return summary raise ValueError(f"text too long:{text_length}") @@ -164,7 +169,7 @@ class BrainMemory(BaseModel): msgs.reverse() self.history = msgs self.is_dirty = True - await self.dumps(redis_key=config.redis.key) + await self.dumps(redis_key=self.config.redis.key) self.is_dirty = False return BrainMemory.to_metagpt_history_format(self.history) @@ -181,7 +186,7 @@ class BrainMemory(BaseModel): summary = await self.summarize(llm=llm, max_words=500) - language = config.language + language = self.config.language command = f"Translate the above summary into a {language} title of less than {max_words} words." summaries = [summary, command] msg = "\n".join(summaries) diff --git a/metagpt/rag/factories/embedding.py b/metagpt/rag/factories/embedding.py index 3613fd228..8a9d4bc95 100644 --- a/metagpt/rag/factories/embedding.py +++ b/metagpt/rag/factories/embedding.py @@ -1,7 +1,7 @@ """RAG Embedding Factory.""" from __future__ import annotations -from typing import Any +from typing import Any, Optional from llama_index.core.embeddings import BaseEmbedding from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding @@ -9,7 +9,7 @@ from llama_index.embeddings.gemini import GeminiEmbedding from llama_index.embeddings.ollama import OllamaEmbedding from llama_index.embeddings.openai import OpenAIEmbedding -from metagpt.config2 import config +from metagpt.config2 import Config from metagpt.configs.embedding_config import EmbeddingType from metagpt.configs.llm_config import LLMType from metagpt.rag.factories.base import GenericFactory @@ -18,7 +18,7 @@ from metagpt.rag.factories.base import GenericFactory class RAGEmbeddingFactory(GenericFactory): """Create LlamaIndex Embedding with MetaGPT's embedding config.""" - def __init__(self): + def __init__(self, config: Optional[Config] = None): creators = { EmbeddingType.OPENAI: self._create_openai, EmbeddingType.AZURE: self._create_azure, @@ -29,6 +29,7 @@ class RAGEmbeddingFactory(GenericFactory): LLMType.AZURE: self._create_azure, } super().__init__(creators) + self.config = config if self.config else Config.default() def get_rag_embedding(self, key: EmbeddingType = None) -> BaseEmbedding: """Key is EmbeddingType.""" @@ -40,18 +41,18 @@ class RAGEmbeddingFactory(GenericFactory): If the embedding type is not specified, for backward compatibility, it checks if the LLM API type is either OPENAI or AZURE. Raise TypeError if embedding type not found. """ - if config.embedding.api_type: - return config.embedding.api_type + if self.config.embedding.api_type: + return self.config.embedding.api_type - if config.llm.api_type in [LLMType.OPENAI, LLMType.AZURE]: - return config.llm.api_type + if self.config.llm.api_type in [LLMType.OPENAI, LLMType.AZURE]: + return self.config.llm.api_type raise TypeError("To use RAG, please set your embedding in config2.yaml.") def _create_openai(self) -> OpenAIEmbedding: params = dict( - api_key=config.embedding.api_key or config.llm.api_key, - api_base=config.embedding.base_url or config.llm.base_url, + api_key=self.config.embedding.api_key or self.config.llm.api_key, + api_base=self.config.embedding.base_url or self.config.llm.base_url, ) self._try_set_model_and_batch_size(params) @@ -60,9 +61,9 @@ class RAGEmbeddingFactory(GenericFactory): def _create_azure(self) -> AzureOpenAIEmbedding: params = dict( - api_key=config.embedding.api_key or config.llm.api_key, - azure_endpoint=config.embedding.base_url or config.llm.base_url, - api_version=config.embedding.api_version or config.llm.api_version, + api_key=self.config.embedding.api_key or self.config.llm.api_key, + azure_endpoint=self.config.embedding.base_url or self.config.llm.base_url, + api_version=self.config.embedding.api_version or self.config.llm.api_version, ) self._try_set_model_and_batch_size(params) @@ -71,8 +72,8 @@ class RAGEmbeddingFactory(GenericFactory): def _create_gemini(self) -> GeminiEmbedding: params = dict( - api_key=config.embedding.api_key, - api_base=config.embedding.base_url, + api_key=self.config.embedding.api_key, + api_base=self.config.embedding.base_url, ) self._try_set_model_and_batch_size(params) @@ -81,7 +82,7 @@ class RAGEmbeddingFactory(GenericFactory): def _create_ollama(self) -> OllamaEmbedding: params = dict( - base_url=config.embedding.base_url, + base_url=self.config.embedding.base_url, ) self._try_set_model_and_batch_size(params) @@ -90,14 +91,15 @@ class RAGEmbeddingFactory(GenericFactory): def _try_set_model_and_batch_size(self, params: dict): """Set the model_name and embed_batch_size only when they are specified.""" - if config.embedding.model: - params["model_name"] = config.embedding.model + if self.config.embedding.model: + params["model_name"] = self.config.embedding.model - if config.embedding.embed_batch_size: - params["embed_batch_size"] = config.embedding.embed_batch_size + if self.config.embedding.embed_batch_size: + params["embed_batch_size"] = self.config.embedding.embed_batch_size def _raise_for_key(self, key: Any): raise ValueError(f"The embedding type is currently not supported: `{type(key)}`, {key}") -get_rag_embedding = RAGEmbeddingFactory().get_rag_embedding +def get_rag_embedding(key: EmbeddingType = None, config: Optional[Config] = None): + return RAGEmbeddingFactory(config=config).get_rag_embedding(key) diff --git a/metagpt/rag/factories/llm.py b/metagpt/rag/factories/llm.py index 9fd19cab5..5d27cde3a 100644 --- a/metagpt/rag/factories/llm.py +++ b/metagpt/rag/factories/llm.py @@ -10,9 +10,9 @@ from llama_index.core.llms import ( LLMMetadata, ) from llama_index.core.llms.callbacks import llm_completion_callback -from pydantic import Field +from pydantic import Field, model_validator -from metagpt.config2 import config +from metagpt.config2 import Config from metagpt.llm import LLM from metagpt.provider.base_llm import BaseLLM from metagpt.utils.async_helper import NestAsyncio @@ -26,9 +26,23 @@ class RAGLLM(CustomLLM): """ model_infer: BaseLLM = Field(..., description="The MetaGPT's LLM.") - context_window: int = TOKEN_MAX.get(config.llm.model, DEFAULT_CONTEXT_WINDOW) - num_output: int = config.llm.max_token - model_name: str = config.llm.model + context_window: int = -1 + num_output: int = -1 + model_name: str = "" + + @model_validator(mode="after") + def update_from_config(self): + config = Config.default() + if self.context_window < 0: + self.context_window = TOKEN_MAX.get(config.llm.model, DEFAULT_CONTEXT_WINDOW) + + if self.num_output < 0: + self.num_output = config.llm.max_token + + if not self.model_name: + self.model_name = config.llm.model + + return self @property def metadata(self) -> LLMMetadata: diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index a8a10f90e..5e97e60c3 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -10,7 +10,7 @@ from llama_index.core.schema import TextNode from llama_index.core.vector_stores.types import VectorStoreQueryMode from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator -from metagpt.config2 import config +from metagpt.config2 import Config from metagpt.configs.embedding_config import EmbeddingType from metagpt.logs import logger from metagpt.rag.interface import RAGObject @@ -45,6 +45,7 @@ class FAISSRetrieverConfig(IndexRetrieverConfig): @model_validator(mode="after") def check_dimensions(self): if self.dimensions == 0: + config = Config.default() self.dimensions = config.embedding.dimensions or self._embedding_type_to_dimensions.get( config.embedding.api_type, 1536 ) diff --git a/metagpt/software_company.py b/metagpt/software_company.py index 2ea16f55f..f74b61191 100644 --- a/metagpt/software_company.py +++ b/metagpt/software_company.py @@ -27,7 +27,7 @@ def generate_repo( recover_path=None, ): """Run the startup logic. Can be called from CLI or other Python scripts.""" - from metagpt.config2 import config + from metagpt.config2 import Config from metagpt.context import Context from metagpt.roles import ( Architect, @@ -38,6 +38,8 @@ def generate_repo( ) from metagpt.team import Team + config = Config.default() + config.update_via_cli(project_path, project_name, inc, reqa_file, max_auto_summarize_code) ctx = Context(config=config) diff --git a/metagpt/tools/__init__.py b/metagpt/tools/__init__.py index 35fa04658..2027dbb1d 100644 --- a/metagpt/tools/__init__.py +++ b/metagpt/tools/__init__.py @@ -6,33 +6,18 @@ @File : __init__.py """ -from enum import Enum from metagpt.tools import libs # this registers all tools from metagpt.tools.tool_registry import TOOL_REGISTRY +from metagpt.configs.search_config import SearchEngineType +from metagpt.configs.browser_config import WebBrowserEngineType + _ = libs, TOOL_REGISTRY # Avoid pre-commit error -class SearchEngineType(Enum): - SERPAPI_GOOGLE = "serpapi" - SERPER_GOOGLE = "serper" - DIRECT_GOOGLE = "google" - DUCK_DUCK_GO = "ddg" - CUSTOM_ENGINE = "custom" - BING = "bing" - - -class WebBrowserEngineType(Enum): - PLAYWRIGHT = "playwright" - SELENIUM = "selenium" - CUSTOM = "custom" - - @classmethod - def __missing__(cls, key): - """Default type conversion""" - return cls.CUSTOM - - class SearchInterface: async def asearch(self, *args, **kwargs): ... + + +__all__ = ["SearchEngineType", "WebBrowserEngineType", "TOOL_REGISTRY"] diff --git a/metagpt/tools/libs/gpt_v_generator.py b/metagpt/tools/libs/gpt_v_generator.py index baedc3d61..66c023766 100644 --- a/metagpt/tools/libs/gpt_v_generator.py +++ b/metagpt/tools/libs/gpt_v_generator.py @@ -7,7 +7,9 @@ """ import re from pathlib import Path +from typing import Optional +from metagpt.config2 import Config from metagpt.const import DEFAULT_WORKSPACE_ROOT from metagpt.logs import logger from metagpt.tools.tool_registry import register_tool @@ -36,11 +38,11 @@ class GPTvGenerator: It utilizes a vision model to analyze the layout from an image and generate webpage codes accordingly. """ - def __init__(self): + def __init__(self, config: Optional[Config]): """Initialize GPTvGenerator class with default values from the configuration.""" - from metagpt.config2 import config from metagpt.llm import LLM + config = config if config else Config.default() self.llm = LLM(llm_config=config.get_openai_llm()) self.llm.model = "gpt-4-vision-preview" diff --git a/metagpt/tools/ut_writer.py b/metagpt/tools/ut_writer.py index 243871aff..9e67a3585 100644 --- a/metagpt/tools/ut_writer.py +++ b/metagpt/tools/ut_writer.py @@ -4,7 +4,7 @@ import json from pathlib import Path -from metagpt.config2 import config +from metagpt.config2 import Config from metagpt.provider.openai_api import OpenAILLM as GPTAPI from metagpt.utils.common import awrite @@ -282,6 +282,7 @@ class UTGenerator: """Choose based on different calling methods""" result = "" if self.chatgpt_method == "API": + config = Config.default() result = await GPTAPI(config.get_openai_llm()).aask_code(messages=messages) return result diff --git a/metagpt/utils/embedding.py b/metagpt/utils/embedding.py index 3d53a314c..3fcf1f25b 100644 --- a/metagpt/utils/embedding.py +++ b/metagpt/utils/embedding.py @@ -7,10 +7,11 @@ """ from llama_index.embeddings.openai import OpenAIEmbedding -from metagpt.config2 import config +from metagpt.config2 import Config def get_embedding() -> OpenAIEmbedding: + config = Config.default() llm = config.get_openai_llm() if llm is None: raise ValueError("To use OpenAIEmbedding, please ensure that config.llm.api_type is correctly set to 'openai'.") diff --git a/metagpt/utils/make_sk_kernel.py b/metagpt/utils/make_sk_kernel.py index 283a682d6..f0c55b07c 100644 --- a/metagpt/utils/make_sk_kernel.py +++ b/metagpt/utils/make_sk_kernel.py @@ -13,10 +13,11 @@ from semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion impo OpenAIChatCompletion, ) -from metagpt.config2 import config +from metagpt.config2 import Config def make_sk_kernel(): + config = Config.default() kernel = sk.Kernel() if llm := config.get_azure_llm(): kernel.add_chat_service( diff --git a/metagpt/utils/mermaid.py b/metagpt/utils/mermaid.py index ba33b8d61..d87ae4f83 100644 --- a/metagpt/utils/mermaid.py +++ b/metagpt/utils/mermaid.py @@ -9,12 +9,14 @@ import asyncio import os from pathlib import Path -from metagpt.config2 import config +from metagpt.config2 import Config from metagpt.logs import logger from metagpt.utils.common import awrite, check_cmd_exists -async def mermaid_to_file(engine, mermaid_code, output_file_without_suffix, width=2048, height=2048) -> int: +async def mermaid_to_file( + engine, mermaid_code, output_file_without_suffix, width=2048, height=2048, config=None +) -> int: """suffix: png/svg/pdf :param mermaid_code: mermaid code @@ -24,6 +26,7 @@ async def mermaid_to_file(engine, mermaid_code, output_file_without_suffix, widt :return: 0 if succeed, -1 if failed """ # Write the Mermaid code to a temporary file + config = config if config else Config.default() dir_name = os.path.dirname(output_file_without_suffix) if dir_name and not os.path.exists(dir_name): os.makedirs(dir_name) diff --git a/metagpt/utils/mmdc_pyppeteer.py b/metagpt/utils/mmdc_pyppeteer.py index f029325f1..4e30ee538 100644 --- a/metagpt/utils/mmdc_pyppeteer.py +++ b/metagpt/utils/mmdc_pyppeteer.py @@ -10,11 +10,11 @@ from urllib.parse import urljoin from pyppeteer import launch -from metagpt.config2 import config +from metagpt.config2 import Config from metagpt.logs import logger -async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height=2048) -> int: +async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height=2048, config=None) -> int: """ Converts the given Mermaid code to various output formats and saves them to files. @@ -27,6 +27,7 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, Returns: int: Returns 1 if the conversion and saving were successful, -1 otherwise. """ + config = config if config else Config.default() suffixes = ["png", "svg", "pdf"] __dirname = os.path.dirname(os.path.abspath(__file__)) diff --git a/metagpt/utils/repair_llm_raw_output.py b/metagpt/utils/repair_llm_raw_output.py index 68fa73108..5c57693f7 100644 --- a/metagpt/utils/repair_llm_raw_output.py +++ b/metagpt/utils/repair_llm_raw_output.py @@ -4,12 +4,12 @@ import copy from enum import Enum -from typing import Callable, Union +from typing import Callable, Optional, Union import regex as re from tenacity import RetryCallState, retry, stop_after_attempt, wait_fixed -from metagpt.config2 import config +from metagpt.config2 import Config from metagpt.logs import logger from metagpt.utils.custom_decoder import CustomDecoder @@ -154,7 +154,9 @@ def _repair_llm_raw_output(output: str, req_key: str, repair_type: RepairType = return output -def repair_llm_raw_output(output: str, req_keys: list[str], repair_type: RepairType = None) -> str: +def repair_llm_raw_output( + output: str, req_keys: list[str], repair_type: RepairType = None, config: Optional[Config] = None +) -> str: """ in open-source llm model, it usually can't follow the instruction well, the output may be incomplete, so here we try to repair it and use all repair methods by default. @@ -169,6 +171,7 @@ def repair_llm_raw_output(output: str, req_keys: list[str], repair_type: RepairT target: { xxx } output: { xxx }] """ + config = config if config else Config.default() if not config.repair_llm_output: return output @@ -256,6 +259,7 @@ def run_after_exp_and_passon_next_retry(logger: "loguru.Logger") -> Callable[["R "next_action":"None" } """ + config = Config.default() if retry_state.outcome.failed: if retry_state.args: # # can't be used as args=retry_state.args @@ -276,8 +280,12 @@ def run_after_exp_and_passon_next_retry(logger: "loguru.Logger") -> Callable[["R return run_and_passon +def repair_stop_after_attempt(retry_state): + return stop_after_attempt(3 if Config.default().repair_llm_output else 0)(retry_state) + + @retry( - stop=stop_after_attempt(3 if config.repair_llm_output else 0), + stop=repair_stop_after_attempt, wait=wait_fixed(1), after=run_after_exp_and_passon_next_retry(logger), ) diff --git a/tests/metagpt/roles/di/run_swe_agent_for_benchmark.py b/tests/metagpt/roles/di/run_swe_agent_for_benchmark.py index e2aa3d17f..207521c97 100644 --- a/tests/metagpt/roles/di/run_swe_agent_for_benchmark.py +++ b/tests/metagpt/roles/di/run_swe_agent_for_benchmark.py @@ -2,13 +2,14 @@ import asyncio import json from datetime import datetime -from metagpt.config2 import config +from metagpt.config2 import Config from metagpt.const import DEFAULT_WORKSPACE_ROOT, METAGPT_ROOT from metagpt.logs import logger from metagpt.roles.di.swe_agent import SWEAgent from metagpt.tools.libs.terminal import Terminal from metagpt.tools.swe_agent_commands.swe_agent_utils import load_hf_dataset +config = Config.default() # Specify by yourself TEST_REPO_DIR = METAGPT_ROOT / "data" / "test_repo" DATA_DIR = METAGPT_ROOT / "data/hugging_face" diff --git a/tests/metagpt/test_document.py b/tests/metagpt/test_document.py index 9c076f4e6..29393bb13 100644 --- a/tests/metagpt/test_document.py +++ b/tests/metagpt/test_document.py @@ -5,10 +5,12 @@ @Author : alexanderwu @File : test_document.py """ -from metagpt.config2 import config +from metagpt.config2 import Config from metagpt.document import Repo from metagpt.logs import logger +config = Config.default() + def set_existing_repo(path): repo1 = Repo.from_path(path) diff --git a/tests/metagpt/tools/test_azure_tts.py b/tests/metagpt/tools/test_azure_tts.py index f72b5663b..ee55616d2 100644 --- a/tests/metagpt/tools/test_azure_tts.py +++ b/tests/metagpt/tools/test_azure_tts.py @@ -12,9 +12,11 @@ from pathlib import Path import pytest from azure.cognitiveservices.speech import ResultReason, SpeechSynthesizer -from metagpt.config2 import config +from metagpt.config2 import Config from metagpt.tools.azure_tts import AzureTTS +config = Config.default() + @pytest.mark.asyncio async def test_azure_tts(mocker): diff --git a/tests/metagpt/tools/test_metagpt_text_to_image.py b/tests/metagpt/tools/test_metagpt_text_to_image.py index d3797a460..bd0fcaf8b 100644 --- a/tests/metagpt/tools/test_metagpt_text_to_image.py +++ b/tests/metagpt/tools/test_metagpt_text_to_image.py @@ -10,9 +10,11 @@ from unittest.mock import AsyncMock import pytest -from metagpt.config2 import config +from metagpt.config2 import Config from metagpt.tools.metagpt_text_to_image import oas3_metagpt_text_to_image +config = Config.default() + @pytest.mark.asyncio async def test_draw(mocker): diff --git a/tests/metagpt/tools/test_moderation.py b/tests/metagpt/tools/test_moderation.py index 8dc9e9d5e..0f921887f 100644 --- a/tests/metagpt/tools/test_moderation.py +++ b/tests/metagpt/tools/test_moderation.py @@ -8,10 +8,12 @@ import pytest -from metagpt.config2 import config +from metagpt.config2 import Config from metagpt.llm import LLM from metagpt.tools.moderation import Moderation +config = Config.default() + @pytest.mark.asyncio @pytest.mark.parametrize( diff --git a/tests/metagpt/tools/test_openai_text_to_image.py b/tests/metagpt/tools/test_openai_text_to_image.py index 3f9169ddd..4856342d1 100644 --- a/tests/metagpt/tools/test_openai_text_to_image.py +++ b/tests/metagpt/tools/test_openai_text_to_image.py @@ -11,7 +11,7 @@ import openai import pytest from pydantic import BaseModel -from metagpt.config2 import config +from metagpt.config2 import Config from metagpt.llm import LLM from metagpt.tools.openai_text_to_image import ( OpenAIText2Image, @@ -19,6 +19,8 @@ from metagpt.tools.openai_text_to_image import ( ) from metagpt.utils.s3 import S3 +config = Config.default() + @pytest.mark.asyncio async def test_draw(mocker): diff --git a/tests/metagpt/tools/test_ut_writer.py b/tests/metagpt/tools/test_ut_writer.py index 3cc7e86bb..3ebbe6d9d 100644 --- a/tests/metagpt/tools/test_ut_writer.py +++ b/tests/metagpt/tools/test_ut_writer.py @@ -20,10 +20,12 @@ from openai.types.chat.chat_completion_message_tool_call import ( Function, ) -from metagpt.config2 import config +from metagpt.config2 import Config from metagpt.const import API_QUESTIONS_PATH, UT_PY_PATH from metagpt.tools.ut_writer import YFT_PROMPT_PREFIX, UTGenerator +config = Config.default() + class TestUTWriter: @pytest.mark.asyncio diff --git a/tests/metagpt/utils/test_repair_llm_raw_output.py b/tests/metagpt/utils/test_repair_llm_raw_output.py index 7a29ea3ee..75bd9f165 100644 --- a/tests/metagpt/utils/test_repair_llm_raw_output.py +++ b/tests/metagpt/utils/test_repair_llm_raw_output.py @@ -2,7 +2,9 @@ # -*- coding: utf-8 -*- # @Desc : unittest of repair_llm_raw_output -from metagpt.config2 import config +from metagpt.config2 import Config + +config = Config.default() """ CONFIG.repair_llm_output should be True before retry_parse_json_text imported. diff --git a/tests/mock/mock_llm.py b/tests/mock/mock_llm.py index 168125448..fdbf86825 100644 --- a/tests/mock/mock_llm.py +++ b/tests/mock/mock_llm.py @@ -1,7 +1,7 @@ import json from typing import Optional, Union -from metagpt.config2 import config +from metagpt.config2 import Config from metagpt.configs.llm_config import LLMType from metagpt.const import LLM_API_TIMEOUT from metagpt.logs import logger @@ -10,6 +10,8 @@ from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA from metagpt.provider.openai_api import OpenAILLM from metagpt.schema import Message +config = Config.default() + OriginalLLM = OpenAILLM if config.llm.api_type == LLMType.OPENAI else AzureOpenAILLM