use config

This commit is contained in:
geekan 2024-01-10 20:19:56 +08:00
parent 21cac0bffb
commit 479bbc9b2d
37 changed files with 102 additions and 276 deletions

View file

@ -12,7 +12,7 @@ from pathlib import Path
import aiofiles
from metagpt.actions import Action
from metagpt.config import CONFIG
from metagpt.config2 import config
from metagpt.const import (
AGGREGATION,
COMPOSITION,
@ -20,6 +20,7 @@ from metagpt.const import (
GENERALIZATION,
GRAPH_REPO_FILE_REPO,
)
from metagpt.context import CONTEXT
from metagpt.logs import logger
from metagpt.repo_parser import RepoParser
from metagpt.schema import ClassAttribute, ClassMethod, ClassView
@ -29,8 +30,8 @@ from metagpt.utils.graph_repository import GraphKeyword, GraphRepository
class RebuildClassView(Action):
async def run(self, with_messages=None, format=CONFIG.prompt_schema):
graph_repo_pathname = CONFIG.git_repo.workdir / GRAPH_REPO_FILE_REPO / CONFIG.git_repo.workdir.name
async def run(self, with_messages=None, format=config.prompt_schema):
graph_repo_pathname = CONTEXT.git_repo.workdir / GRAPH_REPO_FILE_REPO / CONTEXT.git_repo.workdir.name
graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
repo_parser = RepoParser(base_directory=Path(self.i_context))
# use pylint
@ -48,9 +49,9 @@ class RebuildClassView(Action):
await graph_db.save()
async def _create_mermaid_class_views(self, graph_db):
path = Path(CONFIG.git_repo.workdir) / DATA_API_DESIGN_FILE_REPO
path = Path(CONTEXT.git_repo.workdir) / DATA_API_DESIGN_FILE_REPO
path.mkdir(parents=True, exist_ok=True)
pathname = path / CONFIG.git_repo.workdir.name
pathname = path / CONTEXT.git_repo.workdir.name
async with aiofiles.open(str(pathname.with_suffix(".mmd")), mode="w", encoding="utf-8") as writer:
content = "classDiagram\n"
logger.debug(content)

View file

@ -12,7 +12,6 @@ from pathlib import Path
from typing import List
from metagpt.actions import Action
from metagpt.config import CONFIG
from metagpt.const import GRAPH_REPO_FILE_REPO
from metagpt.logs import logger
from metagpt.utils.common import aread, list_files
@ -21,8 +20,8 @@ from metagpt.utils.graph_repository import GraphKeyword
class RebuildSequenceView(Action):
async def run(self, with_messages=None, format=CONFIG.prompt_schema):
graph_repo_pathname = CONFIG.git_repo.workdir / GRAPH_REPO_FILE_REPO / CONFIG.git_repo.workdir.name
async def run(self, with_messages=None, format=config.prompt_schema):
graph_repo_pathname = CONTEXT.git_repo.workdir / GRAPH_REPO_FILE_REPO / CONTEXT.git_repo.workdir.name
graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
entries = await RebuildSequenceView._search_main_entry(graph_db)
for entry in entries:

View file

@ -9,6 +9,7 @@ from pydantic import Field, parse_obj_as
from metagpt.actions import Action
from metagpt.config import CONFIG
from metagpt.config2 import config
from metagpt.logs import logger
from metagpt.tools.search_engine import SearchEngine
from metagpt.tools.web_browser_engine import WebBrowserEngine, WebBrowserEngineType
@ -127,8 +128,8 @@ class CollectLinks(Action):
if len(remove) == 0:
break
model_name = CONFIG.get_model_name(CONFIG.get_default_llm_provider_enum())
prompt = reduce_message_length(gen_msg(), model_name, system_text, CONFIG.max_tokens_rsp)
model_name = config.get_openai_llm().model
prompt = reduce_message_length(gen_msg(), model_name, system_text, 4096)
logger.debug(prompt)
queries = await self._aask(prompt, [system_text])
try:
@ -182,8 +183,6 @@ class WebBrowseAndSummarize(Action):
def __init__(self, **kwargs):
super().__init__(**kwargs)
if CONFIG.model_for_researcher_summary:
self.llm.model = CONFIG.model_for_researcher_summary
self.web_browser_engine = WebBrowserEngine(
engine=WebBrowserEngineType.CUSTOM if self.browse_func else None,
@ -246,8 +245,6 @@ class ConductResearch(Action):
def __init__(self, **kwargs):
super().__init__(**kwargs)
if CONFIG.model_for_researcher_report:
self.llm.model = CONFIG.model_for_researcher_report
async def run(
self,

View file

@ -8,7 +8,7 @@
from typing import Optional
from metagpt.actions import Action
from metagpt.config import CONFIG
from metagpt.context import CONTEXT
from metagpt.logs import logger
@ -76,7 +76,7 @@ class WriteTeachingPlanPart(Action):
return value
# FIXME: 从Context中获取参数而非从options
merged_opts = CONFIG.options or {}
merged_opts = CONTEXT.options or {}
try:
return value.format(**merged_opts)
except KeyError as e:

View file

@ -71,6 +71,9 @@ class Config(CLIParams, YamlModel):
METAGPT_TEXT_TO_IMAGE_MODEL_URL: str = ""
language: str = "English"
redis_key: str = "placeholder"
mmdc: str = "mmdc"
puppeteer_config: str = ""
pyppeteer_executable_path: str = ""
@classmethod
def default(cls):

View file

@ -13,7 +13,7 @@ import aiofiles
import yaml
from pydantic import BaseModel, Field
from metagpt.config import CONFIG
from metagpt.context import CONTEXT
class Example(BaseModel):
@ -80,7 +80,7 @@ class SkillsDeclaration(BaseModel):
return {}
# List of skills that the agent chooses to activate.
agent_skills = CONFIG.agent_skills
agent_skills = CONTEXT.kwargs.agent_skills
if not agent_skills:
return {}

View file

@ -7,7 +7,6 @@
@Desc : Text-to-Embedding skill, which provides text-to-embedding functionality.
"""
from metagpt.config import CONFIG
from metagpt.tools.openai_text_to_embedding import oas3_openai_text_to_embedding
@ -19,6 +18,4 @@ async def text_to_embedding(text, model="text-embedding-ada-002", openai_api_key
:param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`
:return: A json object of :class:`ResultEmbedding` class if successful, otherwise `{}`.
"""
if CONFIG.OPENAI_API_KEY or openai_api_key:
return await oas3_openai_text_to_embedding(text, model=model, openai_api_key=openai_api_key)
raise EnvironmentError
return await oas3_openai_text_to_embedding(text, model=model, openai_api_key=openai_api_key)

View file

@ -8,6 +8,7 @@
"""
from metagpt.config import CONFIG
from metagpt.config2 import config
from metagpt.const import BASE64_FORMAT
from metagpt.tools.azure_tts import oas3_azsure_tts
from metagpt.tools.iflytek_tts import oas3_iflytek_tts
@ -47,7 +48,7 @@ async def text_to_speech(
if (CONFIG.AZURE_TTS_SUBSCRIPTION_KEY and CONFIG.AZURE_TTS_REGION) or (subscription_key and region):
audio_declaration = "data:audio/wav;base64,"
base64_data = await oas3_azsure_tts(text, lang, voice, style, role, subscription_key, region)
s3 = S3()
s3 = S3(config.s3)
url = await s3.cache(data=base64_data, file_ext=".wav", format=BASE64_FORMAT)
if url:
return f"[{text}]({url})"

View file

@ -13,7 +13,7 @@ import aiohttp
import requests
from pydantic import BaseModel, Field
from metagpt.config import CONFIG
from metagpt.config2 import config
from metagpt.logs import logger
@ -47,7 +47,8 @@ class OpenAIText2Embedding:
"""
:param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`
"""
self.openai_api_key = openai_api_key or CONFIG.OPENAI_API_KEY
self.openai_llm = config.get_openai_llm()
self.openai_api_key = openai_api_key or self.openai_llm.api_key
async def text_2_embedding(self, text, model="text-embedding-ada-002"):
"""Text to embedding
@ -57,7 +58,7 @@ class OpenAIText2Embedding:
:return: A json object of :class:`ResultEmbedding` class if successful, otherwise `{}`.
"""
proxies = {"proxy": CONFIG.openai_proxy} if CONFIG.openai_proxy else {}
proxies = {"proxy": self.openai_llm.proxy} if self.openai_llm.proxy else {}
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.openai_api_key}"}
data = {"input": text, "model": model}
url = "https://api.openai.com/v1/embeddings"
@ -83,5 +84,5 @@ async def oas3_openai_text_to_embedding(text, model="text-embedding-ada-002", op
if not text:
return ""
if not openai_api_key:
openai_api_key = CONFIG.OPENAI_API_KEY
openai_api_key = config.get_openai_llm().api_key
return await OpenAIText2Embedding(openai_api_key).text_2_embedding(text, model=model)

View file

@ -1,133 +0,0 @@
# -*- coding: utf-8 -*-
# @Date : 2023/7/19 16:28
# @Author : stellahong (stellahong@deepwisdom.ai)
# @Desc :
import asyncio
import base64
import io
import json
from os.path import join
from typing import List
from aiohttp import ClientSession
from PIL import Image, PngImagePlugin
from metagpt.config import CONFIG
from metagpt.const import SD_OUTPUT_FILE_REPO
from metagpt.logs import logger
payload = {
"prompt": "",
"negative_prompt": "(easynegative:0.8),black, dark,Low resolution",
"override_settings": {"sd_model_checkpoint": "galaxytimemachinesGTM_photoV20"},
"seed": -1,
"batch_size": 1,
"n_iter": 1,
"steps": 20,
"cfg_scale": 7,
"width": 512,
"height": 768,
"restore_faces": False,
"tiling": False,
"do_not_save_samples": False,
"do_not_save_grid": False,
"enable_hr": False,
"hr_scale": 2,
"hr_upscaler": "Latent",
"hr_second_pass_steps": 0,
"hr_resize_x": 0,
"hr_resize_y": 0,
"hr_upscale_to_x": 0,
"hr_upscale_to_y": 0,
"truncate_x": 0,
"truncate_y": 0,
"applied_old_hires_behavior_to": None,
"eta": None,
"sampler_index": "DPM++ SDE Karras",
"alwayson_scripts": {},
}
default_negative_prompt = "(easynegative:0.8),black, dark,Low resolution"
class SDEngine:
def __init__(self):
# Initialize the SDEngine with configuration
self.sd_url = CONFIG.get("SD_URL")
self.sd_t2i_url = f"{self.sd_url}{CONFIG.get('SD_T2I_API')}"
# Define default payload settings for SD API
self.payload = payload
logger.info(self.sd_t2i_url)
def construct_payload(
self,
prompt,
negtive_prompt=default_negative_prompt,
width=512,
height=512,
sd_model="galaxytimemachinesGTM_photoV20",
):
# Configure the payload with provided inputs
self.payload["prompt"] = prompt
self.payload["negtive_prompt"] = negtive_prompt
self.payload["width"] = width
self.payload["height"] = height
self.payload["override_settings"]["sd_model_checkpoint"] = sd_model
logger.info(f"call sd payload is {self.payload}")
return self.payload
def _save(self, imgs, save_name=""):
save_dir = CONFIG.path / SD_OUTPUT_FILE_REPO
if not save_dir.exists():
save_dir.mkdir(parents=True, exist_ok=True)
batch_decode_base64_to_image(imgs, str(save_dir), save_name=save_name)
async def run_t2i(self, prompts: List):
# Asynchronously run the SD API for multiple prompts
session = ClientSession()
for payload_idx, payload in enumerate(prompts):
results = await self.run(url=self.sd_t2i_url, payload=payload, session=session)
self._save(results, save_name=f"output_{payload_idx}")
await session.close()
async def run(self, url, payload, session):
# Perform the HTTP POST request to the SD API
async with session.post(url, json=payload, timeout=600) as rsp:
data = await rsp.read()
rsp_json = json.loads(data)
imgs = rsp_json["images"]
logger.info(f"callback rsp json is {rsp_json.keys()}")
return imgs
async def run_i2i(self):
# todo: 添加图生图接口调用
raise NotImplementedError
async def run_sam(self):
# todo添加SAM接口调用
raise NotImplementedError
def decode_base64_to_image(img, save_name):
image = Image.open(io.BytesIO(base64.b64decode(img.split(",", 1)[0])))
pnginfo = PngImagePlugin.PngInfo()
logger.info(save_name)
image.save(f"{save_name}.png", pnginfo=pnginfo)
return pnginfo, image
def batch_decode_base64_to_image(imgs, save_dir="", save_name=""):
for idx, _img in enumerate(imgs):
save_name = join(save_dir, save_name)
decode_base64_to_image(_img, save_name=save_name)
if __name__ == "__main__":
engine = SDEngine()
prompt = "pixel style, game design, a game interface should be minimalistic and intuitive with the score and high score displayed at the top. The snake and its food should be easily distinguishable. The game should have a simple color scheme, with a contrasting color for the snake and its food. Complete interface boundary"
engine.construct_payload(prompt)
event_loop = asyncio.get_event_loop()
event_loop.run_until_complete(engine.run_t2i(prompt))

View file

@ -7,6 +7,8 @@ import json
from concurrent import futures
from typing import Literal, overload
from metagpt.config2 import config
try:
from duckduckgo_search import DDGS
except ImportError:
@ -15,8 +17,6 @@ except ImportError:
"You can install it by running the command: `pip install -e.[search-ddg]`"
)
from metagpt.config import CONFIG
class DDGAPIWrapper:
"""Wrapper around duckduckgo_search API.
@ -31,8 +31,8 @@ class DDGAPIWrapper:
executor: futures.Executor | None = None,
):
kwargs = {}
if CONFIG.global_proxy:
kwargs["proxies"] = CONFIG.global_proxy
if config.proxy:
kwargs["proxies"] = config.proxy
self.loop = loop
self.executor = executor
self.ddgs = DDGS(**kwargs)

View file

@ -11,7 +11,7 @@ from urllib.parse import urlparse
import httplib2
from pydantic import BaseModel, ConfigDict, Field, field_validator
from metagpt.config import CONFIG
from metagpt.config2 import config
from metagpt.logs import logger
try:
@ -35,7 +35,7 @@ class GoogleAPIWrapper(BaseModel):
@field_validator("google_api_key", mode="before")
@classmethod
def check_google_api_key(cls, val: str):
val = val or CONFIG.google_api_key
val = val or config.search["google"].api_key
if not val:
raise ValueError(
"To use, make sure you provide the google_api_key when constructing an object. Alternatively, "
@ -47,7 +47,7 @@ class GoogleAPIWrapper(BaseModel):
@field_validator("google_cse_id", mode="before")
@classmethod
def check_google_cse_id(cls, val: str):
val = val or CONFIG.google_cse_id
val = val or config.search["google"].cse_id
if not val:
raise ValueError(
"To use, make sure you provide the google_cse_id when constructing an object. Alternatively, "
@ -59,8 +59,8 @@ class GoogleAPIWrapper(BaseModel):
@property
def google_api_client(self):
build_kwargs = {"developerKey": self.google_api_key}
if CONFIG.global_proxy:
parse_result = urlparse(CONFIG.global_proxy)
if config.proxy:
parse_result = urlparse(config.proxy)
proxy_type = parse_result.scheme
if proxy_type == "https":
proxy_type = "http"

View file

@ -10,7 +10,7 @@ from typing import Any, Dict, Optional, Tuple
import aiohttp
from pydantic import BaseModel, ConfigDict, Field, field_validator
from metagpt.config import CONFIG
from metagpt.config2 import config
class SerpAPIWrapper(BaseModel):
@ -32,7 +32,7 @@ class SerpAPIWrapper(BaseModel):
@field_validator("serpapi_api_key", mode="before")
@classmethod
def check_serpapi_api_key(cls, val: str):
val = val or CONFIG.serpapi_api_key
val = val or config.search["serpapi"].api_key
if not val:
raise ValueError(
"To use, make sure you provide the serpapi_api_key when constructing an object. Alternatively, "

View file

@ -11,7 +11,7 @@ from typing import Any, Dict, Optional, Tuple
import aiohttp
from pydantic import BaseModel, ConfigDict, Field, field_validator
from metagpt.config import CONFIG
from metagpt.config2 import config
class SerperWrapper(BaseModel):
@ -25,7 +25,7 @@ class SerperWrapper(BaseModel):
@field_validator("serper_api_key", mode="before")
@classmethod
def check_serper_api_key(cls, val: str):
val = val or CONFIG.serper_api_key
val = val or config.search["serper"].api_key
if not val:
raise ValueError(
"To use, make sure you provide the serper_api_key when constructing an object. Alternatively, "

View file

@ -8,7 +8,6 @@ from __future__ import annotations
import importlib
from typing import Any, Callable, Coroutine, overload
from metagpt.config import CONFIG
from metagpt.tools import WebBrowserEngineType
from metagpt.utils.parse_html import WebPage
@ -19,7 +18,6 @@ class WebBrowserEngine:
engine: WebBrowserEngineType | None = None,
run_func: Callable[..., Coroutine[Any, Any, WebPage | list[WebPage]]] | None = None,
):
engine = engine or CONFIG.web_browser_engine
if engine is None:
raise NotImplementedError

View file

@ -12,7 +12,7 @@ from typing import Literal
from playwright.async_api import async_playwright
from metagpt.config import CONFIG
from metagpt.config2 import config
from metagpt.logs import logger
from metagpt.utils.parse_html import WebPage
@ -33,13 +33,13 @@ class PlaywrightWrapper:
**kwargs,
) -> None:
if browser_type is None:
browser_type = CONFIG.playwright_browser_type
browser_type = config.browser["playwright"].driver
self.browser_type = browser_type
launch_kwargs = launch_kwargs or {}
if CONFIG.global_proxy and "proxy" not in launch_kwargs:
if config.proxy and "proxy" not in launch_kwargs:
args = launch_kwargs.get("args", [])
if not any(str.startswith(i, "--proxy-server=") for i in args):
launch_kwargs["proxy"] = {"server": CONFIG.global_proxy}
launch_kwargs["proxy"] = {"server": config.proxy}
self.launch_kwargs = launch_kwargs
context_kwargs = {}
if "ignore_https_errors" in kwargs:
@ -79,8 +79,8 @@ class PlaywrightWrapper:
executable_path = Path(browser_type.executable_path)
if not executable_path.exists() and "executable_path" not in self.launch_kwargs:
kwargs = {}
if CONFIG.global_proxy:
kwargs["env"] = {"ALL_PROXY": CONFIG.global_proxy}
if config.proxy:
kwargs["env"] = {"ALL_PROXY": config.proxy}
await _install_browsers(self.browser_type, **kwargs)
if self._has_run_precheck:

View file

@ -17,7 +17,7 @@ from selenium.webdriver.support.wait import WebDriverWait
from webdriver_manager.core.download_manager import WDMDownloadManager
from webdriver_manager.core.http import WDMHttpClient
from metagpt.config import CONFIG
from metagpt.config2 import config
from metagpt.utils.parse_html import WebPage
@ -41,12 +41,10 @@ class SeleniumWrapper:
loop: asyncio.AbstractEventLoop | None = None,
executor: futures.Executor | None = None,
) -> None:
if browser_type is None:
browser_type = CONFIG.selenium_browser_type
self.browser_type = browser_type
launch_kwargs = launch_kwargs or {}
if CONFIG.global_proxy and "proxy-server" not in launch_kwargs:
launch_kwargs["proxy-server"] = CONFIG.global_proxy
if config.proxy and "proxy-server" not in launch_kwargs:
launch_kwargs["proxy-server"] = config.proxy
self.executable_path = launch_kwargs.pop("executable_path", None)
self.launch_args = [f"--{k}={v}" for k, v in launch_kwargs.items()]
@ -97,8 +95,8 @@ _webdriver_manager_types = {
class WDMHttpProxyClient(WDMHttpClient):
def get(self, url, **kwargs):
if "proxies" not in kwargs and CONFIG.global_proxy:
kwargs["proxies"] = {"all_proxy": CONFIG.global_proxy}
if "proxies" not in kwargs and config.proxy:
kwargs["proxies"] = {"all_proxy": config.proxy}
return super().get(url, **kwargs)

View file

@ -12,7 +12,7 @@ from pathlib import Path
import aiofiles
from metagpt.config import CONFIG
from metagpt.config2 import config
from metagpt.logs import logger
from metagpt.utils.common import check_cmd_exists
@ -35,9 +35,9 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
await f.write(mermaid_code)
# tmp.write_text(mermaid_code, encoding="utf-8")
engine = CONFIG.mermaid_engine.lower()
engine = config.mermaid["default"].engine
if engine == "nodejs":
if check_cmd_exists(CONFIG.mmdc) != 0:
if check_cmd_exists(config.mmdc) != 0:
logger.warning(
"RUN `npm install -g @mermaid-js/mermaid-cli` to install mmdc,"
"or consider changing MERMAID_ENGINE to `playwright`, `pyppeteer`, or `ink`."
@ -49,11 +49,11 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
# Call the `mmdc` command to convert the Mermaid code to a PNG
logger.info(f"Generating {output_file}..")
if CONFIG.puppeteer_config:
if config.puppeteer_config:
commands = [
CONFIG.mmdc,
config.mmdc,
"-p",
CONFIG.puppeteer_config,
config.puppeteer_config,
"-i",
str(tmp),
"-o",
@ -64,7 +64,7 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
str(height),
]
else:
commands = [CONFIG.mmdc, "-i", str(tmp), "-o", output_file, "-w", str(width), "-H", str(height)]
commands = [config.mmdc, "-i", str(tmp), "-o", output_file, "-w", str(width), "-H", str(height)]
process = await asyncio.create_subprocess_shell(
" ".join(commands), stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)

View file

@ -10,7 +10,7 @@ from urllib.parse import urljoin
from pyppeteer import launch
from metagpt.config import CONFIG
from metagpt.config2 import config
from metagpt.logs import logger
@ -30,10 +30,10 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
suffixes = ["png", "svg", "pdf"]
__dirname = os.path.dirname(os.path.abspath(__file__))
if CONFIG.pyppeteer_executable_path:
if config.pyppeteer_executable_path:
browser = await launch(
headless=True,
executablePath=CONFIG.pyppeteer_executable_path,
executablePath=config.pyppeteer_executable_path,
args=["--disable-extensions", "--no-sandbox"],
)
else:

View file

@ -9,7 +9,7 @@ from typing import Callable, Union
import regex as re
from tenacity import RetryCallState, retry, stop_after_attempt, wait_fixed
from metagpt.config import CONFIG
from metagpt.config2 import config
from metagpt.logs import logger
from metagpt.utils.custom_decoder import CustomDecoder
@ -152,7 +152,7 @@ def repair_llm_raw_output(output: str, req_keys: list[str], repair_type: RepairT
target: { xxx }
output: { xxx }]
"""
if not CONFIG.repair_llm_output:
if not config.repair_llm_output:
return output
# do the repairation usually for non-openai models
@ -231,7 +231,7 @@ def run_after_exp_and_passon_next_retry(logger: "loguru.Logger") -> Callable[["R
func_param_output = retry_state.kwargs.get("output", "")
exp_str = str(retry_state.outcome.exception())
fix_str = "try to fix it, " if CONFIG.repair_llm_output else ""
fix_str = "try to fix it, " if config.repair_llm_output else ""
logger.warning(
f"parse json from content inside [CONTENT][/CONTENT] failed at retry "
f"{retry_state.attempt_number}, {fix_str}exp: {exp_str}"
@ -244,7 +244,7 @@ def run_after_exp_and_passon_next_retry(logger: "loguru.Logger") -> Callable[["R
@retry(
stop=stop_after_attempt(3 if CONFIG.repair_llm_output else 0),
stop=stop_after_attempt(3 if config.repair_llm_output else 0),
wait=wait_fixed(1),
after=run_after_exp_and_passon_next_retry(logger),
)

View file

@ -11,7 +11,6 @@ from pathlib import Path
import pytest
from metagpt.actions.rebuild_class_view import RebuildClassView
from metagpt.config import CONFIG
from metagpt.const import GRAPH_REPO_FILE_REPO
from metagpt.llm import LLM
@ -22,7 +21,7 @@ async def test_rebuild():
name="RedBean", context=str(Path(__file__).parent.parent.parent.parent / "metagpt"), llm=LLM()
)
await action.run()
graph_file_repo = CONFIG.git_repo.new_file_repository(relative_path=GRAPH_REPO_FILE_REPO)
graph_file_repo = CONTEXT.git_repo.new_file_repository(relative_path=GRAPH_REPO_FILE_REPO)
assert graph_file_repo.changed_files

View file

@ -10,7 +10,6 @@ from pathlib import Path
import pytest
from metagpt.actions.rebuild_sequence_view import RebuildSequenceView
from metagpt.config import CONFIG
from metagpt.const import GRAPH_REPO_FILE_REPO
from metagpt.llm import LLM
from metagpt.utils.common import aread
@ -22,20 +21,20 @@ from metagpt.utils.git_repository import ChangeType
async def test_rebuild():
# Mock
data = await aread(filename=Path(__file__).parent / "../../data/graph_db/networkx.json")
graph_db_filename = Path(CONFIG.git_repo.workdir.name).with_suffix(".json")
graph_db_filename = Path(CONTEXT.git_repo.workdir.name).with_suffix(".json")
await FileRepository.save_file(
filename=str(graph_db_filename),
relative_path=GRAPH_REPO_FILE_REPO,
content=data,
)
CONFIG.git_repo.add_change({f"{GRAPH_REPO_FILE_REPO}/{graph_db_filename}": ChangeType.UNTRACTED})
CONFIG.git_repo.commit("commit1")
CONTEXT.git_repo.add_change({f"{GRAPH_REPO_FILE_REPO}/{graph_db_filename}": ChangeType.UNTRACTED})
CONTEXT.git_repo.commit("commit1")
action = RebuildSequenceView(
name="RedBean", context=str(Path(__file__).parent.parent.parent.parent / "metagpt"), llm=LLM()
)
await action.run()
graph_file_repo = CONFIG.git_repo.new_file_repository(relative_path=GRAPH_REPO_FILE_REPO)
graph_file_repo = CONTEXT.git_repo.new_file_repository(relative_path=GRAPH_REPO_FILE_REPO)
assert graph_file_repo.changed_files

View file

@ -9,7 +9,6 @@
import pytest
from metagpt.actions.summarize_code import SummarizeCode
from metagpt.config import CONFIG
from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
from metagpt.context import CONTEXT
from metagpt.logs import logger
@ -181,12 +180,12 @@ async def test_summarize_code():
CONTEXT.src_workspace = CONTEXT.git_repo.workdir / "src"
await CONTEXT.file_repo.save_file(filename="1.json", relative_path=SYSTEM_DESIGN_FILE_REPO, content=DESIGN_CONTENT)
await CONTEXT.file_repo.save_file(filename="1.json", relative_path=TASK_FILE_REPO, content=TASK_CONTENT)
await CONTEXT.file_repo.save_file(filename="food.py", relative_path=CONFIG.src_workspace, content=FOOD_PY)
await CONTEXT.file_repo.save_file(filename="game.py", relative_path=CONFIG.src_workspace, content=GAME_PY)
await CONTEXT.file_repo.save_file(filename="main.py", relative_path=CONFIG.src_workspace, content=MAIN_PY)
await CONTEXT.file_repo.save_file(filename="snake.py", relative_path=CONFIG.src_workspace, content=SNAKE_PY)
await CONTEXT.file_repo.save_file(filename="food.py", relative_path=CONTEXT.src_workspace, content=FOOD_PY)
await CONTEXT.file_repo.save_file(filename="game.py", relative_path=CONTEXT.src_workspace, content=GAME_PY)
await CONTEXT.file_repo.save_file(filename="main.py", relative_path=CONTEXT.src_workspace, content=MAIN_PY)
await CONTEXT.file_repo.save_file(filename="snake.py", relative_path=CONTEXT.src_workspace, content=SNAKE_PY)
src_file_repo = CONTEXT.git_repo.new_file_repository(relative_path=CONFIG.src_workspace)
src_file_repo = CONTEXT.git_repo.new_file_repository(relative_path=CONTEXT.src_workspace)
all_files = src_file_repo.all_files
ctx = CodeSummarizeContext(design_filename="1.json", task_filename="1.json", codes_filenames=all_files)
action = SummarizeCode(context=ctx)

View file

@ -10,13 +10,13 @@ from pathlib import Path
import pytest
from metagpt.config import CONFIG
from metagpt.context import CONTEXT
from metagpt.learn.skill_loader import SkillsDeclaration
@pytest.mark.asyncio
async def test_suite():
CONFIG.agent_skills = [
CONTEXT.kwargs.agent_skills = [
{"id": 1, "name": "text_to_speech", "type": "builtin", "config": {}, "enabled": True},
{"id": 2, "name": "text_to_image", "type": "builtin", "config": {}, "enabled": True},
{"id": 3, "name": "ai_call", "type": "builtin", "config": {}, "enabled": True},

View file

@ -9,14 +9,14 @@
import pytest
from metagpt.config import CONFIG
from metagpt.config2 import config
from metagpt.learn.text_to_embedding import text_to_embedding
@pytest.mark.asyncio
async def test_text_to_embedding():
# Prerequisites
assert CONFIG.OPENAI_API_KEY
assert config.get_openai_llm()
v = await text_to_embedding(text="Panda emoji")
assert len(v.data) > 0

View file

@ -12,6 +12,7 @@ import pytest
from azure.cognitiveservices.speech import ResultReason
from metagpt.config import CONFIG
from metagpt.config2 import config
from metagpt.tools.azure_tts import AzureTTS
@ -32,7 +33,7 @@ async def test_azure_tts():
Writing a binary file in Python is similar to writing a regular text file, but you'll work with bytes instead of strings.”
</mstts:express-as>
"""
path = CONFIG.path / "tts"
path = config.workspace.path / "tts"
path.mkdir(exist_ok=True, parents=True)
filename = path / "girl.wav"
filename.unlink(missing_ok=True)

View file

@ -12,14 +12,14 @@ from pathlib import Path
import pytest
import requests
from metagpt.config import CONFIG
from metagpt.context import CONTEXT
@pytest.mark.asyncio
async def test_oas2_svc():
workdir = Path(__file__).parent.parent.parent.parent
script_pathname = workdir / "metagpt/tools/metagpt_oas3_api_svc.py"
env = CONFIG.new_environ()
env = CONTEXT.new_environ()
env["PYTHONPATH"] = str(workdir) + ":" + env.get("PYTHONPATH", "")
process = subprocess.Popen(["python", str(script_pathname)], cwd=str(workdir), env=env)
await asyncio.sleep(5)

View file

@ -10,7 +10,7 @@ from unittest.mock import AsyncMock
import pytest
from metagpt.config import CONFIG
from metagpt.config2 import config
from metagpt.tools.metagpt_text_to_image import oas3_metagpt_text_to_image
@ -24,7 +24,7 @@ async def test_draw(mocker):
mock_post.return_value.__aenter__.return_value = mock_response
# Prerequisites
assert CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL
assert config.METAGPT_TEXT_TO_IMAGE_MODEL_URL
binary_data = await oas3_metagpt_text_to_image("Panda emoji")
assert binary_data

View file

@ -8,7 +8,7 @@
import pytest
from metagpt.config import CONFIG
from metagpt.config2 import config
from metagpt.llm import LLM
from metagpt.tools.moderation import Moderation
@ -24,9 +24,7 @@ from metagpt.tools.moderation import Moderation
)
async def test_amoderation(content):
# Prerequisites
assert CONFIG.OPENAI_API_KEY and CONFIG.OPENAI_API_KEY != "YOUR_API_KEY"
assert not CONFIG.OPENAI_API_TYPE
assert CONFIG.OPENAI_API_MODEL
assert config.get_openai_llm()
moderation = Moderation(LLM())
results = await moderation.amoderation(content=content)

View file

@ -8,16 +8,14 @@
import pytest
from metagpt.config import CONFIG
from metagpt.config2 import config
from metagpt.tools.openai_text_to_embedding import oas3_openai_text_to_embedding
@pytest.mark.asyncio
async def test_embedding():
# Prerequisites
assert CONFIG.OPENAI_API_KEY and CONFIG.OPENAI_API_KEY != "YOUR_API_KEY"
assert not CONFIG.OPENAI_API_TYPE
assert CONFIG.OPENAI_API_MODEL
assert config.get_openai_llm()
result = await oas3_openai_text_to_embedding("Panda emoji")
assert result

View file

@ -8,7 +8,7 @@
import pytest
from metagpt.config import CONFIG
from metagpt.config2 import config
from metagpt.tools.openai_text_to_image import (
OpenAIText2Image,
oas3_openai_text_to_image,
@ -18,9 +18,7 @@ from metagpt.tools.openai_text_to_image import (
@pytest.mark.asyncio
async def test_draw():
# Prerequisites
assert CONFIG.OPENAI_API_KEY and CONFIG.OPENAI_API_KEY != "YOUR_API_KEY"
assert not CONFIG.OPENAI_API_TYPE
assert CONFIG.OPENAI_API_MODEL
assert config.get_openai_llm()
binary_data = await oas3_openai_text_to_image("Panda emoji")
assert binary_data

View file

@ -12,14 +12,14 @@ from pathlib import Path
import pytest
import requests
from metagpt.config import CONFIG
from metagpt.context import CONTEXT
@pytest.mark.asyncio
async def test_hello():
workdir = Path(__file__).parent.parent.parent.parent
script_pathname = workdir / "metagpt/tools/openapi_v3_hello.py"
env = CONFIG.new_environ()
env = CONTEXT.new_environ()
env["PYTHONPATH"] = str(workdir) + ":" + env.get("PYTHONPATH", "")
process = subprocess.Popen(["python", str(script_pathname)], cwd=workdir, env=env)
await asyncio.sleep(5)

View file

@ -1,26 +0,0 @@
# -*- coding: utf-8 -*-
# @Date : 2023/7/22 02:40
# @Author : stellahong (stellahong@deepwisdom.ai)
#
import os
from metagpt.config import CONFIG
from metagpt.tools.sd_engine import SDEngine
def test_sd_engine_init():
sd_engine = SDEngine()
assert sd_engine.payload["seed"] == -1
def test_sd_engine_generate_prompt():
sd_engine = SDEngine()
sd_engine.construct_payload(prompt="test")
assert sd_engine.payload["prompt"] == "test"
async def test_sd_engine_run_t2i():
sd_engine = SDEngine()
await sd_engine.run_t2i(prompts=["test"])
img_path = CONFIG.path / "resources" / "SD_Output" / "output_0.png"
assert os.path.exists(img_path)

View file

@ -14,7 +14,7 @@ from typing import Callable
import pytest
import tests.data.search
from metagpt.config import CONFIG
from metagpt.config2 import config
from metagpt.logs import logger
from metagpt.tools import SearchEngineType
from metagpt.tools.search_engine import SearchEngine
@ -50,13 +50,12 @@ async def test_search_engine(search_engine_type, run_func: Callable, max_results
# Prerequisites
cache_json_path = None
if search_engine_type is SearchEngineType.SERPAPI_GOOGLE:
assert CONFIG.SERPAPI_API_KEY and CONFIG.SERPAPI_API_KEY != "YOUR_API_KEY"
assert config.search["serpapi"]
cache_json_path = search_cache_path / f"serpapi-metagpt-{max_results}.json"
elif search_engine_type is SearchEngineType.DIRECT_GOOGLE:
assert CONFIG.GOOGLE_API_KEY and CONFIG.GOOGLE_API_KEY != "YOUR_API_KEY"
assert CONFIG.GOOGLE_CSE_ID and CONFIG.GOOGLE_CSE_ID != "YOUR_CSE_ID"
assert config.search["google"]
elif search_engine_type is SearchEngineType.SERPER_GOOGLE:
assert CONFIG.SERPER_API_KEY and CONFIG.SERPER_API_KEY != "YOUR_API_KEY"
assert config.search["serper"]
cache_json_path = search_cache_path / f"serper-metagpt-{max_results}.json"
if cache_json_path:

View file

@ -9,7 +9,7 @@ from pathlib import Path
import pytest
from metagpt.config import CONFIG
from metagpt.config2 import config
from metagpt.const import API_QUESTIONS_PATH, UT_PY_PATH
from metagpt.tools.ut_writer import YFT_PROMPT_PREFIX, UTGenerator
@ -20,9 +20,7 @@ class TestUTWriter:
# Prerequisites
swagger_file = Path(__file__).parent / "../../data/ut_writer/yft_swaggerApi.json"
assert swagger_file.exists()
assert CONFIG.OPENAI_API_KEY and CONFIG.OPENAI_API_KEY != "YOUR_API_KEY"
assert not CONFIG.OPENAI_API_TYPE
assert CONFIG.OPENAI_API_MODEL
assert config.get_openai_llm()
tags = ["测试", "作业"]
# 这里在文件中手动加入了两个测试标签的API

View file

@ -9,6 +9,7 @@
import pytest
from metagpt.config import CONFIG
from metagpt.context import CONTEXT
from metagpt.utils.common import check_cmd_exists
from metagpt.utils.mermaid import MMC1, mermaid_to_file
@ -22,7 +23,7 @@ async def test_mermaid(engine):
assert check_cmd_exists("npm") == 0
CONFIG.mermaid_engine = engine
save_to = CONFIG.git_repo.workdir / f"{CONFIG.mermaid_engine}/1"
save_to = CONTEXT.git_repo.workdir / f"{CONFIG.mermaid_engine}/1"
await mermaid_to_file(MMC1, save_to)
# ink does not support pdf

View file

@ -2,13 +2,13 @@
# -*- coding: utf-8 -*-
# @Desc : unittest of repair_llm_raw_output
from metagpt.config import CONFIG
from metagpt.config2 import config
"""
CONFIG.repair_llm_output should be True before retry_parse_json_text imported.
so we move `from ... impot ...` into each `test_xx` to avoid `Module level import not at top of file` format warning.
"""
CONFIG.repair_llm_output = True
config.repair_llm_output = True
def test_repair_case_sensitivity():