mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-27 14:25:20 +02:00
remove global config
This commit is contained in:
parent
6e0990f251
commit
2968c181c1
39 changed files with 193 additions and 123 deletions
|
|
@ -6,12 +6,13 @@ Author: garylin2099
|
|||
import re
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.config2 import config
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.const import METAGPT_ROOT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles import Role
|
||||
from metagpt.schema import Message
|
||||
|
||||
config = Config.default()
|
||||
EXAMPLE_CODE_FILE = METAGPT_ROOT / "examples/build_customized_agent.py"
|
||||
MULTI_ACTION_AGENT_CODE_EXAMPLE = EXAMPLE_CODE_FILE.read_text()
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ from typing import Optional, Set, Tuple
|
|||
import aiofiles
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.config2 import config
|
||||
from metagpt.const import (
|
||||
AGGREGATION,
|
||||
COMPOSITION,
|
||||
|
|
@ -40,7 +39,7 @@ class RebuildClassView(Action):
|
|||
|
||||
graph_db: Optional[GraphRepository] = None
|
||||
|
||||
async def run(self, with_messages=None, format=config.prompt_schema):
|
||||
async def run(self, with_messages=None, format=None):
|
||||
"""
|
||||
Implementation of `Action`'s `run` method.
|
||||
|
||||
|
|
@ -48,6 +47,7 @@ class RebuildClassView(Action):
|
|||
with_messages (Optional[Type]): An optional argument specifying messages to react to.
|
||||
format (str): The format for the prompt schema.
|
||||
"""
|
||||
format = format if format else self.config.prompt_schema
|
||||
graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name
|
||||
self.graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
|
||||
repo_parser = RepoParser(base_directory=Path(self.i_context))
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ from pydantic import BaseModel
|
|||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.config2 import config
|
||||
from metagpt.const import GRAPH_REPO_FILE_REPO
|
||||
from metagpt.logs import logger
|
||||
from metagpt.repo_parser import CodeBlockInfo, DotClassInfo
|
||||
|
|
@ -84,7 +83,7 @@ class RebuildSequenceView(Action):
|
|||
|
||||
graph_db: Optional[GraphRepository] = None
|
||||
|
||||
async def run(self, with_messages=None, format=config.prompt_schema):
|
||||
async def run(self, with_messages=None, format=None):
|
||||
"""
|
||||
Implementation of `Action`'s `run` method.
|
||||
|
||||
|
|
@ -92,6 +91,7 @@ class RebuildSequenceView(Action):
|
|||
with_messages (Optional[Type]): An optional argument specifying messages to react to.
|
||||
format (str): The format for the prompt schema.
|
||||
"""
|
||||
format = format if format else self.config.prompt_schema
|
||||
graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name
|
||||
self.graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
|
||||
if not self.i_context:
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ from typing import Any, Callable, Coroutine, Optional, Union
|
|||
from pydantic import TypeAdapter, model_validator
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.config2 import config
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools.search_engine import SearchEngine
|
||||
from metagpt.tools.web_browser_engine import WebBrowserEngine
|
||||
|
|
@ -134,8 +133,8 @@ class CollectLinks(Action):
|
|||
if len(remove) == 0:
|
||||
break
|
||||
|
||||
model_name = config.llm.model
|
||||
prompt = reduce_message_length(gen_msg(), model_name, system_text, config.llm.max_token)
|
||||
model_name = self.config.llm.model
|
||||
prompt = reduce_message_length(gen_msg(), model_name, system_text, self.config.llm.max_token)
|
||||
logger.debug(prompt)
|
||||
queries = await self._aask(prompt, [system_text])
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@
|
|||
from typing import Optional
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.config2 import config
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
|
@ -26,7 +25,7 @@ class TalkAction(Action):
|
|||
|
||||
@property
|
||||
def language(self):
|
||||
return self.context.kwargs.language or config.language
|
||||
return self.context.kwargs.language or self.config.language
|
||||
|
||||
@property
|
||||
def prompt(self):
|
||||
|
|
|
|||
|
|
@ -97,20 +97,21 @@ class Config(CLIParams, YamlModel):
|
|||
return Config.from_yaml_file(pathname)
|
||||
|
||||
@classmethod
|
||||
def default(cls):
|
||||
def default(cls, reload: bool = False):
|
||||
"""Load default config
|
||||
- Priority: env < default_config_paths
|
||||
- Inside default_config_paths, the latter one overwrites the former one
|
||||
"""
|
||||
default_config_paths: List[Path] = [
|
||||
default_config_paths = (
|
||||
METAGPT_ROOT / "config/config2.yaml",
|
||||
CONFIG_ROOT / "config2.yaml",
|
||||
]
|
||||
|
||||
dicts = [dict(os.environ)]
|
||||
dicts += [Config.read_yaml(path) for path in default_config_paths]
|
||||
final = merge_dict(dicts)
|
||||
return Config(**final)
|
||||
)
|
||||
if reload or default_config_paths not in _CONFIG_CACHE:
|
||||
dicts = [dict(os.environ)]
|
||||
dicts += [Config.read_yaml(path) for path in default_config_paths]
|
||||
final = merge_dict(dicts)
|
||||
_CONFIG_CACHE[default_config_paths] = Config(**final)
|
||||
return _CONFIG_CACHE[default_config_paths]
|
||||
|
||||
@classmethod
|
||||
def from_llm_config(cls, llm_config: dict):
|
||||
|
|
@ -160,4 +161,4 @@ def merge_dict(dicts: Iterable[Dict]) -> Dict:
|
|||
return result
|
||||
|
||||
|
||||
config = Config.default()
|
||||
_CONFIG_CACHE = {}
|
||||
|
|
|
|||
|
|
@ -5,12 +5,23 @@
|
|||
@Author : alexanderwu
|
||||
@File : browser_config.py
|
||||
"""
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
from metagpt.tools import WebBrowserEngineType
|
||||
from metagpt.utils.yaml_model import YamlModel
|
||||
|
||||
|
||||
class WebBrowserEngineType(Enum):
|
||||
PLAYWRIGHT = "playwright"
|
||||
SELENIUM = "selenium"
|
||||
CUSTOM = "custom"
|
||||
|
||||
@classmethod
|
||||
def __missing__(cls, key):
|
||||
"""Default type conversion"""
|
||||
return cls.CUSTOM
|
||||
|
||||
|
||||
class BrowserConfig(YamlModel):
|
||||
"""Config for Browser"""
|
||||
|
||||
|
|
|
|||
|
|
@ -5,14 +5,23 @@
|
|||
@Author : alexanderwu
|
||||
@File : search_config.py
|
||||
"""
|
||||
from enum import Enum
|
||||
from typing import Callable, Optional
|
||||
|
||||
from pydantic import ConfigDict, Field
|
||||
|
||||
from metagpt.tools import SearchEngineType
|
||||
from metagpt.utils.yaml_model import YamlModel
|
||||
|
||||
|
||||
class SearchEngineType(Enum):
|
||||
SERPAPI_GOOGLE = "serpapi"
|
||||
SERPER_GOOGLE = "serper"
|
||||
DIRECT_GOOGLE = "google"
|
||||
DUCK_DUCK_GO = "ddg"
|
||||
CUSTOM_ENGINE = "custom"
|
||||
BING = "bing"
|
||||
|
||||
|
||||
class SearchConfig(YamlModel):
|
||||
"""Config for Search"""
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from __future__ import annotations
|
|||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
|
|
@ -61,7 +61,7 @@ class Context(BaseModel):
|
|||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
kwargs: AttrDict = AttrDict()
|
||||
config: Config = Config.default()
|
||||
config: Config = Field(default_factory=Config.default)
|
||||
|
||||
cost_manager: CostManager = CostManager()
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from typing import Any, Iterable
|
|||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
from pydantic import ConfigDict, Field
|
||||
|
||||
from metagpt.config2 import config as CONFIG
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.environment.base_env import Environment
|
||||
from metagpt.environment.minecraft.const import MC_CKPT_DIR
|
||||
from metagpt.environment.minecraft.minecraft_ext_env import MinecraftExtEnv
|
||||
|
|
@ -82,7 +82,7 @@ class MinecraftEnv(Environment, MinecraftExtEnv):
|
|||
persist_dir=f"{MC_CKPT_DIR}/skill/vectordb",
|
||||
)
|
||||
|
||||
if CONFIG.resume:
|
||||
if Config.default().resume:
|
||||
logger.info(f"Loading Action Developer from {MC_CKPT_DIR}/action")
|
||||
self.chest_memory = read_json_file(f"{MC_CKPT_DIR}/action/chest_memory.json")
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from typing import Any, Callable, Optional, TypeVar
|
|||
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.exp_pool.context_builders import BaseContextBuilder, SimpleContextBuilder
|
||||
from metagpt.exp_pool.manager import ExperienceManager, exp_manager
|
||||
from metagpt.exp_pool.perfect_judges import BasePerfectJudge, SimplePerfectJudge
|
||||
|
|
@ -50,11 +50,14 @@ def exp_cache(
|
|||
"""
|
||||
|
||||
def decorator(func: Callable[..., ReturnType]) -> Callable[..., ReturnType]:
|
||||
if not config.exp_pool.enabled:
|
||||
return func
|
||||
|
||||
@functools.wraps(func)
|
||||
async def get_or_create(args: Any, kwargs: Any) -> ReturnType:
|
||||
config = Config.default()
|
||||
|
||||
if not config.exp_pool.enabled:
|
||||
rsp = func(*args, **kwargs)
|
||||
return await rsp if asyncio.iscoroutine(rsp) else rsp
|
||||
|
||||
handler = ExpCacheHandler(
|
||||
func=func,
|
||||
args=args,
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@
|
|||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from metagpt.config2 import Config, config
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.exp_pool.schema import (
|
||||
DEFAULT_COLLECTION_NAME,
|
||||
DEFAULT_SIMILARITY_TOP_K,
|
||||
|
|
@ -29,7 +29,7 @@ class ExperienceManager(BaseModel):
|
|||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
config: Config = config
|
||||
config: Config = Field(default_factory=Config.default)
|
||||
|
||||
_storage: Any = None
|
||||
_vector_store: Any = None
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ from pathlib import Path
|
|||
from typing import Any, Optional, Union
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.config2 import config
|
||||
from metagpt.ext.stanford_town.utils.const import PROMPTS_DIR
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
|
@ -62,13 +61,13 @@ class STAction(Action):
|
|||
async def _run_gpt35_max_tokens(self, prompt: str, max_tokens: int = 50, retry: int = 3):
|
||||
for idx in range(retry):
|
||||
try:
|
||||
tmp_max_tokens_rsp = getattr(config.llm, "max_token", 1500)
|
||||
setattr(config.llm, "max_token", max_tokens)
|
||||
tmp_max_tokens_rsp = getattr(self.config.llm, "max_token", 1500)
|
||||
setattr(self.config.llm, "max_token", max_tokens)
|
||||
self.llm.use_system_prompt = False # to make it behave like a non-chat completions
|
||||
|
||||
llm_resp = await self._aask(prompt)
|
||||
|
||||
setattr(config.llm, "max_token", tmp_max_tokens_rsp)
|
||||
setattr(self.config.llm, "max_token", tmp_max_tokens_rsp)
|
||||
logger.info(f"Action: {self.cls_name} llm _run_gpt35_max_tokens raw resp: {llm_resp}")
|
||||
if self._func_validate(llm_resp, prompt):
|
||||
return self._func_cleanup(llm_resp, prompt)
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from typing import Union
|
|||
|
||||
from openai import OpenAI
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
|
|
@ -48,6 +48,7 @@ def read_csv_to_list(curr_file: str, header=False, strip_trail=True):
|
|||
|
||||
|
||||
def get_embedding(text, model: str = "text-embedding-ada-002"):
|
||||
config = Config.default()
|
||||
text = text.replace("\n", " ")
|
||||
if not text:
|
||||
text = "this is blank"
|
||||
|
|
|
|||
|
|
@ -6,12 +6,13 @@
|
|||
@File : text_to_embedding.py
|
||||
@Desc : Text-to-Embedding skill, which provides text-to-embedding functionality.
|
||||
"""
|
||||
import metagpt.config2
|
||||
from typing import Optional
|
||||
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.tools.openai_text_to_embedding import oas3_openai_text_to_embedding
|
||||
|
||||
|
||||
async def text_to_embedding(text, model="text-embedding-ada-002", config: Config = metagpt.config2.config):
|
||||
async def text_to_embedding(text, model="text-embedding-ada-002", config: Optional[Config] = None):
|
||||
"""Text to embedding
|
||||
|
||||
:param text: The text used for embedding.
|
||||
|
|
@ -19,6 +20,7 @@ async def text_to_embedding(text, model="text-embedding-ada-002", config: Config
|
|||
:param config: OpenAI config with API key, For more details, checkout: `https://platform.openai.com/account/api-keys`
|
||||
:return: A json object of :class:`ResultEmbedding` class if successful, otherwise `{}`.
|
||||
"""
|
||||
config = config if config else Config.default()
|
||||
openai_api_key = config.get_openai_llm().api_key
|
||||
proxy = config.get_openai_llm().proxy
|
||||
return await oas3_openai_text_to_embedding(text, model=model, openai_api_key=openai_api_key, proxy=proxy)
|
||||
|
|
|
|||
|
|
@ -7,8 +7,8 @@
|
|||
@Desc : Text-to-Image skill, which provides text-to-image functionality.
|
||||
"""
|
||||
import base64
|
||||
from typing import Optional
|
||||
|
||||
import metagpt.config2
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.const import BASE64_FORMAT
|
||||
from metagpt.llm import LLM
|
||||
|
|
@ -17,7 +17,7 @@ from metagpt.tools.openai_text_to_image import oas3_openai_text_to_image
|
|||
from metagpt.utils.s3 import S3
|
||||
|
||||
|
||||
async def text_to_image(text, size_type: str = "512x512", config: Config = metagpt.config2.config):
|
||||
async def text_to_image(text, size_type: str = "512x512", config: Optional[Config] = None):
|
||||
"""Text to image
|
||||
|
||||
:param text: The text used for image conversion.
|
||||
|
|
@ -25,6 +25,7 @@ async def text_to_image(text, size_type: str = "512x512", config: Config = metag
|
|||
:param config: Config
|
||||
:return: The image data is returned in Base64 encoding.
|
||||
"""
|
||||
config = config if config else Config.default()
|
||||
image_declaration = "data:image/png;base64,"
|
||||
|
||||
model_url = config.metagpt_tti_url
|
||||
|
|
|
|||
|
|
@ -6,7 +6,8 @@
|
|||
@File : text_to_speech.py
|
||||
@Desc : Text-to-Speech skill, which provides text-to-speech functionality
|
||||
"""
|
||||
import metagpt.config2
|
||||
from typing import Optional
|
||||
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.const import BASE64_FORMAT
|
||||
from metagpt.tools.azure_tts import oas3_azsure_tts
|
||||
|
|
@ -20,7 +21,7 @@ async def text_to_speech(
|
|||
voice="zh-CN-XiaomoNeural",
|
||||
style="affectionate",
|
||||
role="Girl",
|
||||
config: Config = metagpt.config2.config,
|
||||
config: Optional[Config] = None,
|
||||
):
|
||||
"""Text to speech
|
||||
For more details, check out:`https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`
|
||||
|
|
@ -38,7 +39,7 @@ async def text_to_speech(
|
|||
:return: Returns the Base64-encoded .wav/.mp3 file data if successful, otherwise an empty string.
|
||||
|
||||
"""
|
||||
|
||||
config = config if config else Config.default()
|
||||
subscription_key = config.azure_tts_subscription_key
|
||||
region = config.azure_tts_region
|
||||
if subscription_key and region:
|
||||
|
|
|
|||
|
|
@ -12,9 +12,9 @@ import json
|
|||
import re
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.config2 import Config as _Config
|
||||
from metagpt.const import DEFAULT_MAX_TOKENS, DEFAULT_TOKEN_SIZE
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider import MetaGPTLLM
|
||||
|
|
@ -32,6 +32,12 @@ class BrainMemory(BaseModel):
|
|||
last_talk: Optional[str] = None
|
||||
cacheable: bool = True
|
||||
llm: Optional[BaseLLM] = Field(default=None, exclude=True)
|
||||
config: Optional[_Config] = None
|
||||
|
||||
@field_validator("config")
|
||||
@classmethod
|
||||
def set_default_config(cls, config):
|
||||
return config if config else _Config.default()
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
|
@ -54,9 +60,8 @@ class BrainMemory(BaseModel):
|
|||
texts = [m.content for m in self.knowledge]
|
||||
return "\n".join(texts)
|
||||
|
||||
@staticmethod
|
||||
async def loads(redis_key: str) -> "BrainMemory":
|
||||
redis = Redis(config.redis)
|
||||
async def loads(self, redis_key: str) -> "BrainMemory":
|
||||
redis = Redis(self.config.redis)
|
||||
if not redis_key:
|
||||
return BrainMemory()
|
||||
v = await redis.get(key=redis_key)
|
||||
|
|
@ -70,7 +75,7 @@ class BrainMemory(BaseModel):
|
|||
async def dumps(self, redis_key: str, timeout_sec: int = 30 * 60):
|
||||
if not self.is_dirty:
|
||||
return
|
||||
redis = Redis(config.redis)
|
||||
redis = Redis(self.config.redis)
|
||||
if not redis_key:
|
||||
return False
|
||||
v = self.model_dump_json()
|
||||
|
|
@ -140,7 +145,7 @@ class BrainMemory(BaseModel):
|
|||
return text
|
||||
summary = await self._summarize(text=text, max_words=max_words, keep_language=keep_language, limit=limit)
|
||||
if summary:
|
||||
await self.set_history_summary(history_summary=summary, redis_key=config.redis_key)
|
||||
await self.set_history_summary(history_summary=summary, redis_key=self.config.redis_key)
|
||||
return summary
|
||||
raise ValueError(f"text too long:{text_length}")
|
||||
|
||||
|
|
@ -164,7 +169,7 @@ class BrainMemory(BaseModel):
|
|||
msgs.reverse()
|
||||
self.history = msgs
|
||||
self.is_dirty = True
|
||||
await self.dumps(redis_key=config.redis.key)
|
||||
await self.dumps(redis_key=self.config.redis.key)
|
||||
self.is_dirty = False
|
||||
|
||||
return BrainMemory.to_metagpt_history_format(self.history)
|
||||
|
|
@ -181,7 +186,7 @@ class BrainMemory(BaseModel):
|
|||
|
||||
summary = await self.summarize(llm=llm, max_words=500)
|
||||
|
||||
language = config.language
|
||||
language = self.config.language
|
||||
command = f"Translate the above summary into a {language} title of less than {max_words} words."
|
||||
summaries = [summary, command]
|
||||
msg = "\n".join(summaries)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"""RAG Embedding Factory."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
|
||||
|
|
@ -9,7 +9,7 @@ from llama_index.embeddings.gemini import GeminiEmbedding
|
|||
from llama_index.embeddings.ollama import OllamaEmbedding
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.configs.embedding_config import EmbeddingType
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.rag.factories.base import GenericFactory
|
||||
|
|
@ -18,7 +18,7 @@ from metagpt.rag.factories.base import GenericFactory
|
|||
class RAGEmbeddingFactory(GenericFactory):
|
||||
"""Create LlamaIndex Embedding with MetaGPT's embedding config."""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, config: Optional[Config] = None):
|
||||
creators = {
|
||||
EmbeddingType.OPENAI: self._create_openai,
|
||||
EmbeddingType.AZURE: self._create_azure,
|
||||
|
|
@ -29,6 +29,7 @@ class RAGEmbeddingFactory(GenericFactory):
|
|||
LLMType.AZURE: self._create_azure,
|
||||
}
|
||||
super().__init__(creators)
|
||||
self.config = config if self.config else Config.default()
|
||||
|
||||
def get_rag_embedding(self, key: EmbeddingType = None) -> BaseEmbedding:
|
||||
"""Key is EmbeddingType."""
|
||||
|
|
@ -40,18 +41,18 @@ class RAGEmbeddingFactory(GenericFactory):
|
|||
If the embedding type is not specified, for backward compatibility, it checks if the LLM API type is either OPENAI or AZURE.
|
||||
Raise TypeError if embedding type not found.
|
||||
"""
|
||||
if config.embedding.api_type:
|
||||
return config.embedding.api_type
|
||||
if self.config.embedding.api_type:
|
||||
return self.config.embedding.api_type
|
||||
|
||||
if config.llm.api_type in [LLMType.OPENAI, LLMType.AZURE]:
|
||||
return config.llm.api_type
|
||||
if self.config.llm.api_type in [LLMType.OPENAI, LLMType.AZURE]:
|
||||
return self.config.llm.api_type
|
||||
|
||||
raise TypeError("To use RAG, please set your embedding in config2.yaml.")
|
||||
|
||||
def _create_openai(self) -> OpenAIEmbedding:
|
||||
params = dict(
|
||||
api_key=config.embedding.api_key or config.llm.api_key,
|
||||
api_base=config.embedding.base_url or config.llm.base_url,
|
||||
api_key=self.config.embedding.api_key or self.config.llm.api_key,
|
||||
api_base=self.config.embedding.base_url or self.config.llm.base_url,
|
||||
)
|
||||
|
||||
self._try_set_model_and_batch_size(params)
|
||||
|
|
@ -60,9 +61,9 @@ class RAGEmbeddingFactory(GenericFactory):
|
|||
|
||||
def _create_azure(self) -> AzureOpenAIEmbedding:
|
||||
params = dict(
|
||||
api_key=config.embedding.api_key or config.llm.api_key,
|
||||
azure_endpoint=config.embedding.base_url or config.llm.base_url,
|
||||
api_version=config.embedding.api_version or config.llm.api_version,
|
||||
api_key=self.config.embedding.api_key or self.config.llm.api_key,
|
||||
azure_endpoint=self.config.embedding.base_url or self.config.llm.base_url,
|
||||
api_version=self.config.embedding.api_version or self.config.llm.api_version,
|
||||
)
|
||||
|
||||
self._try_set_model_and_batch_size(params)
|
||||
|
|
@ -71,8 +72,8 @@ class RAGEmbeddingFactory(GenericFactory):
|
|||
|
||||
def _create_gemini(self) -> GeminiEmbedding:
|
||||
params = dict(
|
||||
api_key=config.embedding.api_key,
|
||||
api_base=config.embedding.base_url,
|
||||
api_key=self.config.embedding.api_key,
|
||||
api_base=self.config.embedding.base_url,
|
||||
)
|
||||
|
||||
self._try_set_model_and_batch_size(params)
|
||||
|
|
@ -81,7 +82,7 @@ class RAGEmbeddingFactory(GenericFactory):
|
|||
|
||||
def _create_ollama(self) -> OllamaEmbedding:
|
||||
params = dict(
|
||||
base_url=config.embedding.base_url,
|
||||
base_url=self.config.embedding.base_url,
|
||||
)
|
||||
|
||||
self._try_set_model_and_batch_size(params)
|
||||
|
|
@ -90,14 +91,15 @@ class RAGEmbeddingFactory(GenericFactory):
|
|||
|
||||
def _try_set_model_and_batch_size(self, params: dict):
|
||||
"""Set the model_name and embed_batch_size only when they are specified."""
|
||||
if config.embedding.model:
|
||||
params["model_name"] = config.embedding.model
|
||||
if self.config.embedding.model:
|
||||
params["model_name"] = self.config.embedding.model
|
||||
|
||||
if config.embedding.embed_batch_size:
|
||||
params["embed_batch_size"] = config.embedding.embed_batch_size
|
||||
if self.config.embedding.embed_batch_size:
|
||||
params["embed_batch_size"] = self.config.embedding.embed_batch_size
|
||||
|
||||
def _raise_for_key(self, key: Any):
|
||||
raise ValueError(f"The embedding type is currently not supported: `{type(key)}`, {key}")
|
||||
|
||||
|
||||
get_rag_embedding = RAGEmbeddingFactory().get_rag_embedding
|
||||
def get_rag_embedding(key: EmbeddingType = None, config: Optional[Config] = None):
|
||||
return RAGEmbeddingFactory(config=config).get_rag_embedding(key)
|
||||
|
|
|
|||
|
|
@ -10,9 +10,9 @@ from llama_index.core.llms import (
|
|||
LLMMetadata,
|
||||
)
|
||||
from llama_index.core.llms.callbacks import llm_completion_callback
|
||||
from pydantic import Field
|
||||
from pydantic import Field, model_validator
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.utils.async_helper import NestAsyncio
|
||||
|
|
@ -26,9 +26,23 @@ class RAGLLM(CustomLLM):
|
|||
"""
|
||||
|
||||
model_infer: BaseLLM = Field(..., description="The MetaGPT's LLM.")
|
||||
context_window: int = TOKEN_MAX.get(config.llm.model, DEFAULT_CONTEXT_WINDOW)
|
||||
num_output: int = config.llm.max_token
|
||||
model_name: str = config.llm.model
|
||||
context_window: int = -1
|
||||
num_output: int = -1
|
||||
model_name: str = ""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def update_from_config(self):
|
||||
config = Config.default()
|
||||
if self.context_window < 0:
|
||||
self.context_window = TOKEN_MAX.get(config.llm.model, DEFAULT_CONTEXT_WINDOW)
|
||||
|
||||
if self.num_output < 0:
|
||||
self.num_output = config.llm.max_token
|
||||
|
||||
if not self.model_name:
|
||||
self.model_name = config.llm.model
|
||||
|
||||
return self
|
||||
|
||||
@property
|
||||
def metadata(self) -> LLMMetadata:
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from llama_index.core.schema import TextNode
|
|||
from llama_index.core.vector_stores.types import VectorStoreQueryMode
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.configs.embedding_config import EmbeddingType
|
||||
from metagpt.logs import logger
|
||||
from metagpt.rag.interface import RAGObject
|
||||
|
|
@ -45,6 +45,7 @@ class FAISSRetrieverConfig(IndexRetrieverConfig):
|
|||
@model_validator(mode="after")
|
||||
def check_dimensions(self):
|
||||
if self.dimensions == 0:
|
||||
config = Config.default()
|
||||
self.dimensions = config.embedding.dimensions or self._embedding_type_to_dimensions.get(
|
||||
config.embedding.api_type, 1536
|
||||
)
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ def generate_repo(
|
|||
recover_path=None,
|
||||
):
|
||||
"""Run the startup logic. Can be called from CLI or other Python scripts."""
|
||||
from metagpt.config2 import config
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.context import Context
|
||||
from metagpt.roles import (
|
||||
Architect,
|
||||
|
|
@ -38,6 +38,8 @@ def generate_repo(
|
|||
)
|
||||
from metagpt.team import Team
|
||||
|
||||
config = Config.default()
|
||||
|
||||
config.update_via_cli(project_path, project_name, inc, reqa_file, max_auto_summarize_code)
|
||||
ctx = Context(config=config)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,33 +6,18 @@
|
|||
@File : __init__.py
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from metagpt.tools import libs # this registers all tools
|
||||
from metagpt.tools.tool_registry import TOOL_REGISTRY
|
||||
from metagpt.configs.search_config import SearchEngineType
|
||||
from metagpt.configs.browser_config import WebBrowserEngineType
|
||||
|
||||
|
||||
_ = libs, TOOL_REGISTRY # Avoid pre-commit error
|
||||
|
||||
|
||||
class SearchEngineType(Enum):
|
||||
SERPAPI_GOOGLE = "serpapi"
|
||||
SERPER_GOOGLE = "serper"
|
||||
DIRECT_GOOGLE = "google"
|
||||
DUCK_DUCK_GO = "ddg"
|
||||
CUSTOM_ENGINE = "custom"
|
||||
BING = "bing"
|
||||
|
||||
|
||||
class WebBrowserEngineType(Enum):
|
||||
PLAYWRIGHT = "playwright"
|
||||
SELENIUM = "selenium"
|
||||
CUSTOM = "custom"
|
||||
|
||||
@classmethod
|
||||
def __missing__(cls, key):
|
||||
"""Default type conversion"""
|
||||
return cls.CUSTOM
|
||||
|
||||
|
||||
class SearchInterface:
|
||||
async def asearch(self, *args, **kwargs):
|
||||
...
|
||||
|
||||
|
||||
__all__ = ["SearchEngineType", "WebBrowserEngineType", "TOOL_REGISTRY"]
|
||||
|
|
|
|||
|
|
@ -7,7 +7,9 @@
|
|||
"""
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.const import DEFAULT_WORKSPACE_ROOT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
|
|
@ -36,11 +38,11 @@ class GPTvGenerator:
|
|||
It utilizes a vision model to analyze the layout from an image and generate webpage codes accordingly.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, config: Optional[Config]):
|
||||
"""Initialize GPTvGenerator class with default values from the configuration."""
|
||||
from metagpt.config2 import config
|
||||
from metagpt.llm import LLM
|
||||
|
||||
config = config if config else Config.default()
|
||||
self.llm = LLM(llm_config=config.get_openai_llm())
|
||||
self.llm.model = "gpt-4-vision-preview"
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.provider.openai_api import OpenAILLM as GPTAPI
|
||||
from metagpt.utils.common import awrite
|
||||
|
||||
|
|
@ -282,6 +282,7 @@ class UTGenerator:
|
|||
"""Choose based on different calling methods"""
|
||||
result = ""
|
||||
if self.chatgpt_method == "API":
|
||||
config = Config.default()
|
||||
result = await GPTAPI(config.get_openai_llm()).aask_code(messages=messages)
|
||||
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -7,10 +7,11 @@
|
|||
"""
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.config2 import Config
|
||||
|
||||
|
||||
def get_embedding() -> OpenAIEmbedding:
|
||||
config = Config.default()
|
||||
llm = config.get_openai_llm()
|
||||
if llm is None:
|
||||
raise ValueError("To use OpenAIEmbedding, please ensure that config.llm.api_type is correctly set to 'openai'.")
|
||||
|
|
|
|||
|
|
@ -13,10 +13,11 @@ from semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion impo
|
|||
OpenAIChatCompletion,
|
||||
)
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.config2 import Config
|
||||
|
||||
|
||||
def make_sk_kernel():
|
||||
config = Config.default()
|
||||
kernel = sk.Kernel()
|
||||
if llm := config.get_azure_llm():
|
||||
kernel.add_chat_service(
|
||||
|
|
|
|||
|
|
@ -9,12 +9,14 @@ import asyncio
|
|||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.common import awrite, check_cmd_exists
|
||||
|
||||
|
||||
async def mermaid_to_file(engine, mermaid_code, output_file_without_suffix, width=2048, height=2048) -> int:
|
||||
async def mermaid_to_file(
|
||||
engine, mermaid_code, output_file_without_suffix, width=2048, height=2048, config=None
|
||||
) -> int:
|
||||
"""suffix: png/svg/pdf
|
||||
|
||||
:param mermaid_code: mermaid code
|
||||
|
|
@ -24,6 +26,7 @@ async def mermaid_to_file(engine, mermaid_code, output_file_without_suffix, widt
|
|||
:return: 0 if succeed, -1 if failed
|
||||
"""
|
||||
# Write the Mermaid code to a temporary file
|
||||
config = config if config else Config.default()
|
||||
dir_name = os.path.dirname(output_file_without_suffix)
|
||||
if dir_name and not os.path.exists(dir_name):
|
||||
os.makedirs(dir_name)
|
||||
|
|
|
|||
|
|
@ -10,11 +10,11 @@ from urllib.parse import urljoin
|
|||
|
||||
from pyppeteer import launch
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height=2048) -> int:
|
||||
async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height=2048, config=None) -> int:
|
||||
"""
|
||||
Converts the given Mermaid code to various output formats and saves them to files.
|
||||
|
||||
|
|
@ -27,6 +27,7 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
|
|||
Returns:
|
||||
int: Returns 1 if the conversion and saving were successful, -1 otherwise.
|
||||
"""
|
||||
config = config if config else Config.default()
|
||||
suffixes = ["png", "svg", "pdf"]
|
||||
__dirname = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
|
|
|||
|
|
@ -4,12 +4,12 @@
|
|||
|
||||
import copy
|
||||
from enum import Enum
|
||||
from typing import Callable, Union
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import regex as re
|
||||
from tenacity import RetryCallState, retry, stop_after_attempt, wait_fixed
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.custom_decoder import CustomDecoder
|
||||
|
||||
|
|
@ -154,7 +154,9 @@ def _repair_llm_raw_output(output: str, req_key: str, repair_type: RepairType =
|
|||
return output
|
||||
|
||||
|
||||
def repair_llm_raw_output(output: str, req_keys: list[str], repair_type: RepairType = None) -> str:
|
||||
def repair_llm_raw_output(
|
||||
output: str, req_keys: list[str], repair_type: RepairType = None, config: Optional[Config] = None
|
||||
) -> str:
|
||||
"""
|
||||
in open-source llm model, it usually can't follow the instruction well, the output may be incomplete,
|
||||
so here we try to repair it and use all repair methods by default.
|
||||
|
|
@ -169,6 +171,7 @@ def repair_llm_raw_output(output: str, req_keys: list[str], repair_type: RepairT
|
|||
target: { xxx }
|
||||
output: { xxx }]
|
||||
"""
|
||||
config = config if config else Config.default()
|
||||
if not config.repair_llm_output:
|
||||
return output
|
||||
|
||||
|
|
@ -256,6 +259,7 @@ def run_after_exp_and_passon_next_retry(logger: "loguru.Logger") -> Callable[["R
|
|||
"next_action":"None"
|
||||
}
|
||||
"""
|
||||
config = Config.default()
|
||||
if retry_state.outcome.failed:
|
||||
if retry_state.args:
|
||||
# # can't be used as args=retry_state.args
|
||||
|
|
@ -276,8 +280,12 @@ def run_after_exp_and_passon_next_retry(logger: "loguru.Logger") -> Callable[["R
|
|||
return run_and_passon
|
||||
|
||||
|
||||
def repair_stop_after_attempt(retry_state):
|
||||
return stop_after_attempt(3 if Config.default().repair_llm_output else 0)(retry_state)
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3 if config.repair_llm_output else 0),
|
||||
stop=repair_stop_after_attempt,
|
||||
wait=wait_fixed(1),
|
||||
after=run_after_exp_and_passon_next_retry(logger),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,13 +2,14 @@ import asyncio
|
|||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.const import DEFAULT_WORKSPACE_ROOT, METAGPT_ROOT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.di.swe_agent import SWEAgent
|
||||
from metagpt.tools.libs.terminal import Terminal
|
||||
from metagpt.tools.swe_agent_commands.swe_agent_utils import load_hf_dataset
|
||||
|
||||
config = Config.default()
|
||||
# Specify by yourself
|
||||
TEST_REPO_DIR = METAGPT_ROOT / "data" / "test_repo"
|
||||
DATA_DIR = METAGPT_ROOT / "data/hugging_face"
|
||||
|
|
|
|||
|
|
@ -5,10 +5,12 @@
|
|||
@Author : alexanderwu
|
||||
@File : test_document.py
|
||||
"""
|
||||
from metagpt.config2 import config
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.document import Repo
|
||||
from metagpt.logs import logger
|
||||
|
||||
config = Config.default()
|
||||
|
||||
|
||||
def set_existing_repo(path):
|
||||
repo1 = Repo.from_path(path)
|
||||
|
|
|
|||
|
|
@ -12,9 +12,11 @@ from pathlib import Path
|
|||
import pytest
|
||||
from azure.cognitiveservices.speech import ResultReason, SpeechSynthesizer
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.tools.azure_tts import AzureTTS
|
||||
|
||||
config = Config.default()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_azure_tts(mocker):
|
||||
|
|
|
|||
|
|
@ -10,9 +10,11 @@ from unittest.mock import AsyncMock
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.tools.metagpt_text_to_image import oas3_metagpt_text_to_image
|
||||
|
||||
config = Config.default()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_draw(mocker):
|
||||
|
|
|
|||
|
|
@ -8,10 +8,12 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.tools.moderation import Moderation
|
||||
|
||||
config = Config.default()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import openai
|
|||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.tools.openai_text_to_image import (
|
||||
OpenAIText2Image,
|
||||
|
|
@ -19,6 +19,8 @@ from metagpt.tools.openai_text_to_image import (
|
|||
)
|
||||
from metagpt.utils.s3 import S3
|
||||
|
||||
config = Config.default()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_draw(mocker):
|
||||
|
|
|
|||
|
|
@ -20,10 +20,12 @@ from openai.types.chat.chat_completion_message_tool_call import (
|
|||
Function,
|
||||
)
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.const import API_QUESTIONS_PATH, UT_PY_PATH
|
||||
from metagpt.tools.ut_writer import YFT_PROMPT_PREFIX, UTGenerator
|
||||
|
||||
config = Config.default()
|
||||
|
||||
|
||||
class TestUTWriter:
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -2,7 +2,9 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of repair_llm_raw_output
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.config2 import Config
|
||||
|
||||
config = Config.default()
|
||||
|
||||
"""
|
||||
CONFIG.repair_llm_output should be True before retry_parse_json_text imported.
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
from typing import Optional, Union
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.const import LLM_API_TIMEOUT
|
||||
from metagpt.logs import logger
|
||||
|
|
@ -10,6 +10,8 @@ from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA
|
|||
from metagpt.provider.openai_api import OpenAILLM
|
||||
from metagpt.schema import Message
|
||||
|
||||
config = Config.default()
|
||||
|
||||
OriginalLLM = OpenAILLM if config.llm.api_type == LLMType.OPENAI else AzureOpenAILLM
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue