mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-27 14:25:20 +02:00
feat: merge geekan:dev
This commit is contained in:
commit
cbad5170ac
25 changed files with 83 additions and 630 deletions
|
|
@ -184,7 +184,7 @@ class WebBrowseAndSummarize(Action):
|
|||
super().__init__(**kwargs)
|
||||
|
||||
self.web_browser_engine = WebBrowserEngine(
|
||||
engine=WebBrowserEngineType.CUSTOM if self.browse_func else None,
|
||||
engine=WebBrowserEngineType.CUSTOM if self.browse_func else WebBrowserEngineType.PLAYWRIGHT,
|
||||
run_func=self.browse_func,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,270 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Provide configuration, singleton
|
||||
@Modified By: mashenquan, 2023/11/27.
|
||||
1. According to Section 2.2.3.11 of RFC 135, add git repository support.
|
||||
2. Add the parameter `src_workspace` for the old version project path.
|
||||
"""
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import yaml
|
||||
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.const import DEFAULT_WORKSPACE_ROOT, METAGPT_ROOT, OPTIONS
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools import SearchEngineType, WebBrowserEngineType
|
||||
from metagpt.utils.common import require_python_version
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.singleton import Singleton
|
||||
|
||||
|
||||
class NotConfiguredException(Exception):
|
||||
"""Exception raised for errors in the configuration.
|
||||
|
||||
Attributes:
|
||||
message -- explanation of the error
|
||||
"""
|
||||
|
||||
def __init__(self, message="The required configuration is not set"):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class Config(metaclass=Singleton):
|
||||
"""
|
||||
Regular usage method:
|
||||
config = Config("config.yaml")
|
||||
secret_key = config.get_key("MY_SECRET_KEY")
|
||||
print("Secret key:", secret_key)
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
home_yaml_file = Path.home() / ".metagpt/config.yaml"
|
||||
key_yaml_file = METAGPT_ROOT / "config/key.yaml"
|
||||
default_yaml_file = METAGPT_ROOT / "config/config.yaml"
|
||||
|
||||
def __init__(self, yaml_file=default_yaml_file, cost_data=""):
|
||||
global_options = OPTIONS.get()
|
||||
# cli paras
|
||||
self.project_path = ""
|
||||
self.project_name = ""
|
||||
self.inc = False
|
||||
self.reqa_file = ""
|
||||
self.max_auto_summarize_code = 0
|
||||
self.git_reinit = False
|
||||
|
||||
self._init_with_config_files_and_env(yaml_file)
|
||||
# The agent needs to be billed per user, so billing information cannot be destroyed when the session ends.
|
||||
self.cost_manager = CostManager(**json.loads(cost_data)) if cost_data else CostManager()
|
||||
self._update()
|
||||
global_options.update(OPTIONS.get())
|
||||
logger.debug("Config loading done.")
|
||||
|
||||
def get_default_llm_provider_enum(self) -> LLMType:
|
||||
"""Get first valid LLM provider enum"""
|
||||
mappings = {
|
||||
LLMType.OPENAI: bool(
|
||||
self._is_valid_llm_key(self.OPENAI_API_KEY) and not self.OPENAI_API_TYPE and self.OPENAI_API_MODEL
|
||||
),
|
||||
LLMType.ANTHROPIC: self._is_valid_llm_key(self.ANTHROPIC_API_KEY),
|
||||
LLMType.ZHIPUAI: self._is_valid_llm_key(self.ZHIPUAI_API_KEY),
|
||||
LLMType.FIREWORKS: self._is_valid_llm_key(self.FIREWORKS_API_KEY),
|
||||
LLMType.OPEN_LLM: self._is_valid_llm_key(self.OPEN_LLM_API_BASE),
|
||||
LLMType.GEMINI: self._is_valid_llm_key(self.GEMINI_API_KEY),
|
||||
LLMType.METAGPT: bool(self._is_valid_llm_key(self.OPENAI_API_KEY) and self.OPENAI_API_TYPE == "metagpt"),
|
||||
LLMType.AZURE: bool(
|
||||
self._is_valid_llm_key(self.OPENAI_API_KEY)
|
||||
and self.OPENAI_API_TYPE == "azure"
|
||||
and self.DEPLOYMENT_NAME
|
||||
and self.OPENAI_API_VERSION
|
||||
),
|
||||
LLMType.OLLAMA: self._is_valid_llm_key(self.OLLAMA_API_BASE),
|
||||
}
|
||||
provider = None
|
||||
for k, v in mappings.items():
|
||||
if v:
|
||||
provider = k
|
||||
break
|
||||
if provider is None:
|
||||
if self.DEFAULT_PROVIDER:
|
||||
provider = LLMType(self.DEFAULT_PROVIDER)
|
||||
else:
|
||||
raise NotConfiguredException("You should config a LLM configuration first")
|
||||
|
||||
if provider is LLMType.GEMINI and not require_python_version(req_version=(3, 10)):
|
||||
warnings.warn("Use Gemini requires Python >= 3.10")
|
||||
model_name = self.get_model_name(provider=provider)
|
||||
if model_name:
|
||||
logger.info(f"{provider} Model: {model_name}")
|
||||
if provider:
|
||||
logger.info(f"API: {provider}")
|
||||
return provider
|
||||
|
||||
def get_model_name(self, provider=None) -> str:
|
||||
provider = provider or self.get_default_llm_provider_enum()
|
||||
model_mappings = {
|
||||
LLMType.OPENAI: self.OPENAI_API_MODEL,
|
||||
LLMType.AZURE: self.DEPLOYMENT_NAME,
|
||||
}
|
||||
return model_mappings.get(provider, "")
|
||||
|
||||
@staticmethod
|
||||
def _is_valid_llm_key(k: str) -> bool:
|
||||
return bool(k and k != "YOUR_API_KEY")
|
||||
|
||||
def _update(self):
|
||||
self.global_proxy = self._get("GLOBAL_PROXY")
|
||||
|
||||
self.openai_api_key = self._get("OPENAI_API_KEY")
|
||||
self.anthropic_api_key = self._get("ANTHROPIC_API_KEY")
|
||||
self.zhipuai_api_key = self._get("ZHIPUAI_API_KEY")
|
||||
self.open_llm_api_base = self._get("OPEN_LLM_API_BASE")
|
||||
self.open_llm_api_model = self._get("OPEN_LLM_API_MODEL")
|
||||
self.fireworks_api_key = self._get("FIREWORKS_API_KEY")
|
||||
self.gemini_api_key = self._get("GEMINI_API_KEY")
|
||||
self.ollama_api_base = self._get("OLLAMA_API_BASE")
|
||||
self.ollama_api_model = self._get("OLLAMA_API_MODEL")
|
||||
|
||||
# if not self._get("DISABLE_LLM_PROVIDER_CHECK"):
|
||||
# _ = self.get_default_llm_provider_enum()
|
||||
|
||||
self.openai_base_url = self._get("OPENAI_BASE_URL")
|
||||
self.openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy
|
||||
self.openai_api_type = self._get("OPENAI_API_TYPE")
|
||||
self.openai_api_version = self._get("OPENAI_API_VERSION")
|
||||
self.openai_api_rpm = self._get("RPM", 3)
|
||||
self.openai_api_model = self._get("OPENAI_API_MODEL", "gpt-4-1106-preview")
|
||||
self.max_tokens_rsp = self._get("MAX_TOKENS", 2048)
|
||||
self.deployment_name = self._get("DEPLOYMENT_NAME", "gpt-4")
|
||||
|
||||
self.spark_appid = self._get("SPARK_APPID")
|
||||
self.spark_api_secret = self._get("SPARK_API_SECRET")
|
||||
self.spark_api_key = self._get("SPARK_API_KEY")
|
||||
self.domain = self._get("DOMAIN")
|
||||
self.spark_url = self._get("SPARK_URL")
|
||||
|
||||
self.fireworks_api_base = self._get("FIREWORKS_API_BASE")
|
||||
self.fireworks_api_model = self._get("FIREWORKS_API_MODEL")
|
||||
|
||||
self.claude_api_key = self._get("ANTHROPIC_API_KEY")
|
||||
|
||||
self.serpapi_api_key = self._get("SERPAPI_API_KEY")
|
||||
self.serper_api_key = self._get("SERPER_API_KEY")
|
||||
self.google_api_key = self._get("GOOGLE_API_KEY")
|
||||
self.google_cse_id = self._get("GOOGLE_CSE_ID")
|
||||
self.search_engine = SearchEngineType(self._get("SEARCH_ENGINE", SearchEngineType.SERPAPI_GOOGLE))
|
||||
self.web_browser_engine = WebBrowserEngineType(self._get("WEB_BROWSER_ENGINE", WebBrowserEngineType.PLAYWRIGHT))
|
||||
self.playwright_browser_type = self._get("PLAYWRIGHT_BROWSER_TYPE", "chromium")
|
||||
self.selenium_browser_type = self._get("SELENIUM_BROWSER_TYPE", "chrome")
|
||||
|
||||
self.long_term_memory = self._get("LONG_TERM_MEMORY", False)
|
||||
if self.long_term_memory:
|
||||
logger.warning("LONG_TERM_MEMORY is True")
|
||||
self.cost_manager.max_budget = self._get("MAX_BUDGET", 10.0)
|
||||
self.code_review_k_times = 2
|
||||
|
||||
self.puppeteer_config = self._get("PUPPETEER_CONFIG", "")
|
||||
self.mmdc = self._get("MMDC", "mmdc")
|
||||
self.calc_usage = self._get("CALC_USAGE", True)
|
||||
self.model_for_researcher_summary = self._get("MODEL_FOR_RESEARCHER_SUMMARY")
|
||||
self.model_for_researcher_report = self._get("MODEL_FOR_RESEARCHER_REPORT")
|
||||
self.mermaid_engine = self._get("MERMAID_ENGINE", "nodejs")
|
||||
self.pyppeteer_executable_path = self._get("PYPPETEER_EXECUTABLE_PATH", "")
|
||||
|
||||
workspace_uid = (
|
||||
self._get("WORKSPACE_UID") or f"{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}-{uuid4().hex[-8:]}"
|
||||
)
|
||||
self.repair_llm_output = self._get("REPAIR_LLM_OUTPUT", False)
|
||||
self.prompt_schema = self._get("PROMPT_FORMAT", "json")
|
||||
self.workspace_path = Path(self._get("WORKSPACE_PATH", DEFAULT_WORKSPACE_ROOT))
|
||||
val = self._get("WORKSPACE_PATH_WITH_UID")
|
||||
if val and val.lower() == "true": # for agent
|
||||
self.workspace_path = self.workspace_path / workspace_uid
|
||||
self._ensure_workspace_exists()
|
||||
self.max_auto_summarize_code = self.max_auto_summarize_code or self._get("MAX_AUTO_SUMMARIZE_CODE", 1)
|
||||
self.timeout = int(self._get("TIMEOUT", 60))
|
||||
|
||||
def update_via_cli(self, project_path, project_name, inc, reqa_file, max_auto_summarize_code):
|
||||
"""update config via cli"""
|
||||
|
||||
# Use in the PrepareDocuments action according to Section 2.2.3.5.1 of RFC 135.
|
||||
if project_path:
|
||||
inc = True
|
||||
project_name = project_name or Path(project_path).name
|
||||
self.project_path = project_path
|
||||
self.project_name = project_name
|
||||
self.inc = inc
|
||||
self.reqa_file = reqa_file
|
||||
self.max_auto_summarize_code = max_auto_summarize_code
|
||||
|
||||
def _ensure_workspace_exists(self):
|
||||
self.workspace_path.mkdir(parents=True, exist_ok=True)
|
||||
logger.debug(f"WORKSPACE_PATH set to {self.workspace_path}")
|
||||
|
||||
def _init_with_config_files_and_env(self, yaml_file):
|
||||
"""Load from config/key.yaml, config/config.yaml, and env in decreasing order of priority"""
|
||||
configs = dict(os.environ)
|
||||
|
||||
for _yaml_file in [yaml_file, self.key_yaml_file, self.home_yaml_file]:
|
||||
if not _yaml_file.exists():
|
||||
continue
|
||||
|
||||
# Load local YAML file
|
||||
with open(_yaml_file, "r", encoding="utf-8") as file:
|
||||
yaml_data = yaml.safe_load(file)
|
||||
if not yaml_data:
|
||||
continue
|
||||
configs.update(yaml_data)
|
||||
OPTIONS.set(configs)
|
||||
|
||||
@staticmethod
|
||||
def _get(*args, **kwargs):
|
||||
i = OPTIONS.get()
|
||||
return i.get(*args, **kwargs)
|
||||
|
||||
def get(self, key, *args, **kwargs):
|
||||
"""Retrieve values from config/key.yaml, config/config.yaml, and environment variables.
|
||||
Throw an error if not found."""
|
||||
value = self._get(key, *args, **kwargs)
|
||||
if value is None:
|
||||
raise ValueError(f"Key '{key}' not found in environment variables or in the YAML file")
|
||||
return value
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
OPTIONS.get()[name] = value
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
i = OPTIONS.get()
|
||||
return i.get(name)
|
||||
|
||||
def set_context(self, options: dict):
|
||||
"""Update current config"""
|
||||
if not options:
|
||||
return
|
||||
opts = deepcopy(OPTIONS.get())
|
||||
opts.update(options)
|
||||
OPTIONS.set(opts)
|
||||
self._update()
|
||||
|
||||
@property
|
||||
def options(self):
|
||||
"""Return all key-values"""
|
||||
return OPTIONS.get()
|
||||
|
||||
def new_environ(self):
|
||||
"""Return a new os.environ object"""
|
||||
env = os.environ.copy()
|
||||
i = self.options
|
||||
env.update({k: v for k, v in i.items() if isinstance(v, str)})
|
||||
return env
|
||||
|
||||
|
||||
CONFIG = Config()
|
||||
|
|
@ -9,7 +9,7 @@ import os
|
|||
from pathlib import Path
|
||||
from typing import Dict, Iterable, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from metagpt.configs.browser_config import BrowserConfig
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
|
|
@ -44,15 +44,15 @@ class Config(CLIParams, YamlModel):
|
|||
"""Configurations for MetaGPT"""
|
||||
|
||||
# Key Parameters
|
||||
llm: Dict[str, LLMConfig] = Field(default_factory=Dict)
|
||||
llm: LLMConfig
|
||||
|
||||
# Global Proxy. Will be used if llm.proxy is not set
|
||||
proxy: str = ""
|
||||
|
||||
# Tool Parameters
|
||||
search: Dict[str, SearchConfig] = {}
|
||||
browser: Dict[str, BrowserConfig] = {"default": BrowserConfig()}
|
||||
mermaid: Dict[str, MermaidConfig] = {"default": MermaidConfig()}
|
||||
search: Optional[SearchConfig] = None
|
||||
browser: BrowserConfig = BrowserConfig()
|
||||
mermaid: MermaidConfig = MermaidConfig()
|
||||
|
||||
# Storage Parameters
|
||||
s3: Optional[S3Config] = None
|
||||
|
|
@ -110,46 +110,17 @@ class Config(CLIParams, YamlModel):
|
|||
self.reqa_file = reqa_file
|
||||
self.max_auto_summarize_code = max_auto_summarize_code
|
||||
|
||||
def _get_llm_config(self, name: Optional[str] = None) -> LLMConfig:
|
||||
"""Get LLM instance by name"""
|
||||
if name is None:
|
||||
# Use the first LLM as default
|
||||
name = list(self.llm.keys())[0]
|
||||
if name not in self.llm:
|
||||
raise ValueError(f"LLM {name} not found in config")
|
||||
return self.llm[name]
|
||||
|
||||
def get_llm_configs_by_type(self, llm_type: LLMType) -> List[LLMConfig]:
|
||||
"""Get LLM instance by type"""
|
||||
return [v for k, v in self.llm.items() if v.api_type == llm_type]
|
||||
|
||||
def get_llm_config_by_type(self, llm_type: LLMType) -> Optional[LLMConfig]:
|
||||
"""Get LLM instance by type"""
|
||||
llm = self.get_llm_configs_by_type(llm_type)
|
||||
if llm:
|
||||
return llm[0]
|
||||
return None
|
||||
|
||||
def get_llm_config(self, name: Optional[str] = None, provider: LLMType = None) -> LLMConfig:
|
||||
"""Return a LLMConfig instance"""
|
||||
if provider:
|
||||
llm_configs = self.get_llm_configs_by_type(provider)
|
||||
|
||||
if len(llm_configs) == 0:
|
||||
raise ValueError(f"Cannot find llm config with name {name} and provider {provider}")
|
||||
# return the first one if name is None, or return the only one
|
||||
llm_config = llm_configs[0]
|
||||
else:
|
||||
llm_config = self._get_llm_config(name)
|
||||
return llm_config
|
||||
|
||||
def get_openai_llm(self) -> Optional[LLMConfig]:
|
||||
"""Get OpenAI LLMConfig by name. If no OpenAI, raise Exception"""
|
||||
return self.get_llm_config_by_type(LLMType.OPENAI)
|
||||
if self.llm.api_type == LLMType.OPENAI:
|
||||
return self.llm
|
||||
return None
|
||||
|
||||
def get_azure_llm(self) -> Optional[LLMConfig]:
|
||||
"""Get Azure LLMConfig by name. If no Azure, raise Exception"""
|
||||
return self.get_llm_config_by_type(LLMType.AZURE)
|
||||
if self.llm.api_type == LLMType.AZURE:
|
||||
return self.llm
|
||||
return None
|
||||
|
||||
|
||||
def merge_dict(dicts: Iterable[Dict]) -> Dict:
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ class LLMConfig(YamlModel):
|
|||
api_type: LLMType = LLMType.OPENAI
|
||||
base_url: str = "https://api.openai.com/v1"
|
||||
api_version: Optional[str] = None
|
||||
|
||||
model: Optional[str] = None # also stands for DEPLOYMENT_NAME
|
||||
|
||||
# For Spark(Xunfei), maybe remove later
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from typing import Optional
|
|||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.configs.llm_config import LLMConfig
|
||||
from metagpt.const import OPTIONS
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import create_llm_instance
|
||||
|
|
@ -77,10 +77,10 @@ class Context(BaseModel):
|
|||
# self._llm = None
|
||||
# return self._llm
|
||||
|
||||
def llm(self, name: Optional[str] = None, provider: LLMType = None) -> BaseLLM:
|
||||
def llm(self) -> BaseLLM:
|
||||
"""Return a LLM instance, fixme: support cache"""
|
||||
# if self._llm is None:
|
||||
self._llm = create_llm_instance(self.config.get_llm_config(name, provider))
|
||||
self._llm = create_llm_instance(self.config.llm)
|
||||
if self._llm.cost_manager is None:
|
||||
self._llm.cost_manager = self.cost_manager
|
||||
return self._llm
|
||||
|
|
@ -103,7 +103,6 @@ class ContextMixin(BaseModel):
|
|||
_config: Optional[Config] = None
|
||||
|
||||
# Env/Role/Action will use this llm as private llm, or use self.context._llm instance
|
||||
_llm_config: Optional[LLMConfig] = None
|
||||
_llm: Optional[BaseLLM] = None
|
||||
|
||||
def __init__(
|
||||
|
|
@ -132,20 +131,10 @@ class ContextMixin(BaseModel):
|
|||
"""Set config"""
|
||||
self.set("_config", config, override)
|
||||
|
||||
def set_llm_config(self, llm_config: LLMConfig, override=False):
|
||||
"""Set llm config"""
|
||||
self.set("_llm_config", llm_config, override)
|
||||
|
||||
def set_llm(self, llm: BaseLLM, override=False):
|
||||
"""Set llm"""
|
||||
self.set("_llm", llm, override)
|
||||
|
||||
def use_llm(self, name: Optional[str] = None, provider: LLMType = None) -> BaseLLM:
|
||||
"""Use a LLM instance"""
|
||||
self._llm_config = self.config.get_llm_config(name, provider)
|
||||
self._llm = None
|
||||
return self.llm
|
||||
|
||||
@property
|
||||
def config(self) -> Config:
|
||||
"""Role config: role config > context config"""
|
||||
|
|
@ -172,11 +161,11 @@ class ContextMixin(BaseModel):
|
|||
|
||||
@property
|
||||
def llm(self) -> BaseLLM:
|
||||
"""Role llm: role llm > context llm"""
|
||||
"""Role llm: if not existed, init from role.config"""
|
||||
# print(f"class:{self.__class__.__name__}({self.name}), llm: {self._llm}, llm_config: {self._llm_config}")
|
||||
if self._llm_config and not self._llm:
|
||||
self._llm = self.context.llm_with_cost_manager_from_llm_config(self._llm_config)
|
||||
return self._llm or self.context.llm()
|
||||
if not self._llm:
|
||||
self._llm = self.context.llm_with_cost_manager_from_llm_config(self.config.llm)
|
||||
return self._llm
|
||||
|
||||
@llm.setter
|
||||
def llm(self, llm: BaseLLM) -> None:
|
||||
|
|
|
|||
|
|
@ -8,8 +8,6 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
|
||||
from metagpt.config import Config
|
||||
|
||||
|
||||
class BaseStore(ABC):
|
||||
"""FIXME: consider add_index, set_index and think about granularity."""
|
||||
|
|
|
|||
|
|
@ -6,14 +6,12 @@
|
|||
@File : llm.py
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.context import CONTEXT
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
|
||||
|
||||
def LLM(name: Optional[str] = None, provider: LLMType = LLMType.OPENAI) -> BaseLLM:
|
||||
def LLM() -> BaseLLM:
|
||||
"""get the default llm provider if name is None"""
|
||||
# context.use_llm(name=name, provider=provider)
|
||||
return CONTEXT.llm(name=name, provider=provider)
|
||||
return CONTEXT.llm()
|
||||
|
|
|
|||
|
|
@ -162,6 +162,7 @@ class Engineer(Role):
|
|||
if not is_pass:
|
||||
todo.i_context.reason = reason
|
||||
tasks.append(todo.i_context.dict())
|
||||
|
||||
await self.project_repo.docs.code_summary.save(
|
||||
filename=Path(todo.i_context.design_filename).name,
|
||||
content=todo.i_context.model_dump_json(),
|
||||
|
|
|
|||
|
|
@ -62,7 +62,9 @@ class QaEngineer(Role):
|
|||
logger.info(f"Writing {test_doc.filename}..")
|
||||
context = TestingContext(filename=test_doc.filename, test_doc=test_doc, code_doc=code_doc)
|
||||
context = await WriteTest(i_context=context, context=self.context, llm=self.llm).run()
|
||||
await self.project_repo.tests.save_doc(doc=test_doc, dependencies={context.code_doc.root_relative_path})
|
||||
await self.project_repo.tests.save_doc(
|
||||
doc=context.test_doc, dependencies={context.code_doc.root_relative_path}
|
||||
)
|
||||
|
||||
# prepare context for run tests in next round
|
||||
run_code_context = RunCodeContext(
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ class GoogleAPIWrapper(BaseModel):
|
|||
@field_validator("google_api_key", mode="before")
|
||||
@classmethod
|
||||
def check_google_api_key(cls, val: str):
|
||||
val = val or config.search["google"].api_key
|
||||
val = val or config.search.api_key
|
||||
if not val:
|
||||
raise ValueError(
|
||||
"To use, make sure you provide the google_api_key when constructing an object. Alternatively, "
|
||||
|
|
@ -47,7 +47,7 @@ class GoogleAPIWrapper(BaseModel):
|
|||
@field_validator("google_cse_id", mode="before")
|
||||
@classmethod
|
||||
def check_google_cse_id(cls, val: str):
|
||||
val = val or config.search["google"].cse_id
|
||||
val = val or config.search.cse_id
|
||||
if not val:
|
||||
raise ValueError(
|
||||
"To use, make sure you provide the google_cse_id when constructing an object. Alternatively, "
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ class SerpAPIWrapper(BaseModel):
|
|||
@field_validator("serpapi_api_key", mode="before")
|
||||
@classmethod
|
||||
def check_serpapi_api_key(cls, val: str):
|
||||
val = val or config.search["serpapi"].api_key
|
||||
val = val or config.search.api_key
|
||||
if not val:
|
||||
raise ValueError(
|
||||
"To use, make sure you provide the serpapi_api_key when constructing an object. Alternatively, "
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ class SerperWrapper(BaseModel):
|
|||
@field_validator("serper_api_key", mode="before")
|
||||
@classmethod
|
||||
def check_serper_api_key(cls, val: str):
|
||||
val = val or config.search["serper"].api_key
|
||||
val = val or config.search.api_key
|
||||
if not val:
|
||||
raise ValueError(
|
||||
"To use, make sure you provide the serper_api_key when constructing an object. Alternatively, "
|
||||
|
|
|
|||
|
|
@ -282,6 +282,6 @@ class UTGenerator:
|
|||
"""Choose based on different calling methods"""
|
||||
result = ""
|
||||
if self.chatgpt_method == "API":
|
||||
result = await GPTAPI(config.get_llm_config()).aask_code(messages=messages)
|
||||
result = await GPTAPI(config.get_openai_llm()).aask_code(messages=messages)
|
||||
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -28,12 +28,10 @@ class PlaywrightWrapper:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
browser_type: Literal["chromium", "firefox", "webkit"] | None = None,
|
||||
browser_type: Literal["chromium", "firefox", "webkit"] | None = "chromium",
|
||||
launch_kwargs: dict | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if browser_type is None:
|
||||
browser_type = config.browser["playwright"].driver
|
||||
self.browser_type = browser_type
|
||||
launch_kwargs = launch_kwargs or {}
|
||||
if config.proxy and "proxy" not in launch_kwargs:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue