diff --git a/config/config2.yaml b/config/config2.yaml new file mode 100644 index 000000000..0040023a8 --- /dev/null +++ b/config/config2.yaml @@ -0,0 +1,4 @@ +llm: + gpt3t: + api_key: "YOUR_API_KEY" + model: "gpt-3.5-turbo-1106" \ No newline at end of file diff --git a/examples/agent_creator.py b/examples/agent_creator.py index 340dfafa4..e908fe6ee 100644 --- a/examples/agent_creator.py +++ b/examples/agent_creator.py @@ -6,7 +6,7 @@ Author: garylin2099 import re from metagpt.actions import Action -from metagpt.config import CONFIG +from metagpt.config2 import config from metagpt.const import METAGPT_ROOT from metagpt.logs import logger from metagpt.roles import Role @@ -48,8 +48,8 @@ class CreateAgent(Action): pattern = r"```python(.*)```" match = re.search(pattern, rsp, re.DOTALL) code_text = match.group(1) if match else "" - CONFIG.workspace_path.mkdir(parents=True, exist_ok=True) - new_file = CONFIG.workspace_path / "agent_created_agent.py" + config.workspace.path.mkdir(parents=True, exist_ok=True) + new_file = config.workspace.path / "agent_created_agent.py" new_file.write_text(code_text) return code_text diff --git a/examples/search_kb.py b/examples/search_kb.py index 0e0e0ffd0..995720cc1 100644 --- a/examples/search_kb.py +++ b/examples/search_kb.py @@ -8,7 +8,7 @@ import asyncio from langchain.embeddings import OpenAIEmbeddings -from metagpt.config import CONFIG +from metagpt.config2 import config from metagpt.const import DATA_PATH, EXAMPLE_PATH from metagpt.document_store import FaissStore from metagpt.logs import logger @@ -16,7 +16,8 @@ from metagpt.roles import Sales def get_store(): - embedding = OpenAIEmbeddings(openai_api_key=CONFIG.openai_api_key, openai_api_base=CONFIG.openai_base_url) + llm = config.get_openai_llm() + embedding = OpenAIEmbeddings(openai_api_key=llm.api_key, openai_api_base=llm.base_url) return FaissStore(DATA_PATH / "example.json", embedding=embedding) diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index b586bcc22..ec80a96dd 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -13,6 +13,7 @@ from typing import Optional, Union from pydantic import ConfigDict, Field, model_validator from metagpt.actions.action_node import ActionNode +from metagpt.context import Context from metagpt.llm import LLM from metagpt.provider.base_llm import BaseLLM from metagpt.schema import ( @@ -33,14 +34,41 @@ class Action(SerializationMixin, is_polymorphic_base=True): prefix: str = "" # aask*时会加上prefix,作为system_message desc: str = "" # for skill manager node: ActionNode = Field(default=None, exclude=True) + _context: Optional[Context] = Field(default=None, exclude=True) + + @property + def git_repo(self): + return self._context.git_repo + + @property + def src_workspace(self): + return self._context.src_workspace + + @property + def prompt_schema(self): + return self._context.config.prompt_schema + + @property + def project_name(self): + return self._context.config.project_name + + @project_name.setter + def project_name(self, value): + self._context.config.project_name = value + + @property + def project_path(self): + return self._context.config.project_path @model_validator(mode="before") + @classmethod def set_name_if_empty(cls, values): if "name" not in values or not values["name"]: values["name"] = cls.__name__ return values @model_validator(mode="before") + @classmethod def _init_with_instruction(cls, values): if "instruction" in values: name = values["name"] diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index 6c65b33ef..16a43ea69 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -14,7 +14,6 @@ from typing import Any, Dict, List, Optional, Tuple, Type from pydantic import BaseModel, create_model, model_validator from tenacity import retry, stop_after_attempt, wait_random_exponential -from metagpt.config import CONFIG from metagpt.llm import BaseLLM from metagpt.logs import logger from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess @@ -262,7 +261,7 @@ class ActionNode: output_data_mapping: dict, system_msgs: Optional[list[str]] = None, schema="markdown", # compatible to original format - timeout=CONFIG.timeout, + timeout=None, ) -> (str, BaseModel): """Use ActionOutput to wrap the output of aask""" content = await self.llm.aask(prompt, system_msgs, timeout=timeout) @@ -294,7 +293,7 @@ class ActionNode: def set_context(self, context): self.set_recursive("context", context) - async def simple_fill(self, schema, mode, timeout=CONFIG.timeout, exclude=None): + async def simple_fill(self, schema, mode, timeout=None, exclude=None): prompt = self.compile(context=self.context, schema=schema, mode=mode, exclude=exclude) if schema != "raw": @@ -309,7 +308,7 @@ class ActionNode: return self - async def fill(self, context, llm, schema="json", mode="auto", strgy="simple", timeout=CONFIG.timeout, exclude=[]): + async def fill(self, context, llm, schema="json", mode="auto", strgy="simple", timeout=None, exclude=[]): """Fill the node(s) with mode. :param context: Everything we should know when filling node. diff --git a/metagpt/actions/debug_error.py b/metagpt/actions/debug_error.py index 34f784072..2916005c2 100644 --- a/metagpt/actions/debug_error.py +++ b/metagpt/actions/debug_error.py @@ -9,12 +9,13 @@ 2. According to Section 2.2.3.1 of RFC 135, replace file data in the message with the file name. """ import re +from typing import Optional from pydantic import Field from metagpt.actions.action import Action -from metagpt.config import CONFIG from metagpt.const import TEST_CODES_FILE_REPO, TEST_OUTPUTS_FILE_REPO +from metagpt.context import Context from metagpt.logs import logger from metagpt.schema import RunCodeContext, RunCodeResult from metagpt.utils.common import CodeParser @@ -49,8 +50,8 @@ Now you should start rewriting the code: class DebugError(Action): - name: str = "DebugError" context: RunCodeContext = Field(default_factory=RunCodeContext) + _context: Optional[Context] = None async def run(self, *args, **kwargs) -> str: output_doc = await FileRepository.get_file( @@ -66,7 +67,7 @@ class DebugError(Action): logger.info(f"Debug and rewrite {self.context.test_filename}") code_doc = await FileRepository.get_file( - filename=self.context.code_filename, relative_path=CONFIG.src_workspace + filename=self.context.code_filename, relative_path=self._context.src_workspace ) if not code_doc: return "" diff --git a/metagpt/actions/design_api.py b/metagpt/actions/design_api.py index 2574550e4..664c1c5c3 100644 --- a/metagpt/actions/design_api.py +++ b/metagpt/actions/design_api.py @@ -15,7 +15,6 @@ from typing import Optional from metagpt.actions import Action, ActionOutput from metagpt.actions.design_api_an import DESIGN_API_NODE -from metagpt.config import CONFIG from metagpt.const import ( DATA_API_DESIGN_FILE_REPO, PRDS_FILE_REPO, @@ -46,13 +45,13 @@ class WriteDesign(Action): "clearly and in detail." ) - async def run(self, with_messages: Message, schema: str = CONFIG.prompt_schema): + async def run(self, with_messages: Message, schema: str = None): # Use `git status` to identify which PRD documents have been modified in the `docs/prds` directory. - prds_file_repo = CONFIG.git_repo.new_file_repository(PRDS_FILE_REPO) + prds_file_repo = self.git_repo.new_file_repository(PRDS_FILE_REPO) changed_prds = prds_file_repo.changed_files # Use `git status` to identify which design documents in the `docs/system_designs` directory have undergone # changes. - system_design_file_repo = CONFIG.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO) + system_design_file_repo = self.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO) changed_system_designs = system_design_file_repo.changed_files # For those PRDs and design documents that have undergone changes, regenerate the design content. @@ -76,11 +75,11 @@ class WriteDesign(Action): # leaving room for global optimization in subsequent steps. return ActionOutput(content=changed_files.model_dump_json(), instruct_content=changed_files) - async def _new_system_design(self, context, schema=CONFIG.prompt_schema): + async def _new_system_design(self, context, schema=None): node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, schema=schema) return node - async def _merge(self, prd_doc, system_design_doc, schema=CONFIG.prompt_schema): + async def _merge(self, prd_doc, system_design_doc, schema=None): context = NEW_REQ_TEMPLATE.format(old_design=system_design_doc.content, context=prd_doc.content) node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, schema=schema) system_design_doc.content = node.instruct_content.model_dump_json() @@ -106,23 +105,21 @@ class WriteDesign(Action): await self._save_pdf(doc) return doc - @staticmethod - async def _save_data_api_design(design_doc): + async def _save_data_api_design(self, design_doc): m = json.loads(design_doc.content) data_api_design = m.get("Data structures and interfaces") if not data_api_design: return - pathname = CONFIG.git_repo.workdir / DATA_API_DESIGN_FILE_REPO / Path(design_doc.filename).with_suffix("") + pathname = self.git_repo.workdir / DATA_API_DESIGN_FILE_REPO / Path(design_doc.filename).with_suffix("") await WriteDesign._save_mermaid_file(data_api_design, pathname) logger.info(f"Save class view to {str(pathname)}") - @staticmethod - async def _save_seq_flow(design_doc): + async def _save_seq_flow(self, design_doc): m = json.loads(design_doc.content) seq_flow = m.get("Program call flow") if not seq_flow: return - pathname = CONFIG.git_repo.workdir / Path(SEQ_FLOW_FILE_REPO) / Path(design_doc.filename).with_suffix("") + pathname = self.git_repo.workdir / Path(SEQ_FLOW_FILE_REPO) / Path(design_doc.filename).with_suffix("") await WriteDesign._save_mermaid_file(seq_flow, pathname) logger.info(f"Saving sequence flow to {str(pathname)}") diff --git a/metagpt/actions/prepare_documents.py b/metagpt/actions/prepare_documents.py index a936ea655..3bd362207 100644 --- a/metagpt/actions/prepare_documents.py +++ b/metagpt/actions/prepare_documents.py @@ -12,7 +12,6 @@ from pathlib import Path from typing import Optional from metagpt.actions import Action, ActionOutput -from metagpt.config import CONFIG from metagpt.const import DOCS_FILE_REPO, REQUIREMENT_FILENAME from metagpt.schema import Document from metagpt.utils.file_repository import FileRepository @@ -25,18 +24,22 @@ class PrepareDocuments(Action): name: str = "PrepareDocuments" context: Optional[str] = None + @property + def config(self): + return self._context.config + def _init_repo(self): """Initialize the Git environment.""" - if not CONFIG.project_path: - name = CONFIG.project_name or FileRepository.new_filename() - path = Path(CONFIG.workspace_path) / name + if not self.config.project_path: + name = self.config.project_name or FileRepository.new_filename() + path = Path(self.config.workspace.path) / name else: - path = Path(CONFIG.project_path) - if path.exists() and not CONFIG.inc: + path = Path(self.config.project_path) + if path.exists() and not self.config.inc: shutil.rmtree(path) - CONFIG.project_path = path - CONFIG.project_name = path.name - CONFIG.git_repo = GitRepository(local_path=path, auto_init=True) + self.config.project_path = path + self.config.project_name = path.name + self._context.git_repo = GitRepository(local_path=path, auto_init=True) async def run(self, with_messages, **kwargs): """Create and initialize the workspace folder, initialize the Git environment.""" diff --git a/metagpt/actions/project_management.py b/metagpt/actions/project_management.py index e40c2034b..f8ccd922a 100644 --- a/metagpt/actions/project_management.py +++ b/metagpt/actions/project_management.py @@ -16,7 +16,6 @@ from typing import Optional from metagpt.actions import ActionOutput from metagpt.actions.action import Action from metagpt.actions.project_management_an import PM_NODE -from metagpt.config import CONFIG from metagpt.const import ( PACKAGE_REQUIREMENTS_FILENAME, SYSTEM_DESIGN_FILE_REPO, @@ -40,11 +39,15 @@ class WriteTasks(Action): name: str = "CreateTasks" context: Optional[str] = None - async def run(self, with_messages, schema=CONFIG.prompt_schema): - system_design_file_repo = CONFIG.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO) + @property + def prompt_schema(self): + return self._context.config.prompt_schema + + async def run(self, with_messages, schema=None): + system_design_file_repo = self.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO) changed_system_designs = system_design_file_repo.changed_files - tasks_file_repo = CONFIG.git_repo.new_file_repository(TASK_FILE_REPO) + tasks_file_repo = self.git_repo.new_file_repository(TASK_FILE_REPO) changed_tasks = tasks_file_repo.changed_files change_files = Documents() # Rewrite the system designs that have undergone changes based on the git head diff under @@ -87,21 +90,20 @@ class WriteTasks(Action): await self._save_pdf(task_doc=task_doc) return task_doc - async def _run_new_tasks(self, context, schema=CONFIG.prompt_schema): - node = await PM_NODE.fill(context, self.llm, schema) + async def _run_new_tasks(self, context): + node = await PM_NODE.fill(context, self.llm, schema=self.prompt_schema) return node - async def _merge(self, system_design_doc, task_doc, schema=CONFIG.prompt_schema) -> Document: + async def _merge(self, system_design_doc, task_doc) -> Document: context = NEW_REQ_TEMPLATE.format(context=system_design_doc.content, old_tasks=task_doc.content) - node = await PM_NODE.fill(context, self.llm, schema) + node = await PM_NODE.fill(context, self.llm, schema=self.prompt_schema) task_doc.content = node.instruct_content.model_dump_json() return task_doc - @staticmethod - async def _update_requirements(doc): + async def _update_requirements(self, doc): m = json.loads(doc.content) packages = set(m.get("Required Python third-party packages", set())) - file_repo = CONFIG.git_repo.new_file_repository() + file_repo = self.git_repo.new_file_repository() requirement_doc = await file_repo.get(filename=PACKAGE_REQUIREMENTS_FILENAME) if not requirement_doc: requirement_doc = Document(filename=PACKAGE_REQUIREMENTS_FILENAME, root_path=".", content="") diff --git a/metagpt/actions/rebuild_class_view.py b/metagpt/actions/rebuild_class_view.py index 66bc2c7ab..773e40a3e 100644 --- a/metagpt/actions/rebuild_class_view.py +++ b/metagpt/actions/rebuild_class_view.py @@ -10,7 +10,6 @@ import re from pathlib import Path from metagpt.actions import Action -from metagpt.config import CONFIG from metagpt.const import CLASS_VIEW_FILE_REPO, GRAPH_REPO_FILE_REPO from metagpt.repo_parser import RepoParser from metagpt.utils.di_graph_repository import DiGraphRepository @@ -21,8 +20,8 @@ class RebuildClassView(Action): def __init__(self, name="", context=None, llm=None): super().__init__(name=name, context=context, llm=llm) - async def run(self, with_messages=None, format=CONFIG.prompt_schema): - graph_repo_pathname = CONFIG.git_repo.workdir / GRAPH_REPO_FILE_REPO / CONFIG.git_repo.workdir.name + async def run(self, with_messages=None): + graph_repo_pathname = self.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.git_repo.workdir.name graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json"))) repo_parser = RepoParser(base_directory=self.context) class_views = await repo_parser.rebuild_class_views(path=Path(self.context)) # use pylint @@ -57,7 +56,7 @@ class RebuildClassView(Action): # logger.info(f"{concat_namespace(filename, class_name)} {GraphKeyword.HAS_CLASS_VIEW} {class_view}") async def _save(self, graph_db): - class_view_file_repo = CONFIG.git_repo.new_file_repository(relative_path=CLASS_VIEW_FILE_REPO) + class_view_file_repo = self.git_repo.new_file_repository(relative_path=CLASS_VIEW_FILE_REPO) dataset = await graph_db.select(predicate=GraphKeyword.HAS_CLASS_VIEW) all_class_view = [] for spo in dataset: diff --git a/metagpt/actions/run_code.py b/metagpt/actions/run_code.py index 30b06f1a6..74ad36dae 100644 --- a/metagpt/actions/run_code.py +++ b/metagpt/actions/run_code.py @@ -21,7 +21,6 @@ from typing import Tuple from pydantic import Field from metagpt.actions.action import Action -from metagpt.config import CONFIG from metagpt.logs import logger from metagpt.schema import RunCodeContext, RunCodeResult from metagpt.utils.exceptions import handle_exception @@ -89,13 +88,12 @@ class RunCode(Action): return "", str(e) return namespace.get("result", ""), "" - @classmethod - async def run_script(cls, working_directory, additional_python_paths=[], command=[]) -> Tuple[str, str]: + async def run_script(self, working_directory, additional_python_paths=[], command=[]) -> Tuple[str, str]: working_directory = str(working_directory) additional_python_paths = [str(path) for path in additional_python_paths] # Copy the current environment variables - env = CONFIG.new_environ() + env = self._context.new_environ() # Modify the PYTHONPATH environment variable additional_python_paths = [working_directory] + additional_python_paths diff --git a/metagpt/actions/search_and_summarize.py b/metagpt/actions/search_and_summarize.py index d2e361f73..39ca23df5 100644 --- a/metagpt/actions/search_and_summarize.py +++ b/metagpt/actions/search_and_summarize.py @@ -11,7 +11,7 @@ import pydantic from pydantic import Field, model_validator from metagpt.actions import Action -from metagpt.config import CONFIG, Config +from metagpt.config import Config from metagpt.logs import logger from metagpt.schema import Message from metagpt.tools import SearchEngineType @@ -103,12 +103,11 @@ You are a member of a professional butler team and will provide helpful suggesti """ -# TOTEST class SearchAndSummarize(Action): name: str = "" content: Optional[str] = None config: None = Field(default_factory=Config) - engine: Optional[SearchEngineType] = CONFIG.search_engine + engine: Optional[SearchEngineType] = None search_func: Optional[Any] = None search_engine: SearchEngine = None result: str = "" diff --git a/metagpt/actions/summarize_code.py b/metagpt/actions/summarize_code.py index bdad546d7..94f3c6541 100644 --- a/metagpt/actions/summarize_code.py +++ b/metagpt/actions/summarize_code.py @@ -11,7 +11,6 @@ from pydantic import Field from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions.action import Action -from metagpt.config import CONFIG from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO from metagpt.logs import logger from metagpt.schema import CodeSummarizeContext @@ -105,7 +104,7 @@ class SummarizeCode(Action): design_doc = await FileRepository.get_file(filename=design_pathname.name, relative_path=SYSTEM_DESIGN_FILE_REPO) task_pathname = Path(self.context.task_filename) task_doc = await FileRepository.get_file(filename=task_pathname.name, relative_path=TASK_FILE_REPO) - src_file_repo = CONFIG.git_repo.new_file_repository(relative_path=CONFIG.src_workspace) + src_file_repo = self.git_repo.new_file_repository(relative_path=self._context.src_workspace) code_blocks = [] for filename in self.context.codes_filenames: code_doc = await src_file_repo.get(filename) diff --git a/metagpt/actions/write_code.py b/metagpt/actions/write_code.py index 25c4912c3..5b09aa2b0 100644 --- a/metagpt/actions/write_code.py +++ b/metagpt/actions/write_code.py @@ -21,7 +21,6 @@ from pydantic import Field from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions.action import Action -from metagpt.config import CONFIG from metagpt.const import ( BUGFIX_FILENAME, CODE_SUMMARIES_FILE_REPO, @@ -114,7 +113,12 @@ class WriteCode(Action): if bug_feedback: code_context = coding_context.code_doc.content else: - code_context = await self.get_codes(coding_context.task_doc, exclude=self.context.filename) + code_context = await self.get_codes( + coding_context.task_doc, + exclude=self.context.filename, + git_repo=self.git_repo, + src_workspace=self._context.src_workspace, + ) prompt = PROMPT_TEMPLATE.format( design=coding_context.design_doc.content if coding_context.design_doc else "", @@ -129,13 +133,13 @@ class WriteCode(Action): code = await self.write_code(prompt) if not coding_context.code_doc: # avoid root_path pydantic ValidationError if use WriteCode alone - root_path = CONFIG.src_workspace if CONFIG.src_workspace else "" + root_path = self._context.src_workspace if self._context.src_workspace else "" coding_context.code_doc = Document(filename=coding_context.filename, root_path=root_path) coding_context.code_doc.content = code return coding_context @staticmethod - async def get_codes(task_doc, exclude) -> str: + async def get_codes(task_doc, exclude, git_repo, src_workspace) -> str: if not task_doc: return "" if not task_doc.content: @@ -143,7 +147,7 @@ class WriteCode(Action): m = json.loads(task_doc.content) code_filenames = m.get("Task list", []) codes = [] - src_file_repo = CONFIG.git_repo.new_file_repository(relative_path=CONFIG.src_workspace) + src_file_repo = git_repo.new_file_repository(relative_path=src_workspace) for filename in code_filenames: if filename == exclude: continue diff --git a/metagpt/actions/write_code_review.py b/metagpt/actions/write_code_review.py index a8c913573..e261f0623 100644 --- a/metagpt/actions/write_code_review.py +++ b/metagpt/actions/write_code_review.py @@ -13,7 +13,6 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions import WriteCode from metagpt.actions.action import Action -from metagpt.config import CONFIG from metagpt.logs import logger from metagpt.schema import CodingContext from metagpt.utils.common import CodeParser @@ -137,11 +136,16 @@ class WriteCodeReview(Action): async def run(self, *args, **kwargs) -> CodingContext: iterative_code = self.context.code_doc.content - k = CONFIG.code_review_k_times or 1 + k = self._context.config.code_review_k_times or 1 for i in range(k): format_example = FORMAT_EXAMPLE.format(filename=self.context.code_doc.filename) task_content = self.context.task_doc.content if self.context.task_doc else "" - code_context = await WriteCode.get_codes(self.context.task_doc, exclude=self.context.filename) + code_context = await WriteCode.get_codes( + self.context.task_doc, + exclude=self.context.filename, + git_repo=self._context.git_repo, + src_workspace=self.src_workspace, + ) context = "\n".join( [ "## System Design\n" + str(self.context.design_doc) + "\n", diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index d51c0a7be..e77a469c1 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -26,7 +26,6 @@ from metagpt.actions.write_prd_an import ( WP_ISSUE_TYPE_NODE, WRITE_PRD_NODE, ) -from metagpt.config import CONFIG from metagpt.const import ( BUGFIX_FILENAME, COMPETITIVE_ANALYSIS_FILE_REPO, @@ -65,10 +64,10 @@ class WritePRD(Action): name: str = "WritePRD" content: Optional[str] = None - async def run(self, with_messages, schema=CONFIG.prompt_schema, *args, **kwargs) -> ActionOutput | Message: + async def run(self, with_messages, *args, **kwargs) -> ActionOutput | Message: # Determine which requirement documents need to be rewritten: Use LLM to assess whether new requirements are # related to the PRD. If they are related, rewrite the PRD. - docs_file_repo = CONFIG.git_repo.new_file_repository(relative_path=DOCS_FILE_REPO) + docs_file_repo = self.git_repo.new_file_repository(relative_path=DOCS_FILE_REPO) requirement_doc = await docs_file_repo.get(filename=REQUIREMENT_FILENAME) if requirement_doc and await self._is_bugfix(requirement_doc.content): await docs_file_repo.save(filename=BUGFIX_FILENAME, content=requirement_doc.content) @@ -85,7 +84,7 @@ class WritePRD(Action): else: await docs_file_repo.delete(filename=BUGFIX_FILENAME) - prds_file_repo = CONFIG.git_repo.new_file_repository(PRDS_FILE_REPO) + prds_file_repo = self.git_repo.new_file_repository(PRDS_FILE_REPO) prd_docs = await prds_file_repo.get_all() change_files = Documents() for prd_doc in prd_docs: @@ -109,7 +108,7 @@ class WritePRD(Action): # optimization in subsequent steps. return ActionOutput(content=change_files.model_dump_json(), instruct_content=change_files) - async def _run_new_requirement(self, requirements, schema=CONFIG.prompt_schema) -> ActionOutput: + async def _run_new_requirement(self, requirements) -> ActionOutput: # sas = SearchAndSummarize() # # rsp = await sas.run(context=requirements, system_text=SEARCH_AND_SUMMARIZE_SYSTEM_EN_US) # rsp = "" @@ -117,7 +116,7 @@ class WritePRD(Action): # if sas.result: # logger.info(sas.result) # logger.info(rsp) - project_name = CONFIG.project_name if CONFIG.project_name else "" + project_name = self.project_name context = CONTEXT_TEMPLATE.format(requirements=requirements, project_name=project_name) exclude = [PROJECT_NAME.key] if project_name else [] node = await WRITE_PRD_NODE.fill(context=context, llm=self.llm, exclude=exclude) # schema=schema @@ -129,11 +128,11 @@ class WritePRD(Action): node = await WP_IS_RELATIVE_NODE.fill(context, self.llm) return node.get("is_relative") == "YES" - async def _merge(self, new_requirement_doc, prd_doc, schema=CONFIG.prompt_schema) -> Document: - if not CONFIG.project_name: - CONFIG.project_name = Path(CONFIG.project_path).name + async def _merge(self, new_requirement_doc, prd_doc) -> Document: + if not self.project_name: + self.project_name = Path(self.project_path).name prompt = NEW_REQ_TEMPLATE.format(requirements=new_requirement_doc.content, old_prd=prd_doc.content) - node = await WRITE_PRD_NODE.fill(context=prompt, llm=self.llm, schema=schema) + node = await WRITE_PRD_NODE.fill(context=prompt, llm=self.llm, schema=self.prompt_schema) prd_doc.content = node.instruct_content.model_dump_json() await self._rename_workspace(node) return prd_doc @@ -157,15 +156,12 @@ class WritePRD(Action): await self._save_pdf(new_prd_doc) return new_prd_doc - @staticmethod - async def _save_competitive_analysis(prd_doc): + async def _save_competitive_analysis(self, prd_doc): m = json.loads(prd_doc.content) quadrant_chart = m.get("Competitive Quadrant Chart") if not quadrant_chart: return - pathname = ( - CONFIG.git_repo.workdir / Path(COMPETITIVE_ANALYSIS_FILE_REPO) / Path(prd_doc.filename).with_suffix("") - ) + pathname = self.git_repo.workdir / Path(COMPETITIVE_ANALYSIS_FILE_REPO) / Path(prd_doc.filename).with_suffix("") if not pathname.parent.exists(): pathname.parent.mkdir(parents=True, exist_ok=True) await mermaid_to_file(quadrant_chart, pathname) @@ -174,20 +170,19 @@ class WritePRD(Action): async def _save_pdf(prd_doc): await FileRepository.save_as(doc=prd_doc, with_suffix=".md", relative_path=PRD_PDF_FILE_REPO) - @staticmethod - async def _rename_workspace(prd): - if not CONFIG.project_name: + async def _rename_workspace(self, prd): + if not self.project_name: if isinstance(prd, (ActionOutput, ActionNode)): ws_name = prd.instruct_content.model_dump()["Project Name"] else: ws_name = CodeParser.parse_str(block="Project Name", text=prd) if ws_name: - CONFIG.project_name = ws_name - CONFIG.git_repo.rename_root(CONFIG.project_name) + self.project_name = ws_name + self.git_repo.rename_root(self.project_name) async def _is_bugfix(self, context) -> bool: - src_workspace_path = CONFIG.git_repo.workdir / CONFIG.git_repo.workdir.name - code_files = CONFIG.git_repo.get_files(relative_path=src_workspace_path) + src_workspace_path = self.git_repo.workdir / self.git_repo.workdir.name + code_files = self.git_repo.get_files(relative_path=src_workspace_path) if not code_files: return False node = await WP_ISSUE_TYPE_NODE.fill(context, self.llm) diff --git a/metagpt/actions/write_teaching_plan.py b/metagpt/actions/write_teaching_plan.py index b824e055e..ea9be4819 100644 --- a/metagpt/actions/write_teaching_plan.py +++ b/metagpt/actions/write_teaching_plan.py @@ -75,6 +75,7 @@ class WriteTeachingPlanPart(Action): if "{" not in value: return value + # FIXME: 从Context中获取参数 merged_opts = CONFIG.options or {} try: return value.format(**merged_opts) diff --git a/metagpt/actions/write_test.py b/metagpt/actions/write_test.py index 0166f5417..2b98e7458 100644 --- a/metagpt/actions/write_test.py +++ b/metagpt/actions/write_test.py @@ -11,7 +11,6 @@ from typing import Optional from metagpt.actions.action import Action -from metagpt.config import CONFIG from metagpt.const import TEST_CODES_FILE_REPO from metagpt.logs import logger from metagpt.schema import Document, TestingContext @@ -64,7 +63,7 @@ class WriteTest(Action): code_to_test=self.context.code_doc.content, test_file_name=self.context.test_doc.filename, source_file_path=self.context.code_doc.root_relative_path, - workspace=CONFIG.git_repo.workdir, + workspace=self.git_repo.workdir, ) self.context.test_doc.content = await self.write_code(prompt) return self.context diff --git a/metagpt/config.py b/metagpt/config.py index eb3636c9a..176b54cfc 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -11,13 +11,13 @@ import json import os import warnings from copy import deepcopy -from enum import Enum from pathlib import Path from typing import Any from uuid import uuid4 import yaml +from metagpt.configs.llm_config import LLMType from metagpt.const import DEFAULT_WORKSPACE_ROOT, METAGPT_ROOT, OPTIONS from metagpt.logs import logger from metagpt.tools import SearchEngineType, WebBrowserEngineType @@ -38,19 +38,6 @@ class NotConfiguredException(Exception): super().__init__(self.message) -class LLMProviderEnum(Enum): - OPENAI = "openai" - ANTHROPIC = "anthropic" - SPARK = "spark" - ZHIPUAI = "zhipuai" - FIREWORKS = "fireworks" - OPEN_LLM = "open_llm" - GEMINI = "gemini" - METAGPT = "metagpt" - AZURE_OPENAI = "azure_openai" - OLLAMA = "ollama" - - class Config(metaclass=Singleton): """ Regular usage method: @@ -81,27 +68,25 @@ class Config(metaclass=Singleton): global_options.update(OPTIONS.get()) logger.debug("Config loading done.") - def get_default_llm_provider_enum(self) -> LLMProviderEnum: + def get_default_llm_provider_enum(self) -> LLMType: """Get first valid LLM provider enum""" mappings = { - LLMProviderEnum.OPENAI: bool( + LLMType.OPENAI: bool( self._is_valid_llm_key(self.OPENAI_API_KEY) and not self.OPENAI_API_TYPE and self.OPENAI_API_MODEL ), - LLMProviderEnum.ANTHROPIC: self._is_valid_llm_key(self.ANTHROPIC_API_KEY), - LLMProviderEnum.ZHIPUAI: self._is_valid_llm_key(self.ZHIPUAI_API_KEY), - LLMProviderEnum.FIREWORKS: self._is_valid_llm_key(self.FIREWORKS_API_KEY), - LLMProviderEnum.OPEN_LLM: self._is_valid_llm_key(self.OPEN_LLM_API_BASE), - LLMProviderEnum.GEMINI: self._is_valid_llm_key(self.GEMINI_API_KEY), - LLMProviderEnum.METAGPT: bool( - self._is_valid_llm_key(self.OPENAI_API_KEY) and self.OPENAI_API_TYPE == "metagpt" - ), - LLMProviderEnum.AZURE_OPENAI: bool( + LLMType.ANTHROPIC: self._is_valid_llm_key(self.ANTHROPIC_API_KEY), + LLMType.ZHIPUAI: self._is_valid_llm_key(self.ZHIPUAI_API_KEY), + LLMType.FIREWORKS: self._is_valid_llm_key(self.FIREWORKS_API_KEY), + LLMType.OPEN_LLM: self._is_valid_llm_key(self.OPEN_LLM_API_BASE), + LLMType.GEMINI: self._is_valid_llm_key(self.GEMINI_API_KEY), + LLMType.METAGPT: bool(self._is_valid_llm_key(self.OPENAI_API_KEY) and self.OPENAI_API_TYPE == "metagpt"), + LLMType.AZURE_OPENAI: bool( self._is_valid_llm_key(self.OPENAI_API_KEY) and self.OPENAI_API_TYPE == "azure" and self.DEPLOYMENT_NAME and self.OPENAI_API_VERSION ), - LLMProviderEnum.OLLAMA: self._is_valid_llm_key(self.OLLAMA_API_BASE), + LLMType.OLLAMA: self._is_valid_llm_key(self.OLLAMA_API_BASE), } provider = None for k, v in mappings.items(): @@ -109,7 +94,7 @@ class Config(metaclass=Singleton): provider = k break - if provider is LLMProviderEnum.GEMINI and not require_python_version(req_version=(3, 10)): + if provider is LLMType.GEMINI and not require_python_version(req_version=(3, 10)): warnings.warn("Use Gemini requires Python >= 3.10") model_name = self.get_model_name(provider=provider) if model_name: @@ -122,8 +107,8 @@ class Config(metaclass=Singleton): def get_model_name(self, provider=None) -> str: provider = provider or self.get_default_llm_provider_enum() model_mappings = { - LLMProviderEnum.OPENAI: self.OPENAI_API_MODEL, - LLMProviderEnum.AZURE_OPENAI: self.DEPLOYMENT_NAME, + LLMType.OPENAI: self.OPENAI_API_MODEL, + LLMType.AZURE_OPENAI: self.DEPLOYMENT_NAME, } return model_mappings.get(provider, "") @@ -166,6 +151,7 @@ class Config(metaclass=Singleton): self.fireworks_api_model = self._get("FIREWORKS_API_MODEL") self.claude_api_key = self._get("ANTHROPIC_API_KEY") + self.serpapi_api_key = self._get("SERPAPI_API_KEY") self.serper_api_key = self._get("SERPER_API_KEY") self.google_api_key = self._get("GOOGLE_API_KEY") @@ -200,7 +186,7 @@ class Config(metaclass=Singleton): self.workspace_path = self.workspace_path / workspace_uid self._ensure_workspace_exists() self.max_auto_summarize_code = self.max_auto_summarize_code or self._get("MAX_AUTO_SUMMARIZE_CODE", 1) - self.timeout = int(self._get("TIMEOUT", 3)) + self.timeout = int(self._get("TIMEOUT", 60)) def update_via_cli(self, project_path, project_name, inc, reqa_file, max_auto_summarize_code): """update config via cli""" diff --git a/metagpt/config2.py b/metagpt/config2.py new file mode 100644 index 000000000..ca46cc7a5 --- /dev/null +++ b/metagpt/config2.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 01:25 +@Author : alexanderwu +@File : llm_factory.py +""" +import os +from pathlib import Path +from typing import Dict, Iterable, List, Literal, Optional + +from pydantic import BaseModel, Field, model_validator + +from metagpt.configs.browser_config import BrowserConfig +from metagpt.configs.llm_config import LLMConfig, LLMType +from metagpt.configs.mermaid_config import MermaidConfig +from metagpt.configs.redis_config import RedisConfig +from metagpt.configs.s3_config import S3Config +from metagpt.configs.search_config import SearchConfig +from metagpt.configs.workspace_config import WorkspaceConfig +from metagpt.const import METAGPT_ROOT +from metagpt.utils.yaml_model import YamlModel + + +class CLIParams(BaseModel): + project_path: str = "" + project_name: str = "" + inc: bool = False + reqa_file: str = "" + max_auto_summarize_code: int = 0 + git_reinit: bool = False + + @model_validator(mode="after") + def check_project_path(self): + if self.project_path: + self.inc = True + self.project_name = self.project_name or Path(self.project_path).name + + +class Config(CLIParams, YamlModel): + # Key Parameters + llm: Dict[str, LLMConfig] = Field(default_factory=Dict) + + # Global Proxy. Will be used if llm.proxy is not set + proxy: str = "" + + # Tool Parameters + search: Dict[str, SearchConfig] = {} + browser: Dict[str, BrowserConfig] = {"default": BrowserConfig()} + mermaid: Dict[str, MermaidConfig] = {"default": MermaidConfig()} + + # Storage Parameters + s3: Optional[S3Config] = None + redis: Optional[RedisConfig] = None + + # Misc Parameters + repair_llm_output: bool = False + prompt_schema: Literal["json", "markdown", "raw"] = "json" + workspace: WorkspaceConfig = WorkspaceConfig() + enable_longterm_memory: bool = False + code_review_k_times: int = 2 + + # Will be removed in the future + llm_for_researcher_summary: str = "gpt3" + llm_for_researcher_report: str = "gpt3" + METAGPT_TEXT_TO_IMAGE_MODEL_URL: str = "" + + @classmethod + def default(cls): + """Load default config + - Priority: env < default_config_paths + - Inside default_config_paths, the latter one overwrites the former one + """ + default_config_paths: List[Path] = [ + METAGPT_ROOT / "config/config2.yaml", + Path.home() / ".metagpt/config2.yaml", + ] + + dicts = [dict(os.environ)] + dicts += [Config.read_yaml(path) for path in default_config_paths] + final = merge_dict(dicts) + return Config(**final) + + def update_via_cli(self, project_path, project_name, inc, reqa_file, max_auto_summarize_code): + """update config via cli""" + + # Use in the PrepareDocuments action according to Section 2.2.3.5.1 of RFC 135. + if project_path: + inc = True + project_name = project_name or Path(project_path).name + self.project_path = project_path + self.project_name = project_name + self.inc = inc + self.reqa_file = reqa_file + self.max_auto_summarize_code = max_auto_summarize_code + + def get_llm_config(self, name: Optional[str] = None) -> LLMConfig: + """Get LLM instance by name""" + if name is None: + # Use the first LLM as default + name = list(self.llm.keys())[0] + if name not in self.llm: + raise ValueError(f"LLM {name} not found in config") + return self.llm[name] + + def get_openai_llm(self, name: Optional[str] = None) -> LLMConfig: + """Get OpenAI LLMConfig by name. If no OpenAI, raise Exception""" + if name is None: + # Use the first OpenAI LLM as default + name = [k for k, v in self.llm.items() if v.api_type == LLMType.OPENAI][0] + if name not in self.llm: + raise ValueError(f"OpenAI LLM {name} not found in config") + return self.llm[name] + + +def merge_dict(dicts: Iterable[Dict]) -> Dict: + """Merge multiple dicts into one, with the latter dict overwriting the former""" + result = {} + for dictionary in dicts: + result.update(dictionary) + return result + + +config = Config.default() diff --git a/metagpt/configs/__init__.py b/metagpt/configs/__init__.py new file mode 100644 index 000000000..e42e6788f --- /dev/null +++ b/metagpt/configs/__init__.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 16:33 +@Author : alexanderwu +@File : __init__.py +""" diff --git a/metagpt/configs/browser_config.py b/metagpt/configs/browser_config.py new file mode 100644 index 000000000..00f918735 --- /dev/null +++ b/metagpt/configs/browser_config.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 19:06 +@Author : alexanderwu +@File : browser_config.py +""" +from typing import Literal + +from metagpt.tools import WebBrowserEngineType +from metagpt.utils.yaml_model import YamlModel + + +class BrowserConfig(YamlModel): + """Config for Browser""" + + engine: WebBrowserEngineType = WebBrowserEngineType.PLAYWRIGHT + browser: Literal["chrome", "firefox", "edge", "ie"] = "chrome" + driver: Literal["chromium", "firefox", "webkit"] = "chromium" + path: str = "" diff --git a/metagpt/configs/llm_config.py b/metagpt/configs/llm_config.py new file mode 100644 index 000000000..0961478a4 --- /dev/null +++ b/metagpt/configs/llm_config.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 16:33 +@Author : alexanderwu +@File : llm_config.py +""" +from enum import Enum +from typing import Optional + +from pydantic import field_validator + +from metagpt.utils.yaml_model import YamlModel + + +class LLMType(Enum): + OPENAI = "openai" + ANTHROPIC = "anthropic" + SPARK = "spark" + ZHIPUAI = "zhipuai" + FIREWORKS = "fireworks" + OPEN_LLM = "open_llm" + GEMINI = "gemini" + METAGPT = "metagpt" + AZURE_OPENAI = "azure" + OLLAMA = "ollama" + + +class LLMConfig(YamlModel): + """Config for LLM + + OpenAI: https://github.com/openai/openai-python/blob/main/src/openai/resources/chat/completions.py#L681 + Optional Fields in pydantic: https://docs.pydantic.dev/latest/migration/#required-optional-and-nullable-fields + """ + + api_key: str + api_type: LLMType = LLMType.OPENAI + base_url: str = "https://api.openai.com/v1" + api_version: Optional[str] = None + model: Optional[str] = None # also stands for DEPLOYMENT_NAME + + # For Spark(Xunfei), maybe remove later + app_id: Optional[str] = None + api_secret: Optional[str] = None + domain: Optional[str] = None + + # For Chat Completion + max_token: int = 4096 + temperature: float = 0.0 + top_p: float = 1.0 + top_k: int = 0 + repetition_penalty: float = 1.0 + stop: Optional[str] = None + presence_penalty: float = 0.0 + frequency_penalty: float = 0.0 + best_of: Optional[int] = None + n: Optional[int] = None + stream: bool = False + logprobs: Optional[bool] = None # https://cookbook.openai.com/examples/using_logprobs + top_logprobs: Optional[int] = None + timeout: int = 60 + + # For Network + proxy: Optional[str] = None + + # Cost Control + calc_usage: bool = True + + @field_validator("api_key") + @classmethod + def check_llm_key(cls, v): + if v in ["", None, "YOUR_API_KEY"]: + raise ValueError("Please set your API key in config.yaml") + return v diff --git a/metagpt/configs/mermaid_config.py b/metagpt/configs/mermaid_config.py new file mode 100644 index 000000000..de4a3865c --- /dev/null +++ b/metagpt/configs/mermaid_config.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 19:07 +@Author : alexanderwu +@File : mermaid_config.py +""" +from typing import Literal + +from metagpt.utils.yaml_model import YamlModel + + +class MermaidConfig(YamlModel): + """Config for Mermaid""" + + engine: Literal["nodejs", "ink", "playwright", "pyppeteer"] = "nodejs" + path: str = "" + puppeteer_config: str = "" # Only for nodejs engine diff --git a/metagpt/configs/redis_config.py b/metagpt/configs/redis_config.py new file mode 100644 index 000000000..c4cfb6764 --- /dev/null +++ b/metagpt/configs/redis_config.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 19:06 +@Author : alexanderwu +@File : redis_config.py +""" +from metagpt.utils.yaml_model import YamlModelWithoutDefault + + +class RedisConfig(YamlModelWithoutDefault): + host: str + port: int + username: str = "" + password: str + db: str + + def to_url(self): + return f"redis://{self.host}:{self.port}" + + def to_kwargs(self): + return { + "username": self.username, + "password": self.password, + "db": self.db, + } diff --git a/metagpt/configs/s3_config.py b/metagpt/configs/s3_config.py new file mode 100644 index 000000000..72b81fae4 --- /dev/null +++ b/metagpt/configs/s3_config.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 19:07 +@Author : alexanderwu +@File : s3_config.py +""" +from metagpt.utils.yaml_model import YamlModelWithoutDefault + + +class S3Config(YamlModelWithoutDefault): + access_key: str + secret_key: str + endpoint: str + bucket: str diff --git a/metagpt/configs/search_config.py b/metagpt/configs/search_config.py new file mode 100644 index 000000000..a8ae918db --- /dev/null +++ b/metagpt/configs/search_config.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 19:06 +@Author : alexanderwu +@File : search_config.py +""" +from metagpt.tools import SearchEngineType +from metagpt.utils.yaml_model import YamlModel + + +class SearchConfig(YamlModel): + """Config for Search""" + + api_key: str + api_type: SearchEngineType = SearchEngineType.SERPAPI_GOOGLE + cse_id: str = "" # for google diff --git a/metagpt/configs/workspace_config.py b/metagpt/configs/workspace_config.py new file mode 100644 index 000000000..df7aeaef9 --- /dev/null +++ b/metagpt/configs/workspace_config.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 19:09 +@Author : alexanderwu +@File : workspace_config.py +""" +from datetime import datetime +from pathlib import Path +from uuid import uuid4 + +from pydantic import field_validator, model_validator + +from metagpt.const import DEFAULT_WORKSPACE_ROOT +from metagpt.utils.yaml_model import YamlModel + + +class WorkspaceConfig(YamlModel): + path: Path = DEFAULT_WORKSPACE_ROOT + use_uid: bool = False + uid: str = "" + + @field_validator("path") + @classmethod + def check_workspace_path(cls, v): + if isinstance(v, str): + v = Path(v) + return v + + @model_validator(mode="after") + def check_uid_and_update_path(self): + if self.use_uid and not self.uid: + self.uid = f"{datetime.now().strftime('%Y%m%d%H%M%S')}-{uuid4().hex[-8:]}" + self.path = self.path / self.uid + + # Create workspace path if not exists + self.path.mkdir(parents=True, exist_ok=True) + return self diff --git a/metagpt/context.py b/metagpt/context.py new file mode 100644 index 000000000..53b673b3e --- /dev/null +++ b/metagpt/context.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 16:32 +@Author : alexanderwu +@File : context.py +""" +import os +from pathlib import Path +from typing import Dict, Optional + +from pydantic import BaseModel + +from metagpt.config2 import Config +from metagpt.const import OPTIONS +from metagpt.provider.base_llm import BaseLLM +from metagpt.provider.llm_provider_registry import get_llm +from metagpt.utils.cost_manager import CostManager +from metagpt.utils.git_repository import GitRepository + + +class Context(BaseModel): + kwargs: Dict = {} + config: Config = Config.default() + git_repo: Optional[GitRepository] = None + src_workspace: Optional[Path] = None + cost_manager: CostManager = CostManager() + + @property + def options(self): + """Return all key-values""" + return OPTIONS.get() + + def new_environ(self): + """Return a new os.environ object""" + env = os.environ.copy() + i = self.options + env.update({k: v for k, v in i.items() if isinstance(v, str)}) + return env + + def llm(self, name: Optional[str] = None) -> BaseLLM: + """Return a LLM instance""" + llm = get_llm(self.config.get_llm_config(name)) + if llm.cost_manager is None: + llm.cost_manager = self.cost_manager + return llm + + +# Global context +context = Context() + + +if __name__ == "__main__": + print(context.model_dump_json(indent=4)) + print(context.config.get_openai_llm()) diff --git a/metagpt/document_store/faiss_store.py b/metagpt/document_store/faiss_store.py index 1271f1c23..2359917d5 100644 --- a/metagpt/document_store/faiss_store.py +++ b/metagpt/document_store/faiss_store.py @@ -9,14 +9,13 @@ import asyncio from pathlib import Path from typing import Optional -from langchain.embeddings import OpenAIEmbeddings from langchain.vectorstores import FAISS from langchain_core.embeddings import Embeddings -from metagpt.config import CONFIG from metagpt.document import IndexableDocument from metagpt.document_store.base_store import LocalStore from metagpt.logs import logger +from metagpt.utils.embedding import get_embedding class FaissStore(LocalStore): @@ -25,9 +24,7 @@ class FaissStore(LocalStore): ): self.meta_col = meta_col self.content_col = content_col - self.embedding = embedding or OpenAIEmbeddings( - openai_api_key=CONFIG.openai_api_key, openai_api_base=CONFIG.openai_base_url - ) + self.embedding = embedding or get_embedding() super().__init__(raw_data, cache_dir) def _load(self) -> Optional["FaissStore"]: diff --git a/metagpt/environment.py b/metagpt/environment.py index ddb9ad9dd..b68aa40de 100644 --- a/metagpt/environment.py +++ b/metagpt/environment.py @@ -17,7 +17,7 @@ from typing import Iterable, Set from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator -from metagpt.config import CONFIG +from metagpt.context import Context from metagpt.logs import logger from metagpt.roles.role import Role from metagpt.schema import Message @@ -35,6 +35,7 @@ class Environment(BaseModel): roles: dict[str, SerializeAsAny[Role]] = Field(default_factory=dict, validate_default=True) members: dict[Role, Set] = Field(default_factory=dict, exclude=True) history: str = "" # For debug + context: Context = Field(default_factory=Context, exclude=True) @model_validator(mode="after") def init_roles(self): @@ -85,6 +86,7 @@ class Environment(BaseModel): """ self.roles[role.profile] = role role.set_env(self) + role.context = self.context def add_roles(self, roles: Iterable[Role]): """增加一批在当前环境的角色 @@ -95,6 +97,7 @@ class Environment(BaseModel): for role in roles: # setup system message with roles role.set_env(self) + role.context = self.context def publish_message(self, message: Message, peekable: bool = True) -> bool: """ @@ -162,7 +165,6 @@ class Environment(BaseModel): """Set the labels for message to be consumed by the object""" self.members[obj] = tags - @staticmethod - def archive(auto_archive=True): - if auto_archive and CONFIG.git_repo: - CONFIG.git_repo.archive() + def archive(self, auto_archive=True): + if auto_archive and self.context.git_repo: + self.context.git_repo.archive() diff --git a/metagpt/learn/text_to_image.py b/metagpt/learn/text_to_image.py index c3c62fb67..1af66d6fb 100644 --- a/metagpt/learn/text_to_image.py +++ b/metagpt/learn/text_to_image.py @@ -8,33 +8,37 @@ """ import base64 -from metagpt.config import CONFIG +from metagpt.config2 import Config from metagpt.const import BASE64_FORMAT +from metagpt.llm import LLM from metagpt.tools.metagpt_text_to_image import oas3_metagpt_text_to_image from metagpt.tools.openai_text_to_image import oas3_openai_text_to_image from metagpt.utils.s3 import S3 -async def text_to_image(text, size_type: str = "512x512", openai_api_key="", model_url="", **kwargs): +async def text_to_image(text, size_type: str = "512x512", model_url="", config: Config = None): """Text to image :param text: The text used for image conversion. :param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys` :param size_type: If using OPENAI, the available size options are ['256x256', '512x512', '1024x1024'], while for MetaGPT, the options are ['512x512', '512x768']. :param model_url: MetaGPT model url + :param config: Config :return: The image data is returned in Base64 encoding. """ image_declaration = "data:image/png;base64," - if CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL or model_url: + + if model_url: binary_data = await oas3_metagpt_text_to_image(text, size_type, model_url) - elif CONFIG.OPENAI_API_KEY or openai_api_key: - binary_data = await oas3_openai_text_to_image(text, size_type) + elif oai_llm := config.get_openai_llm(): + binary_data = await oas3_openai_text_to_image(text, size_type, LLM(oai_llm)) else: raise ValueError("Missing necessary parameters.") base64_data = base64.b64encode(binary_data).decode("utf-8") - s3 = S3() - url = await s3.cache(data=base64_data, file_ext=".png", format=BASE64_FORMAT) if s3.is_valid else "" + assert config.s3, "S3 config is required." + s3 = S3(config.s3) + url = await s3.cache(data=base64_data, file_ext=".png", format=BASE64_FORMAT) if url: return f"![{text}]({url})" return image_declaration + base64_data if base64_data else "" diff --git a/metagpt/learn/text_to_speech.py b/metagpt/learn/text_to_speech.py index ecd00c724..9ee3d64ee 100644 --- a/metagpt/learn/text_to_speech.py +++ b/metagpt/learn/text_to_speech.py @@ -48,7 +48,7 @@ async def text_to_speech( audio_declaration = "data:audio/wav;base64," base64_data = await oas3_azsure_tts(text, lang, voice, style, role, subscription_key, region) s3 = S3() - url = await s3.cache(data=base64_data, file_ext=".wav", format=BASE64_FORMAT) if s3.is_valid else "" + url = await s3.cache(data=base64_data, file_ext=".wav", format=BASE64_FORMAT) if url: return f"[{text}]({url})" return audio_declaration + base64_data if base64_data else base64_data @@ -60,7 +60,7 @@ async def text_to_speech( text=text, app_id=iflytek_app_id, api_key=iflytek_api_key, api_secret=iflytek_api_secret ) s3 = S3() - url = await s3.cache(data=base64_data, file_ext=".mp3", format=BASE64_FORMAT) if s3.is_valid else "" + url = await s3.cache(data=base64_data, file_ext=".mp3", format=BASE64_FORMAT) if url: return f"[{text}]({url})" return audio_declaration + base64_data if base64_data else base64_data diff --git a/metagpt/llm.py b/metagpt/llm.py index 76dd5a0f8..9a473e306 100644 --- a/metagpt/llm.py +++ b/metagpt/llm.py @@ -8,17 +8,10 @@ from typing import Optional -from metagpt.config import CONFIG, LLMProviderEnum +from metagpt.context import context from metagpt.provider.base_llm import BaseLLM -from metagpt.provider.human_provider import HumanProvider -from metagpt.provider.llm_provider_registry import LLM_REGISTRY - -_ = HumanProvider() # Avoid pre-commit error -def LLM(provider: Optional[LLMProviderEnum] = None) -> BaseLLM: - """get the default llm provider""" - if provider is None: - provider = CONFIG.get_default_llm_provider_enum() - - return LLM_REGISTRY.get_provider(provider) +def LLM(name: Optional[str] = None) -> BaseLLM: + """get the default llm provider if name is None""" + return context.llm(name) diff --git a/metagpt/provider/__init__.py b/metagpt/provider/__init__.py index 28157a4e2..33f43b148 100644 --- a/metagpt/provider/__init__.py +++ b/metagpt/provider/__init__.py @@ -14,6 +14,7 @@ from metagpt.provider.openai_api import OpenAILLM from metagpt.provider.zhipuai_api import ZhiPuAILLM from metagpt.provider.azure_openai_api import AzureOpenAILLM from metagpt.provider.metagpt_api import MetaGPTLLM +from metagpt.provider.human_provider import HumanProvider __all__ = [ "FireworksLLM", @@ -24,4 +25,5 @@ __all__ = [ "AzureOpenAILLM", "MetaGPTLLM", "OllamaLLM", + "HumanProvider", ] diff --git a/metagpt/provider/anthropic_api.py b/metagpt/provider/anthropic_api.py index b9d7d9e38..2a65b81c1 100644 --- a/metagpt/provider/anthropic_api.py +++ b/metagpt/provider/anthropic_api.py @@ -9,12 +9,15 @@ import anthropic from anthropic import Anthropic, AsyncAnthropic -from metagpt.config import CONFIG +from metagpt.configs.llm_config import LLMConfig class Claude2: + def __init__(self, config: LLMConfig = None): + self.config = config + def ask(self, prompt: str) -> str: - client = Anthropic(api_key=CONFIG.anthropic_api_key) + client = Anthropic(api_key=self.config.api_key) res = client.completions.create( model="claude-2", @@ -24,7 +27,7 @@ class Claude2: return res.completion async def aask(self, prompt: str) -> str: - aclient = AsyncAnthropic(api_key=CONFIG.anthropic_api_key) + aclient = AsyncAnthropic(api_key=self.config.api_key) res = await aclient.completions.create( model="claude-2", diff --git a/metagpt/provider/azure_openai_api.py b/metagpt/provider/azure_openai_api.py index d15d1c82e..987eafc4c 100644 --- a/metagpt/provider/azure_openai_api.py +++ b/metagpt/provider/azure_openai_api.py @@ -13,12 +13,12 @@ from openai import AsyncAzureOpenAI from openai._base_client import AsyncHttpxClientWrapper -from metagpt.config import LLMProviderEnum +from metagpt.configs.llm_config import LLMType from metagpt.provider.llm_provider_registry import register_provider from metagpt.provider.openai_api import OpenAILLM -@register_provider(LLMProviderEnum.AZURE_OPENAI) +@register_provider(LLMType.AZURE_OPENAI) class AzureOpenAILLM(OpenAILLM): """ Check https://platform.openai.com/examples for examples diff --git a/metagpt/provider/base_llm.py b/metagpt/provider/base_llm.py index 52dd96b1a..f13899c38 100644 --- a/metagpt/provider/base_llm.py +++ b/metagpt/provider/base_llm.py @@ -8,15 +8,30 @@ """ import json from abc import ABC, abstractmethod -from typing import Optional +from typing import Optional, Union + +from openai import AsyncOpenAI + +from metagpt.configs.llm_config import LLMConfig +from metagpt.schema import Message +from metagpt.utils.cost_manager import CostManager class BaseLLM(ABC): """LLM API abstract class, requiring all inheritors to provide a series of standard capabilities""" + config: LLMConfig use_system_prompt: bool = True system_prompt = "You are a helpful assistant." + # OpenAI / Azure / Others + aclient: Optional[Union[AsyncOpenAI]] = None + cost_manager: Optional[CostManager] = None + + @abstractmethod + def __init__(self, config: LLMConfig = None): + pass + def _user_msg(self, msg: str) -> dict[str, str]: return {"role": "user", "content": msg} @@ -63,10 +78,9 @@ class BaseLLM(ABC): context.append(self._assistant_msg(rsp_text)) return self._extract_assistant_rsp(context) - async def aask_code(self, msgs: list[str], timeout=3) -> str: + async def aask_code(self, messages: Union[str, Message, list[dict]], timeout=3) -> dict: """FIXME: No code segment filtering has been done here, and all results are actually displayed""" - rsp_text = await self.aask_batch(msgs, timeout=timeout) - return rsp_text + raise NotImplementedError @abstractmethod async def acompletion(self, messages: list[dict], timeout=3): diff --git a/metagpt/provider/fireworks_api.py b/metagpt/provider/fireworks_api.py index f0af68818..09581a2f3 100644 --- a/metagpt/provider/fireworks_api.py +++ b/metagpt/provider/fireworks_api.py @@ -15,7 +15,7 @@ from tenacity import ( wait_random_exponential, ) -from metagpt.config import CONFIG, Config, LLMProviderEnum +from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.logs import logger from metagpt.provider.llm_provider_registry import register_provider from metagpt.provider.openai_api import OpenAILLM, log_and_reraise @@ -64,27 +64,18 @@ class FireworksCostManager(CostManager): token_costs = self.model_grade_token_costs(model) cost = (prompt_tokens * token_costs["prompt"] + completion_tokens * token_costs["completion"]) / 1000000 self.total_cost += cost - max_budget = CONFIG.max_budget if CONFIG.max_budget else CONFIG.cost_manager.max_budget logger.info( - f"Total running cost: ${self.total_cost:.4f} | Max budget: ${max_budget:.3f} | " + f"Total running cost: ${self.total_cost:.4f}" f"Current cost: ${cost:.4f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}" ) - CONFIG.total_cost = self.total_cost -@register_provider(LLMProviderEnum.FIREWORKS) +@register_provider(LLMType.FIREWORKS) class FireworksLLM(OpenAILLM): - def __init__(self): - self.config: Config = CONFIG - self.__init_fireworks() + def __init__(self, config: LLMConfig = None): + super().__init__(config=config) self.auto_max_tokens = False - self._cost_manager = FireworksCostManager() - - def __init_fireworks(self): - self.is_azure = False - self.rpm = int(self.config.get("RPM", 10)) - self._init_client() - self.model = self.config.fireworks_api_model # `self.model` should after `_make_client` to rewrite it + self.cost_manager = FireworksCostManager() def _make_client_kwargs(self) -> dict: kwargs = dict(api_key=self.config.fireworks_api_key, base_url=self.config.fireworks_api_base) @@ -94,14 +85,14 @@ class FireworksLLM(OpenAILLM): if self.config.calc_usage and usage: try: # use FireworksCostManager not CONFIG.cost_manager - self._cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model) + self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model) except Exception as e: logger.error(f"updating costs failed!, exp: {e}") def get_costs(self) -> Costs: - return self._cost_manager.get_costs() + return self.cost_manager.get_costs() - async def _achat_completion_stream(self, messages: list[dict]) -> str: + async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str: response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create( **self._cons_kwargs(messages), stream=True ) diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index 795687773..0f2251792 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -19,7 +19,7 @@ from tenacity import ( wait_random_exponential, ) -from metagpt.config import CONFIG, LLMProviderEnum +from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.logs import log_llm_stream, logger from metagpt.provider.base_llm import BaseLLM from metagpt.provider.llm_provider_registry import register_provider @@ -41,21 +41,21 @@ class GeminiGenerativeModel(GenerativeModel): return await self._async_client.count_tokens(model=self.model_name, contents=contents) -@register_provider(LLMProviderEnum.GEMINI) +@register_provider(LLMType.GEMINI) class GeminiLLM(BaseLLM): """ Refs to `https://ai.google.dev/tutorials/python_quickstart` """ - def __init__(self): + def __init__(self, config: LLMConfig = None): self.use_system_prompt = False # google gemini has no system prompt when use api - self.__init_gemini(CONFIG) + self.__init_gemini(config) self.model = "gemini-pro" # so far only one model self.llm = GeminiGenerativeModel(model_name=self.model) - def __init_gemini(self, config: CONFIG): - genai.configure(api_key=config.gemini_api_key) + def __init_gemini(self, config: LLMConfig): + genai.configure(api_key=config.api_key) def _user_msg(self, msg: str) -> dict[str, str]: # Not to change BaseLLM default functions but update with Gemini's conversation format. @@ -71,11 +71,11 @@ class GeminiLLM(BaseLLM): def _update_costs(self, usage: dict): """update each request's token cost""" - if CONFIG.calc_usage: + if self.config.calc_usage: try: prompt_tokens = int(usage.get("prompt_tokens", 0)) completion_tokens = int(usage.get("completion_tokens", 0)) - CONFIG.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) + self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) except Exception as e: logger.error(f"google gemini updats costs failed! exp: {e}") @@ -108,7 +108,7 @@ class GeminiLLM(BaseLLM): self._update_costs(usage) return resp - async def acompletion(self, messages: list[dict]) -> dict: + async def acompletion(self, messages: list[dict], timeout=3) -> dict: return await self._achat_completion(messages) async def _achat_completion_stream(self, messages: list[dict]) -> str: diff --git a/metagpt/provider/human_provider.py b/metagpt/provider/human_provider.py index 59d236a3a..25b897d74 100644 --- a/metagpt/provider/human_provider.py +++ b/metagpt/provider/human_provider.py @@ -5,6 +5,7 @@ Author: garylin2099 """ from typing import Optional +from metagpt.configs.llm_config import LLMConfig from metagpt.logs import logger from metagpt.provider.base_llm import BaseLLM @@ -14,6 +15,9 @@ class HumanProvider(BaseLLM): This enables replacing LLM anywhere in the framework with a human, thus introducing human interaction """ + def __init__(self, config: LLMConfig = None): + pass + def ask(self, msg: str, timeout=3) -> str: logger.info("It's your turn, please type in your response. You may also refer to the context below") rsp = input(msg) diff --git a/metagpt/provider/llm_provider_registry.py b/metagpt/provider/llm_provider_registry.py index 2b3ef93a3..2f68f27c8 100644 --- a/metagpt/provider/llm_provider_registry.py +++ b/metagpt/provider/llm_provider_registry.py @@ -5,7 +5,8 @@ @Author : alexanderwu @File : llm_provider_registry.py """ -from metagpt.config import LLMProviderEnum +from metagpt.configs.llm_config import LLMConfig, LLMType +from metagpt.provider.base_llm import BaseLLM class LLMProviderRegistry: @@ -15,13 +16,9 @@ class LLMProviderRegistry: def register(self, key, provider_cls): self.providers[key] = provider_cls - def get_provider(self, enum: LLMProviderEnum): + def get_provider(self, enum: LLMType): """get provider instance according to the enum""" - return self.providers[enum]() - - -# Registry instance -LLM_REGISTRY = LLMProviderRegistry() + return self.providers[enum] def register_provider(key): @@ -32,3 +29,12 @@ def register_provider(key): return cls return decorator + + +def get_llm(config: LLMConfig) -> BaseLLM: + """get the default llm provider""" + return LLM_REGISTRY.get_provider(config.api_type)(config) + + +# Registry instance +LLM_REGISTRY = LLMProviderRegistry() diff --git a/metagpt/provider/metagpt_api.py b/metagpt/provider/metagpt_api.py index 69aa7f305..4956746dc 100644 --- a/metagpt/provider/metagpt_api.py +++ b/metagpt/provider/metagpt_api.py @@ -5,12 +5,11 @@ @File : metagpt_api.py @Desc : MetaGPT LLM provider. """ -from metagpt.config import LLMProviderEnum +from metagpt.configs.llm_config import LLMType from metagpt.provider import OpenAILLM from metagpt.provider.llm_provider_registry import register_provider -@register_provider(LLMProviderEnum.METAGPT) +@register_provider(LLMType.METAGPT) class MetaGPTLLM(OpenAILLM): - def __init__(self): - super().__init__() + pass diff --git a/metagpt/provider/ollama_api.py b/metagpt/provider/ollama_api.py index 8ee04de7d..35e39c9cc 100644 --- a/metagpt/provider/ollama_api.py +++ b/metagpt/provider/ollama_api.py @@ -13,48 +13,33 @@ from tenacity import ( wait_random_exponential, ) -from metagpt.config import CONFIG, LLMProviderEnum +from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.const import LLM_API_TIMEOUT from metagpt.logs import log_llm_stream, logger from metagpt.provider.base_llm import BaseLLM from metagpt.provider.general_api_requestor import GeneralAPIRequestor from metagpt.provider.llm_provider_registry import register_provider from metagpt.provider.openai_api import log_and_reraise -from metagpt.utils.cost_manager import CostManager +from metagpt.utils.cost_manager import TokenCostManager -class OllamaCostManager(CostManager): - def update_cost(self, prompt_tokens, completion_tokens, model): - """ - Update the total cost, prompt tokens, and completion tokens. - """ - self.total_prompt_tokens += prompt_tokens - self.total_completion_tokens += completion_tokens - max_budget = CONFIG.max_budget if CONFIG.max_budget else CONFIG.cost_manager.max_budget - logger.info( - f"Max budget: ${max_budget:.3f} | " - f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}" - ) - CONFIG.total_cost = self.total_cost - - -@register_provider(LLMProviderEnum.OLLAMA) +@register_provider(LLMType.OLLAMA) class OllamaLLM(BaseLLM): """ Refs to `https://github.com/jmorganca/ollama/blob/main/docs/api.md#generate-a-chat-completion` """ - def __init__(self): - self.__init_ollama(CONFIG) - self.client = GeneralAPIRequestor(base_url=CONFIG.ollama_api_base) + def __init__(self, config: LLMConfig = None): + self.__init_ollama(config) + self.client = GeneralAPIRequestor(base_url=config.api_base) self.suffix_url = "/chat" self.http_method = "post" self.use_system_prompt = False - self._cost_manager = OllamaCostManager() + self._cost_manager = TokenCostManager() - def __init_ollama(self, config: CONFIG): - assert config.ollama_api_base - self.model = config.ollama_api_model + def __init_ollama(self, config: LLMConfig): + assert config.api_base + self.model = config.model def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict: kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream} @@ -62,7 +47,7 @@ class OllamaLLM(BaseLLM): def _update_costs(self, usage: dict): """update each request's token cost""" - if CONFIG.calc_usage: + if self.config.calc_usage: try: prompt_tokens = int(usage.get("prompt_tokens", 0)) completion_tokens = int(usage.get("completion_tokens", 0)) diff --git a/metagpt/provider/open_llm_api.py b/metagpt/provider/open_llm_api.py index b0c484f5a..a29b263a4 100644 --- a/metagpt/provider/open_llm_api.py +++ b/metagpt/provider/open_llm_api.py @@ -4,56 +4,27 @@ from openai.types import CompletionUsage -from metagpt.config import CONFIG, Config, LLMProviderEnum +from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.logs import logger from metagpt.provider.llm_provider_registry import register_provider from metagpt.provider.openai_api import OpenAILLM -from metagpt.utils.cost_manager import CostManager, Costs +from metagpt.utils.cost_manager import Costs, TokenCostManager from metagpt.utils.token_counter import count_message_tokens, count_string_tokens -class OpenLLMCostManager(CostManager): - """open llm model is self-host, it's free and without cost""" - - def update_cost(self, prompt_tokens, completion_tokens, model): - """ - Update the total cost, prompt tokens, and completion tokens. - - Args: - prompt_tokens (int): The number of tokens used in the prompt. - completion_tokens (int): The number of tokens used in the completion. - model (str): The model used for the API call. - """ - self.total_prompt_tokens += prompt_tokens - self.total_completion_tokens += completion_tokens - max_budget = CONFIG.max_budget if CONFIG.max_budget else CONFIG.cost_manager.max_budget - logger.info( - f"Max budget: ${max_budget:.3f} | reference " - f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}" - ) - - -@register_provider(LLMProviderEnum.OPEN_LLM) +@register_provider(LLMType.OPEN_LLM) class OpenLLM(OpenAILLM): - def __init__(self): - self.config: Config = CONFIG - self.__init_openllm() - self.auto_max_tokens = False - self._cost_manager = OpenLLMCostManager() - - def __init_openllm(self): - self.is_azure = False - self.rpm = int(self.config.get("RPM", 10)) - self._init_client() - self.model = self.config.open_llm_api_model # `self.model` should after `_make_client` to rewrite it + def __init__(self, config: LLMConfig): + super().__init__(config) + self._cost_manager = TokenCostManager() def _make_client_kwargs(self) -> dict: - kwargs = dict(api_key="sk-xxx", base_url=self.config.open_llm_api_base) + kwargs = dict(api_key="sk-xxx", base_url=self.config.base_url) return kwargs def _calc_usage(self, messages: list[dict], rsp: str) -> CompletionUsage: usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0) - if not CONFIG.calc_usage: + if not self.config.calc_usage: return usage try: diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 20dde9ea5..c1337a9f8 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -10,7 +10,7 @@ """ import json -from typing import AsyncIterator, Union +from typing import AsyncIterator, Optional, Union from openai import APIConnectionError, AsyncOpenAI, AsyncStream from openai._base_client import AsyncHttpxClientWrapper @@ -24,13 +24,13 @@ from tenacity import ( wait_random_exponential, ) -from metagpt.config import CONFIG, Config, LLMProviderEnum +from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.logs import log_llm_stream, logger from metagpt.provider.base_llm import BaseLLM from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA, GENERAL_TOOL_CHOICE from metagpt.provider.llm_provider_registry import register_provider from metagpt.schema import Message -from metagpt.utils.cost_manager import Costs +from metagpt.utils.cost_manager import CostManager, Costs from metagpt.utils.exceptions import handle_exception from metagpt.utils.token_counter import ( count_message_tokens, @@ -50,18 +50,19 @@ See FAQ 5.8 raise retry_state.outcome.exception() -@register_provider(LLMProviderEnum.OPENAI) +@register_provider(LLMType.OPENAI) class OpenAILLM(BaseLLM): """Check https://platform.openai.com/examples for examples""" - def __init__(self): - self.config: Config = CONFIG - self._init_openai() + def __init__(self, config: LLMConfig = None): + self.config = config + self._init_model() self._init_client() self.auto_max_tokens = False + self.cost_manager: Optional[CostManager] = None - def _init_openai(self): - self.model = self.config.OPENAI_API_MODEL # Used in _calc_usage & _cons_kwargs + def _init_model(self): + self.model = self.config.model # Used in _calc_usage & _cons_kwargs def _init_client(self): """https://github.com/openai/openai-python#async-usage""" @@ -69,7 +70,7 @@ class OpenAILLM(BaseLLM): self.aclient = AsyncOpenAI(**kwargs) def _make_client_kwargs(self) -> dict: - kwargs = {"api_key": self.config.openai_api_key, "base_url": self.config.openai_base_url} + kwargs = {"api_key": self.config.api_key, "base_url": self.config.base_url} # to use proxy, openai v1 needs http_client if proxy_params := self._get_proxy_params(): @@ -79,10 +80,10 @@ class OpenAILLM(BaseLLM): def _get_proxy_params(self) -> dict: params = {} - if self.config.openai_proxy: - params = {"proxies": self.config.openai_proxy} - if self.config.openai_base_url: - params["base_url"] = self.config.openai_base_url + if self.config.proxy: + params = {"proxies": self.config.proxy} + if self.config.base_url: + params["base_url"] = self.config.base_url return params @@ -103,7 +104,7 @@ class OpenAILLM(BaseLLM): "stop": None, "temperature": 0.3, "model": self.model, - "timeout": max(CONFIG.timeout, timeout), + "timeout": max(self.config.timeout, timeout), } if extra_kwargs: kwargs.update(extra_kwargs) @@ -205,7 +206,7 @@ class OpenAILLM(BaseLLM): def _calc_usage(self, messages: list[dict], rsp: str) -> CompletionUsage: usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0) - if not CONFIG.calc_usage: + if not self.config.calc_usage: return usage try: @@ -218,16 +219,16 @@ class OpenAILLM(BaseLLM): @handle_exception def _update_costs(self, usage: CompletionUsage): - if CONFIG.calc_usage and usage: - CONFIG.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model) + if self.config.calc_usage and usage: + self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model) def get_costs(self) -> Costs: - return CONFIG.cost_manager.get_costs() + return self.cost_manager.get_costs() def _get_max_tokens(self, messages: list[dict]): if not self.auto_max_tokens: - return CONFIG.max_tokens_rsp - return get_max_completion_tokens(messages, self.model, CONFIG.max_tokens_rsp) + return self.config.max_token + return get_max_completion_tokens(messages, self.model, self.config.max_tokens) @handle_exception async def amoderation(self, content: Union[str, list[str]]): diff --git a/metagpt/provider/spark_api.py b/metagpt/provider/spark_api.py index ce889529a..bc842f202 100644 --- a/metagpt/provider/spark_api.py +++ b/metagpt/provider/spark_api.py @@ -16,15 +16,16 @@ from wsgiref.handlers import format_date_time import websocket # 使用websocket_client -from metagpt.config import CONFIG, LLMProviderEnum +from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.logs import logger from metagpt.provider.base_llm import BaseLLM from metagpt.provider.llm_provider_registry import register_provider -@register_provider(LLMProviderEnum.SPARK) +@register_provider(LLMType.SPARK) class SparkLLM(BaseLLM): - def __init__(self): + def __init__(self, config: LLMConfig = None): + self.config = config logger.warning("当前方法无法支持异步运行。当你使用acompletion时,并不能并行访问。") def get_choice_text(self, rsp: dict) -> str: @@ -33,12 +34,12 @@ class SparkLLM(BaseLLM): async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str: # 不支持 logger.error("该功能禁用。") - w = GetMessageFromWeb(messages) + w = GetMessageFromWeb(messages, self.config) return w.run() async def acompletion(self, messages: list[dict], timeout=3): # 不支持异步 - w = GetMessageFromWeb(messages) + w = GetMessageFromWeb(messages, self.config) return w.run() @@ -89,14 +90,14 @@ class GetMessageFromWeb: # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致 return url - def __init__(self, text): + def __init__(self, text, config): self.text = text self.ret = "" - self.spark_appid = CONFIG.spark_appid - self.spark_api_secret = CONFIG.spark_api_secret - self.spark_api_key = CONFIG.spark_api_key - self.domain = CONFIG.domain - self.spark_url = CONFIG.spark_url + self.spark_appid = config.app_id + self.spark_api_secret = config.api_secret + self.spark_api_key = config.api_key + self.domain = config.domain + self.spark_url = config.base_url def on_message(self, ws, message): data = json.loads(message) diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index 865b7fce1..61e9c1aa6 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -16,7 +16,7 @@ from tenacity import ( wait_random_exponential, ) -from metagpt.config import CONFIG, LLMProviderEnum +from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.logs import log_llm_stream, logger from metagpt.provider.base_llm import BaseLLM from metagpt.provider.llm_provider_registry import register_provider @@ -31,27 +31,27 @@ class ZhiPuEvent(Enum): FINISH = "finish" -@register_provider(LLMProviderEnum.ZHIPUAI) +@register_provider(LLMType.ZHIPUAI) class ZhiPuAILLM(BaseLLM): """ Refs to `https://open.bigmodel.cn/dev/api#chatglm_turbo` From now, there is only one model named `chatglm_turbo` """ - def __init__(self): - self.__init_zhipuai(CONFIG) + def __init__(self, config: LLMConfig = None): + self.__init_zhipuai(config) self.llm = ZhiPuModelAPI self.model = "chatglm_turbo" # so far only one model, just use it self.use_system_prompt: bool = False # zhipuai has no system prompt when use api - def __init_zhipuai(self, config: CONFIG): - assert config.zhipuai_api_key - zhipuai.api_key = config.zhipuai_api_key + def __init_zhipuai(self, config: LLMConfig): + assert config.api_key + zhipuai.api_key = config.api_key # due to use openai sdk, set the api_key but it will't be used. # openai.api_key = zhipuai.api_key # due to use openai sdk, set the api_key but it will't be used. - if config.openai_proxy: + if config.proxy: # FIXME: openai v1.x sdk has no proxy support - openai.proxy = config.openai_proxy + openai.proxy = config.proxy def _const_kwargs(self, messages: list[dict]) -> dict: kwargs = {"model": self.model, "prompt": messages, "temperature": 0.3} @@ -59,11 +59,11 @@ class ZhiPuAILLM(BaseLLM): def _update_costs(self, usage: dict): """update each request's token cost""" - if CONFIG.calc_usage: + if self.config.calc_usage: try: prompt_tokens = int(usage.get("prompt_tokens", 0)) completion_tokens = int(usage.get("completion_tokens", 0)) - CONFIG.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) + self.config.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) except Exception as e: logger.error(f"zhipuai updats costs failed! exp: {e}") diff --git a/metagpt/roles/engineer.py b/metagpt/roles/engineer.py index e05e69cbb..e20ea42a7 100644 --- a/metagpt/roles/engineer.py +++ b/metagpt/roles/engineer.py @@ -27,7 +27,6 @@ from typing import Set from metagpt.actions import Action, WriteCode, WriteCodeReview, WriteTasks from metagpt.actions.fix_bug import FixBug from metagpt.actions.summarize_code import SummarizeCode -from metagpt.config import CONFIG from metagpt.const import ( CODE_SUMMARIES_FILE_REPO, CODE_SUMMARIES_PDF_FILE_REPO, @@ -80,6 +79,7 @@ class Engineer(Role): code_todos: list = [] summarize_todos: list = [] next_todo_action: str = "" + n_summarize: int = 0 def __init__(self, **kwargs) -> None: super().__init__(**kwargs) @@ -97,7 +97,7 @@ class Engineer(Role): async def _act_sp_with_cr(self, review=False) -> Set[str]: changed_files = set() - src_file_repo = CONFIG.git_repo.new_file_repository(CONFIG.src_workspace) + src_file_repo = self.git_repo.new_file_repository(self.src_workspace) for todo in self.code_todos: """ # Select essential information from the historical data to reduce the length of the prompt (summarized from human experience): @@ -153,10 +153,10 @@ class Engineer(Role): ) async def _act_summarize(self): - code_summaries_file_repo = CONFIG.git_repo.new_file_repository(CODE_SUMMARIES_FILE_REPO) - code_summaries_pdf_file_repo = CONFIG.git_repo.new_file_repository(CODE_SUMMARIES_PDF_FILE_REPO) + code_summaries_file_repo = self.git_repo.new_file_repository(CODE_SUMMARIES_FILE_REPO) + code_summaries_pdf_file_repo = self.git_repo.new_file_repository(CODE_SUMMARIES_PDF_FILE_REPO) tasks = [] - src_relative_path = CONFIG.src_workspace.relative_to(CONFIG.git_repo.workdir) + src_relative_path = self.src_workspace.relative_to(self.git_repo.workdir) for todo in self.summarize_todos: summary = await todo.run() summary_filename = Path(todo.context.design_filename).with_suffix(".md").name @@ -179,8 +179,8 @@ class Engineer(Role): else: await code_summaries_file_repo.delete(filename=Path(todo.context.design_filename).name) - logger.info(f"--max-auto-summarize-code={CONFIG.max_auto_summarize_code}") - if not tasks or CONFIG.max_auto_summarize_code == 0: + logger.info(f"--max-auto-summarize-code={self.config.max_auto_summarize_code}") + if not tasks or self.config.max_auto_summarize_code == 0: return Message( content="", role=self.profile, @@ -190,7 +190,7 @@ class Engineer(Role): ) # The maximum number of times the 'SummarizeCode' action is automatically invoked, with -1 indicating unlimited. # This parameter is used for debugging the workflow. - CONFIG.max_auto_summarize_code -= 1 if CONFIG.max_auto_summarize_code > 0 else 0 + self.n_summarize += 1 if self.config.max_auto_summarize_code > self.n_summarize else 0 return Message( content=json.dumps(tasks), role=self.profile, cause_by=SummarizeCode, send_to=self, sent_from=self ) @@ -203,8 +203,8 @@ class Engineer(Role): return False, rsp async def _think(self) -> Action | None: - if not CONFIG.src_workspace: - CONFIG.src_workspace = CONFIG.git_repo.workdir / CONFIG.git_repo.workdir.name + if not self.src_workspace: + self.src_workspace = self.git_repo.workdir / self.git_repo.workdir.name write_code_filters = any_to_str_set([WriteTasks, SummarizeCode, FixBug]) summarize_code_filters = any_to_str_set([WriteCode, WriteCodeReview]) if not self.rc.news: @@ -253,11 +253,11 @@ class Engineer(Role): async def _new_code_actions(self, bug_fix=False): # Prepare file repos - src_file_repo = CONFIG.git_repo.new_file_repository(CONFIG.src_workspace) + src_file_repo = self.git_repo.new_file_repository(self.src_workspace) changed_src_files = src_file_repo.all_files if bug_fix else src_file_repo.changed_files - task_file_repo = CONFIG.git_repo.new_file_repository(TASK_FILE_REPO) + task_file_repo = self.git_repo.new_file_repository(TASK_FILE_REPO) changed_task_files = task_file_repo.changed_files - design_file_repo = CONFIG.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO) + design_file_repo = self.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO) changed_files = Documents() # Recode caused by upstream changes. @@ -283,7 +283,7 @@ class Engineer(Role): changed_files.docs[task_filename] = coding_doc self.code_todos = [WriteCode(context=i, llm=self.llm) for i in changed_files.docs.values()] # Code directly modified by the user. - dependency = await CONFIG.git_repo.get_dependency() + dependency = await self.git_repo.get_dependency() for filename in changed_src_files: if filename in changed_files.docs: continue @@ -301,7 +301,7 @@ class Engineer(Role): self.rc.todo = self.code_todos[0] async def _new_summarize_actions(self): - src_file_repo = CONFIG.git_repo.new_file_repository(CONFIG.src_workspace) + src_file_repo = self.git_repo.new_file_repository(self.src_workspace) src_files = src_file_repo.all_files # Generate a SummarizeCode action for each pair of (system_design_doc, task_doc). summarizations = defaultdict(list) diff --git a/metagpt/roles/product_manager.py b/metagpt/roles/product_manager.py index 1d82ac3f2..427c8acb5 100644 --- a/metagpt/roles/product_manager.py +++ b/metagpt/roles/product_manager.py @@ -9,7 +9,6 @@ from metagpt.actions import UserRequirement, WritePRD from metagpt.actions.prepare_documents import PrepareDocuments -from metagpt.config import CONFIG from metagpt.roles.role import Role from metagpt.utils.common import any_to_name @@ -40,11 +39,11 @@ class ProductManager(Role): async def _think(self) -> bool: """Decide what to do""" - if CONFIG.git_repo and not CONFIG.git_reinit: + if self.git_repo and not self.config.git_reinit: self._set_state(1) else: self._set_state(0) - CONFIG.git_reinit = False + self.context.config.git_reinit = False self.todo_action = any_to_name(WritePRD) return bool(self.rc.todo) diff --git a/metagpt/roles/qa_engineer.py b/metagpt/roles/qa_engineer.py index b1d06d122..1a6ca2d9c 100644 --- a/metagpt/roles/qa_engineer.py +++ b/metagpt/roles/qa_engineer.py @@ -15,10 +15,9 @@ of SummarizeCode. """ - from metagpt.actions import DebugError, RunCode, WriteTest from metagpt.actions.summarize_code import SummarizeCode -from metagpt.config import CONFIG +from metagpt.config2 import Config from metagpt.const import ( MESSAGE_ROUTE_TO_NONE, TEST_CODES_FILE_REPO, @@ -50,13 +49,17 @@ class QaEngineer(Role): self._watch([SummarizeCode, WriteTest, RunCode, DebugError]) self.test_round = 0 + @property + def config(self) -> Config: + return self.context.config + async def _write_test(self, message: Message) -> None: - src_file_repo = CONFIG.git_repo.new_file_repository(CONFIG.src_workspace) + src_file_repo = self.context.git_repo.new_file_repository(self.context.src_workspace) changed_files = set(src_file_repo.changed_files.keys()) # Unit tests only. - if CONFIG.reqa_file and CONFIG.reqa_file not in changed_files: - changed_files.add(CONFIG.reqa_file) - tests_file_repo = CONFIG.git_repo.new_file_repository(TEST_CODES_FILE_REPO) + if self.config.reqa_file and self.config.reqa_file not in changed_files: + changed_files.add(self.config.reqa_file) + tests_file_repo = self.context.git_repo.new_file_repository(TEST_CODES_FILE_REPO) for filename in changed_files: # write tests if not filename or "test" in filename: @@ -69,7 +72,7 @@ class QaEngineer(Role): ) logger.info(f"Writing {test_doc.filename}..") context = TestingContext(filename=test_doc.filename, test_doc=test_doc, code_doc=code_doc) - context = await WriteTest(context=context, llm=self.llm).run() + context = await WriteTest(context=context, _context=self.context, llm=self.llm).run() await tests_file_repo.save( filename=context.test_doc.filename, content=context.test_doc.content, @@ -81,8 +84,8 @@ class QaEngineer(Role): command=["python", context.test_doc.root_relative_path], code_filename=context.code_doc.filename, test_filename=context.test_doc.filename, - working_directory=str(CONFIG.git_repo.workdir), - additional_python_paths=[str(CONFIG.src_workspace)], + working_directory=str(self.context.git_repo.workdir), + additional_python_paths=[str(self.context.src_workspace)], ) self.publish_message( Message( @@ -98,17 +101,21 @@ class QaEngineer(Role): async def _run_code(self, msg): run_code_context = RunCodeContext.loads(msg.content) - src_doc = await CONFIG.git_repo.new_file_repository(CONFIG.src_workspace).get(run_code_context.code_filename) + src_doc = await self.context.git_repo.new_file_repository(self.context.src_workspace).get( + run_code_context.code_filename + ) if not src_doc: return - test_doc = await CONFIG.git_repo.new_file_repository(TEST_CODES_FILE_REPO).get(run_code_context.test_filename) + test_doc = await self.context.git_repo.new_file_repository(TEST_CODES_FILE_REPO).get( + run_code_context.test_filename + ) if not test_doc: return run_code_context.code = src_doc.content run_code_context.test_code = test_doc.content result = await RunCode(context=run_code_context, llm=self.llm).run() run_code_context.output_filename = run_code_context.test_filename + ".json" - await CONFIG.git_repo.new_file_repository(TEST_OUTPUTS_FILE_REPO).save( + await self.context.git_repo.new_file_repository(TEST_OUTPUTS_FILE_REPO).save( filename=run_code_context.output_filename, content=result.model_dump_json(), dependencies={src_doc.root_relative_path, test_doc.root_relative_path}, diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 356b9e33f..63316b5de 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -32,9 +32,11 @@ from metagpt.actions import Action, ActionOutput from metagpt.actions.action_node import ActionNode from metagpt.actions.add_requirement import UserRequirement from metagpt.const import SERDESER_PATH -from metagpt.llm import LLM, HumanProvider +from metagpt.context import Context +from metagpt.llm import LLM from metagpt.logs import logger from metagpt.memory import Memory +from metagpt.provider import HumanProvider from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Message, MessageQueue, SerializationMixin from metagpt.utils.common import ( @@ -148,9 +150,46 @@ class Role(SerializationMixin, is_polymorphic_base=True): # builtin variables recovered: bool = False # to tag if a recovered role latest_observed_msg: Optional[Message] = None # record the latest observed message when interrupted + context: Optional[Context] = Field(default=None, exclude=True) __hash__ = object.__hash__ # support Role as hashable type in `Environment.members` + @property + def config(self): + return self.context.config + + @property + def git_repo(self): + return self.context.git_repo + + @git_repo.setter + def git_repo(self, value): + self.context.git_repo = value + + @property + def src_workspace(self): + return self.context.src_workspace + + @src_workspace.setter + def src_workspace(self, value): + self.context.src_workspace = value + + @property + def prompt_schema(self): + return self.context.config.prompt_schema + + @property + def project_name(self): + return self.context.config.project_name + + @project_name.setter + def project_name(self, value): + self.context.config.project_name = value + + @property + def project_path(self): + return self.context.config.project_path + @model_validator(mode="after") def check_subscription(self): if not self.subscription: diff --git a/metagpt/roles/teacher.py b/metagpt/roles/teacher.py index 5449fe828..637fd242a 100644 --- a/metagpt/roles/teacher.py +++ b/metagpt/roles/teacher.py @@ -15,7 +15,6 @@ import aiofiles from metagpt.actions import UserRequirement from metagpt.actions.write_teaching_plan import TeachingPlanBlock, WriteTeachingPlanPart -from metagpt.config import CONFIG from metagpt.logs import logger from metagpt.roles import Role from metagpt.schema import Message @@ -81,7 +80,7 @@ class Teacher(Role): async def save(self, content): """Save teaching plan""" filename = Teacher.new_file_name(self.course_title) - pathname = CONFIG.workspace_path / "teaching_plan" + pathname = self.config.workspace.path / "teaching_plan" pathname.mkdir(exist_ok=True) pathname = pathname / filename try: diff --git a/metagpt/schema.py b/metagpt/schema.py index e36bef395..ec04d321c 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -35,7 +35,6 @@ from pydantic import ( ) from pydantic_core import core_schema -from metagpt.config import CONFIG from metagpt.const import ( MESSAGE_ROUTE_CAUSE_BY, MESSAGE_ROUTE_FROM, @@ -151,12 +150,6 @@ class Document(BaseModel): """ return os.path.join(self.root_path, self.filename) - @property - def full_path(self): - if not CONFIG.git_repo: - return None - return str(CONFIG.git_repo.workdir / self.root_path / self.filename) - def __str__(self): return self.content diff --git a/metagpt/startup.py b/metagpt/startup.py index 767a19a9d..e7ae2b09e 100644 --- a/metagpt/startup.py +++ b/metagpt/startup.py @@ -5,7 +5,7 @@ from pathlib import Path import typer -from metagpt.config import CONFIG +from metagpt.config2 import config app = typer.Typer(add_completion=False) @@ -44,7 +44,7 @@ def startup( ) from metagpt.team import Team - CONFIG.update_via_cli(project_path, project_name, inc, reqa_file, max_auto_summarize_code) + config.update_via_cli(project_path, project_name, inc, reqa_file, max_auto_summarize_code) if not recover_path: company = Team() diff --git a/metagpt/team.py b/metagpt/team.py index b98fc2efb..87fee8dc7 100644 --- a/metagpt/team.py +++ b/metagpt/team.py @@ -15,7 +15,6 @@ from typing import Any from pydantic import BaseModel, ConfigDict, Field from metagpt.actions import UserRequirement -from metagpt.config import CONFIG from metagpt.const import MESSAGE_ROUTE_TO_ALL, SERDESER_PATH from metagpt.environment import Environment from metagpt.logs import logger @@ -79,18 +78,20 @@ class Team(BaseModel): """Hire roles to cooperate""" self.env.add_roles(roles) + @property + def cost_manager(self): + """Get cost manager""" + return self.env.context.cost_manager + def invest(self, investment: float): """Invest company. raise NoMoneyException when exceed max_budget.""" self.investment = investment - CONFIG.max_budget = investment + self.cost_manager.max_budget = investment logger.info(f"Investment: ${investment}.") - @staticmethod - def _check_balance(): - if CONFIG.cost_manager.total_cost > CONFIG.cost_manager.max_budget: - raise NoMoneyException( - CONFIG.cost_manager.total_cost, f"Insufficient funds: {CONFIG.cost_manager.max_budget}" - ) + def _check_balance(self): + if self.cost_manager.total_cost > self.cost_manager.max_budget: + raise NoMoneyException(self.cost_manager.total_cost, f"Insufficient funds: {self.cost_manager.max_budget}") def run_project(self, idea, send_to: str = ""): """Run a project from publishing user requirement.""" diff --git a/metagpt/tools/metagpt_text_to_image.py b/metagpt/tools/metagpt_text_to_image.py index 9a84e69eb..cf7bf97e7 100644 --- a/metagpt/tools/metagpt_text_to_image.py +++ b/metagpt/tools/metagpt_text_to_image.py @@ -13,7 +13,6 @@ import aiohttp import requests from pydantic import BaseModel -from metagpt.config import CONFIG from metagpt.logs import logger @@ -22,7 +21,7 @@ class MetaGPTText2Image: """ :param model_url: Model reset api url """ - self.model_url = model_url if model_url else CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL + self.model_url = model_url async def text_2_image(self, text, size_type="512x512"): """Text to image @@ -93,6 +92,4 @@ async def oas3_metagpt_text_to_image(text, size_type: str = "512x512", model_url """ if not text: return "" - if not model_url: - model_url = CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL return await MetaGPTText2Image(model_url).text_2_image(text, size_type=size_type) diff --git a/metagpt/tools/openai_text_to_image.py b/metagpt/tools/openai_text_to_image.py index aa00abdcc..fc31b95f7 100644 --- a/metagpt/tools/openai_text_to_image.py +++ b/metagpt/tools/openai_text_to_image.py @@ -10,16 +10,16 @@ import aiohttp import requests -from metagpt.llm import LLM from metagpt.logs import logger +from metagpt.provider.base_llm import BaseLLM class OpenAIText2Image: - def __init__(self): + def __init__(self, llm: BaseLLM): """ :param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys` """ - self._llm = LLM() + self.llm = llm async def text_2_image(self, text, size_type="1024x1024"): """Text to image @@ -29,7 +29,7 @@ class OpenAIText2Image: :return: The image data is returned in Base64 encoding. """ try: - result = await self._llm.aclient.images.generate(prompt=text, n=1, size=size_type) + result = await self.llm.aclient.images.generate(prompt=text, n=1, size=size_type) except Exception as e: logger.error(f"An error occurred:{e}") return "" @@ -57,13 +57,14 @@ class OpenAIText2Image: # Export -async def oas3_openai_text_to_image(text, size_type: str = "1024x1024"): +async def oas3_openai_text_to_image(text, size_type: str = "1024x1024", llm: BaseLLM = None): """Text to image :param text: The text used for image conversion. :param size_type: One of ['256x256', '512x512', '1024x1024'] + :param llm: LLM instance :return: The image data is returned in Base64 encoding. """ if not text: return "" - return await OpenAIText2Image().text_2_image(text, size_type=size_type) + return await OpenAIText2Image(llm).text_2_image(text, size_type=size_type) diff --git a/metagpt/tools/sd_engine.py b/metagpt/tools/sd_engine.py index c4d9d2df4..c56b335ca 100644 --- a/metagpt/tools/sd_engine.py +++ b/metagpt/tools/sd_engine.py @@ -77,7 +77,7 @@ class SDEngine: return self.payload def _save(self, imgs, save_name=""): - save_dir = CONFIG.workspace_path / SD_OUTPUT_FILE_REPO + save_dir = CONFIG.path / SD_OUTPUT_FILE_REPO if not save_dir.exists(): save_dir.mkdir(parents=True, exist_ok=True) batch_decode_base64_to_image(imgs, str(save_dir), save_name=save_name) diff --git a/metagpt/utils/cost_manager.py b/metagpt/utils/cost_manager.py index ce53f2285..7bf5154b6 100644 --- a/metagpt/utils/cost_manager.py +++ b/metagpt/utils/cost_manager.py @@ -80,3 +80,20 @@ class CostManager(BaseModel): def get_costs(self) -> Costs: """Get all costs""" return Costs(self.total_prompt_tokens, self.total_completion_tokens, self.total_cost, self.total_budget) + + +class TokenCostManager(CostManager): + """open llm model is self-host, it's free and without cost""" + + def update_cost(self, prompt_tokens, completion_tokens, model): + """ + Update the total cost, prompt tokens, and completion tokens. + + Args: + prompt_tokens (int): The number of tokens used in the prompt. + completion_tokens (int): The number of tokens used in the completion. + model (str): The model used for the API call. + """ + self.total_prompt_tokens += prompt_tokens + self.total_completion_tokens += completion_tokens + logger.info(f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}") diff --git a/metagpt/utils/embedding.py b/metagpt/utils/embedding.py new file mode 100644 index 000000000..21d62948c --- /dev/null +++ b/metagpt/utils/embedding.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 20:58 +@Author : alexanderwu +@File : embedding.py +""" +from langchain_community.embeddings import OpenAIEmbeddings + +from metagpt.config2 import config + + +def get_embedding(): + llm = config.get_openai_llm() + embedding = OpenAIEmbeddings(openai_api_key=llm.api_key, openai_api_base=llm.base_url) + return embedding diff --git a/metagpt/utils/redis.py b/metagpt/utils/redis.py index 10f33285c..7a640563a 100644 --- a/metagpt/utils/redis.py +++ b/metagpt/utils/redis.py @@ -12,26 +12,25 @@ from datetime import timedelta import aioredis # https://aioredis.readthedocs.io/en/latest/getting-started/ -from metagpt.config import CONFIG +from metagpt.configs.redis_config import RedisConfig from metagpt.logs import logger class Redis: - def __init__(self): + def __init__(self, config: RedisConfig = None): + self.config = config self._client = None async def _connect(self, force=False): if self._client and not force: return True - if not self.is_configured: - return False try: self._client = await aioredis.from_url( - f"redis://{CONFIG.REDIS_HOST}:{CONFIG.REDIS_PORT}", - username=CONFIG.REDIS_USER, - password=CONFIG.REDIS_PASSWORD, - db=CONFIG.REDIS_DB, + self.config.to_url(), + username=self.config.username, + password=self.config.password, + db=self.config.db, ) return True except Exception as e: @@ -62,18 +61,3 @@ class Redis: return await self._client.close() self._client = None - - @property - def is_valid(self) -> bool: - return self._client is not None - - @property - def is_configured(self) -> bool: - return bool( - CONFIG.REDIS_HOST - and CONFIG.REDIS_HOST != "YOUR_REDIS_HOST" - and CONFIG.REDIS_PORT - and CONFIG.REDIS_PORT != "YOUR_REDIS_PORT" - and CONFIG.REDIS_DB is not None - and CONFIG.REDIS_PASSWORD is not None - ) diff --git a/metagpt/utils/s3.py b/metagpt/utils/s3.py index 2a2c1a31c..c0afbb2f5 100644 --- a/metagpt/utils/s3.py +++ b/metagpt/utils/s3.py @@ -8,7 +8,7 @@ from typing import Optional import aioboto3 import aiofiles -from metagpt.config import CONFIG +from metagpt.config2 import S3Config from metagpt.const import BASE64_FORMAT from metagpt.logs import logger @@ -16,13 +16,14 @@ from metagpt.logs import logger class S3: """A class for interacting with Amazon S3 storage.""" - def __init__(self): + def __init__(self, config: S3Config): self.session = aioboto3.Session() + self.config = config self.auth_config = { "service_name": "s3", - "aws_access_key_id": CONFIG.S3_ACCESS_KEY, - "aws_secret_access_key": CONFIG.S3_SECRET_KEY, - "endpoint_url": CONFIG.S3_ENDPOINT_URL, + "aws_access_key_id": config.access_key, + "aws_secret_access_key": config.secret_key, + "endpoint_url": config.endpoint, } async def upload_file( @@ -139,8 +140,8 @@ class S3: data = base64.b64decode(data) if format == BASE64_FORMAT else data.encode(encoding="utf-8") await file.write(data) - bucket = CONFIG.S3_BUCKET - object_pathname = CONFIG.S3_BUCKET or "system" + bucket = self.config.bucket + object_pathname = self.config.bucket or "system" object_pathname += f"/{object_name}" object_pathname = os.path.normpath(object_pathname) await self.upload_file(bucket=bucket, local_path=str(pathname), object_name=object_pathname) @@ -151,20 +152,3 @@ class S3: logger.exception(f"{e}, stack:{traceback.format_exc()}") pathname.unlink(missing_ok=True) return None - - @property - def is_valid(self): - return self.is_configured - - @property - def is_configured(self) -> bool: - return bool( - CONFIG.S3_ACCESS_KEY - and CONFIG.S3_ACCESS_KEY != "YOUR_S3_ACCESS_KEY" - and CONFIG.S3_SECRET_KEY - and CONFIG.S3_SECRET_KEY != "YOUR_S3_SECRET_KEY" - and CONFIG.S3_ENDPOINT_URL - and CONFIG.S3_ENDPOINT_URL != "YOUR_S3_ENDPOINT_URL" - and CONFIG.S3_BUCKET - and CONFIG.S3_BUCKET != "YOUR_S3_BUCKET" - ) diff --git a/metagpt/utils/yaml_model.py b/metagpt/utils/yaml_model.py new file mode 100644 index 000000000..85bdbf9bb --- /dev/null +++ b/metagpt/utils/yaml_model.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 10:18 +@Author : alexanderwu +@File : YamlModel.py +""" +from pathlib import Path +from typing import Dict, Optional + +import yaml +from pydantic import BaseModel, model_validator + + +class YamlModel(BaseModel): + extra_fields: Optional[Dict[str, str]] = None + + @classmethod + def read_yaml(cls, file_path: Path) -> Dict: + with open(file_path, "r") as file: + return yaml.safe_load(file) + + @classmethod + def model_validate_yaml(cls, file_path: Path) -> "YamlModel": + return cls(**cls.read_yaml(file_path)) + + def model_dump_yaml(self, file_path: Path) -> None: + with open(file_path, "w") as file: + yaml.dump(self.model_dump(), file) + + +class YamlModelWithoutDefault(YamlModel): + @model_validator(mode="before") + @classmethod + def check_not_default_config(cls, values): + if any(["YOUR" in v for v in values]): + raise ValueError("Please set your S3 config in config.yaml") + return values diff --git a/tests/metagpt/learn/test_text_to_image.py b/tests/metagpt/learn/test_text_to_image.py index 760b9d09c..85fa679b3 100644 --- a/tests/metagpt/learn/test_text_to_image.py +++ b/tests/metagpt/learn/test_text_to_image.py @@ -10,29 +10,26 @@ import pytest -from metagpt.config import CONFIG +from metagpt.config2 import Config from metagpt.learn.text_to_image import text_to_image @pytest.mark.asyncio -async def test_metagpt_llm(): - # Prerequisites - assert CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL - assert CONFIG.OPENAI_API_KEY +async def test_metagpt_text_to_image(): + config = Config() + assert config.METAGPT_TEXT_TO_IMAGE_MODEL_URL - data = await text_to_image("Panda emoji", size_type="512x512") + data = await text_to_image("Panda emoji", size_type="512x512", model_url=config.METAGPT_TEXT_TO_IMAGE_MODEL_URL) assert "base64" in data or "http" in data - # Mock session env - old_options = CONFIG.options.copy() - new_options = old_options.copy() - new_options["METAGPT_TEXT_TO_IMAGE_MODEL_URL"] = None - CONFIG.set_context(new_options) - try: - data = await text_to_image("Panda emoji", size_type="512x512") - assert "base64" in data or "http" in data - finally: - CONFIG.set_context(old_options) + +@pytest.mark.asyncio +async def test_openai_text_to_image(): + config = Config.default() + assert config.get_openai_llm() + + data = await text_to_image("Panda emoji", size_type="512x512", config=config) + assert "base64" in data or "http" in data if __name__ == "__main__": diff --git a/tests/metagpt/memory/test_brain_memory.py b/tests/metagpt/memory/test_brain_memory.py index 32dcd672a..1f587d9f7 100644 --- a/tests/metagpt/memory/test_brain_memory.py +++ b/tests/metagpt/memory/test_brain_memory.py @@ -8,7 +8,7 @@ import pytest -from metagpt.config import LLMProviderEnum +from metagpt.configs.llm_config import LLMType from metagpt.llm import LLM from metagpt.memory.brain_memory import BrainMemory from metagpt.schema import Message @@ -46,7 +46,7 @@ def test_extract_info(input, tag, val): @pytest.mark.asyncio -@pytest.mark.parametrize("llm", [LLM(provider=LLMProviderEnum.OPENAI), LLM(provider=LLMProviderEnum.METAGPT)]) +@pytest.mark.parametrize("llm", [LLM(provider=LLMType.OPENAI), LLM(provider=LLMType.METAGPT)]) async def test_memory_llm(llm): memory = BrainMemory() for i in range(500): diff --git a/tests/metagpt/provider/test_azure_openai_api.py b/tests/metagpt/provider/test_azure_openai_api.py index f36740e65..4437eec3b 100644 --- a/tests/metagpt/provider/test_azure_openai_api.py +++ b/tests/metagpt/provider/test_azure_openai_api.py @@ -3,12 +3,8 @@ # @Desc : -from metagpt.config import CONFIG -from metagpt.provider.azure_openai_api import AzureOpenAILLM - -CONFIG.OPENAI_API_VERSION = "xx" -CONFIG.openai_proxy = "http://127.0.0.1:80" # fake value +from metagpt.context import context def test_azure_openai_api(): - _ = AzureOpenAILLM() + _ = context.llm("azure") diff --git a/tests/metagpt/provider/test_base_gpt_api.py b/tests/metagpt/provider/test_base_gpt_api.py index 3443b5078..cc781f78a 100644 --- a/tests/metagpt/provider/test_base_gpt_api.py +++ b/tests/metagpt/provider/test_base_gpt_api.py @@ -8,6 +8,7 @@ import pytest +from metagpt.configs.llm_config import LLMConfig from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Message @@ -28,6 +29,9 @@ resp_content = default_chat_resp["choices"][0]["message"]["content"] class MockBaseLLM(BaseLLM): + def __init__(self, config: LLMConfig = None): + pass + def completion(self, messages: list[dict], timeout=3): return default_chat_resp @@ -102,5 +106,5 @@ async def test_async_base_llm(): resp = await base_llm.aask_batch([prompt_msg]) assert resp == resp_content - resp = await base_llm.aask_code([prompt_msg]) - assert resp == resp_content + # resp = await base_llm.aask_code([prompt_msg]) + # assert resp == resp_content diff --git a/tests/metagpt/provider/test_metagpt_api.py b/tests/metagpt/provider/test_metagpt_api.py index 1f00cb653..8f42a53c8 100644 --- a/tests/metagpt/provider/test_metagpt_api.py +++ b/tests/metagpt/provider/test_metagpt_api.py @@ -5,10 +5,10 @@ @Author : mashenquan @File : test_metagpt_api.py """ -from metagpt.config import LLMProviderEnum +from metagpt.configs.llm_config import LLMType from metagpt.llm import LLM def test_llm(): - llm = LLM(provider=LLMProviderEnum.METAGPT) + llm = LLM(provider=LLMType.METAGPT) assert llm diff --git a/tests/metagpt/provider/test_openai.py b/tests/metagpt/provider/test_openai.py index 6166a82de..a996cf5b9 100644 --- a/tests/metagpt/provider/test_openai.py +++ b/tests/metagpt/provider/test_openai.py @@ -2,18 +2,18 @@ from unittest.mock import Mock import pytest -from metagpt.config import CONFIG -from metagpt.provider.openai_api import OpenAILLM +from metagpt.llm import LLM +from metagpt.logs import logger from metagpt.schema import UserMessage -CONFIG.openai_proxy = None - @pytest.mark.asyncio async def test_aask_code(): - llm = OpenAILLM() + llm = LLM(name="gpt3t") msg = [{"role": "user", "content": "Write a python hello world code."}] rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"} + + logger.info(rsp) assert "language" in rsp assert "code" in rsp assert len(rsp["code"]) > 0 @@ -21,7 +21,7 @@ async def test_aask_code(): @pytest.mark.asyncio async def test_aask_code_str(): - llm = OpenAILLM() + llm = LLM(name="gpt3t") msg = "Write a python hello world code." rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"} assert "language" in rsp @@ -30,8 +30,8 @@ async def test_aask_code_str(): @pytest.mark.asyncio -async def test_aask_code_Message(): - llm = OpenAILLM() +async def test_aask_code_message(): + llm = LLM(name="gpt3t") msg = UserMessage("Write a python hello world code.") rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"} assert "language" in rsp diff --git a/tests/metagpt/test_document.py b/tests/metagpt/test_document.py index 18650e112..e7b08544b 100644 --- a/tests/metagpt/test_document.py +++ b/tests/metagpt/test_document.py @@ -28,6 +28,6 @@ def load_existing_repo(path): def test_repo_set_load(): - repo_path = CONFIG.workspace_path / "test_repo" + repo_path = CONFIG.path / "test_repo" set_existing_repo(repo_path) load_existing_repo(repo_path) diff --git a/tests/metagpt/test_schema.py b/tests/metagpt/test_schema.py index 816c186e2..925f4b2dc 100644 --- a/tests/metagpt/test_schema.py +++ b/tests/metagpt/test_schema.py @@ -15,7 +15,6 @@ import pytest from metagpt.actions import Action from metagpt.actions.action_node import ActionNode from metagpt.actions.write_code import WriteCode -from metagpt.config import CONFIG from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO from metagpt.schema import ( AIMessage, @@ -119,8 +118,6 @@ def test_document(): assert doc.filename == meta_doc.filename assert meta_doc.content == "" - assert doc.full_path == str(CONFIG.git_repo.workdir / doc.root_path / doc.filename) - @pytest.mark.asyncio async def test_message_queue(): diff --git a/tests/metagpt/tools/test_azure_tts.py b/tests/metagpt/tools/test_azure_tts.py index 38fef557e..dca71544e 100644 --- a/tests/metagpt/tools/test_azure_tts.py +++ b/tests/metagpt/tools/test_azure_tts.py @@ -32,7 +32,7 @@ async def test_azure_tts(): “Writing a binary file in Python is similar to writing a regular text file, but you'll work with bytes instead of strings.” """ - path = CONFIG.workspace_path / "tts" + path = CONFIG.path / "tts" path.mkdir(exist_ok=True, parents=True) filename = path / "girl.wav" filename.unlink(missing_ok=True) diff --git a/tests/metagpt/tools/test_sd_tool.py b/tests/metagpt/tools/test_sd_tool.py index e457101a9..52b970229 100644 --- a/tests/metagpt/tools/test_sd_tool.py +++ b/tests/metagpt/tools/test_sd_tool.py @@ -22,5 +22,5 @@ def test_sd_engine_generate_prompt(): async def test_sd_engine_run_t2i(): sd_engine = SDEngine() await sd_engine.run_t2i(prompts=["test"]) - img_path = CONFIG.workspace_path / "resources" / "SD_Output" / "output_0.png" + img_path = CONFIG.path / "resources" / "SD_Output" / "output_0.png" assert os.path.exists(img_path) diff --git a/tests/metagpt/utils/test_redis.py b/tests/metagpt/utils/test_redis.py index b93ff0cdb..e6e2c2ce2 100644 --- a/tests/metagpt/utils/test_redis.py +++ b/tests/metagpt/utils/test_redis.py @@ -8,38 +8,19 @@ import pytest -from metagpt.config import CONFIG +from metagpt.config2 import Config from metagpt.utils.redis import Redis @pytest.mark.asyncio async def test_redis(): - # Prerequisites - assert CONFIG.REDIS_HOST and CONFIG.REDIS_HOST != "YOUR_REDIS_HOST" - assert CONFIG.REDIS_PORT and CONFIG.REDIS_PORT != "YOUR_REDIS_PORT" - # assert CONFIG.REDIS_USER - assert CONFIG.REDIS_PASSWORD is not None and CONFIG.REDIS_PASSWORD != "YOUR_REDIS_PASSWORD" - assert CONFIG.REDIS_DB is not None and CONFIG.REDIS_DB != "YOUR_REDIS_DB_INDEX, str, 0-based" + redis = Config.default().redis - conn = Redis() - assert not conn.is_valid + conn = Redis(redis) await conn.set("test", "test", timeout_sec=0) assert await conn.get("test") == b"test" await conn.close() - # Mock session env - old_options = CONFIG.options.copy() - new_options = old_options.copy() - new_options["REDIS_HOST"] = "YOUR_REDIS_HOST" - CONFIG.set_context(new_options) - try: - conn = Redis() - await conn.set("test", "test", timeout_sec=0) - assert not await conn.get("test") == b"test" - await conn.close() - finally: - CONFIG.set_context(old_options) - if __name__ == "__main__": pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/utils/test_s3.py b/tests/metagpt/utils/test_s3.py index f74e7b52a..708f4b9c3 100644 --- a/tests/metagpt/utils/test_s3.py +++ b/tests/metagpt/utils/test_s3.py @@ -11,30 +11,25 @@ from pathlib import Path import aiofiles import pytest -from metagpt.config import CONFIG +from metagpt.config2 import Config from metagpt.utils.s3 import S3 @pytest.mark.asyncio async def test_s3(): # Prerequisites - assert CONFIG.S3_ACCESS_KEY and CONFIG.S3_ACCESS_KEY != "YOUR_S3_ACCESS_KEY" - assert CONFIG.S3_SECRET_KEY and CONFIG.S3_SECRET_KEY != "YOUR_S3_SECRET_KEY" - assert CONFIG.S3_ENDPOINT_URL and CONFIG.S3_ENDPOINT_URL != "YOUR_S3_ENDPOINT_URL" - # assert CONFIG.S3_SECURE: true # true/false - assert CONFIG.S3_BUCKET and CONFIG.S3_BUCKET != "YOUR_S3_BUCKET" - - conn = S3() - assert conn.is_valid + s3 = Config.default().s3 + assert s3 + conn = S3(s3) object_name = "unittest.bak" - await conn.upload_file(bucket=CONFIG.S3_BUCKET, local_path=__file__, object_name=object_name) + await conn.upload_file(bucket=s3.bucket, local_path=__file__, object_name=object_name) pathname = (Path(__file__).parent / uuid.uuid4().hex).with_suffix(".bak") pathname.unlink(missing_ok=True) - await conn.download_file(bucket=CONFIG.S3_BUCKET, object_name=object_name, local_path=str(pathname)) + await conn.download_file(bucket=s3.bucket, object_name=object_name, local_path=str(pathname)) assert pathname.exists() - url = await conn.get_object_url(bucket=CONFIG.S3_BUCKET, object_name=object_name) + url = await conn.get_object_url(bucket=s3.bucket, object_name=object_name) assert url - bin_data = await conn.get_object(bucket=CONFIG.S3_BUCKET, object_name=object_name) + bin_data = await conn.get_object(bucket=s3.bucket, object_name=object_name) assert bin_data async with aiofiles.open(__file__, mode="r", encoding="utf-8") as reader: data = await reader.read() @@ -42,17 +37,13 @@ async def test_s3(): assert "http" in res # Mock session env - old_options = CONFIG.options.copy() - new_options = old_options.copy() - new_options["S3_ACCESS_KEY"] = "YOUR_S3_ACCESS_KEY" - CONFIG.set_context(new_options) + s3.access_key = "ABC" try: - conn = S3() - assert not conn.is_valid + conn = S3(s3) res = await conn.cache("ABC", ".bak", "script") assert not res - finally: - CONFIG.set_context(old_options) + except Exception: + pass if __name__ == "__main__":