add context and config2

This commit is contained in:
geekan 2024-01-04 21:16:23 +08:00
parent 42bb40a0f6
commit e5d11a046c
76 changed files with 922 additions and 495 deletions

4
config/config2.yaml Normal file
View file

@ -0,0 +1,4 @@
llm:
gpt3t:
api_key: "YOUR_API_KEY"
model: "gpt-3.5-turbo-1106"

View file

@ -6,7 +6,7 @@ Author: garylin2099
import re
from metagpt.actions import Action
from metagpt.config import CONFIG
from metagpt.config2 import config
from metagpt.const import METAGPT_ROOT
from metagpt.logs import logger
from metagpt.roles import Role
@ -48,8 +48,8 @@ class CreateAgent(Action):
pattern = r"```python(.*)```"
match = re.search(pattern, rsp, re.DOTALL)
code_text = match.group(1) if match else ""
CONFIG.workspace_path.mkdir(parents=True, exist_ok=True)
new_file = CONFIG.workspace_path / "agent_created_agent.py"
config.workspace.path.mkdir(parents=True, exist_ok=True)
new_file = config.workspace.path / "agent_created_agent.py"
new_file.write_text(code_text)
return code_text

View file

@ -8,7 +8,7 @@ import asyncio
from langchain.embeddings import OpenAIEmbeddings
from metagpt.config import CONFIG
from metagpt.config2 import config
from metagpt.const import DATA_PATH, EXAMPLE_PATH
from metagpt.document_store import FaissStore
from metagpt.logs import logger
@ -16,7 +16,8 @@ from metagpt.roles import Sales
def get_store():
embedding = OpenAIEmbeddings(openai_api_key=CONFIG.openai_api_key, openai_api_base=CONFIG.openai_base_url)
llm = config.get_openai_llm()
embedding = OpenAIEmbeddings(openai_api_key=llm.api_key, openai_api_base=llm.base_url)
return FaissStore(DATA_PATH / "example.json", embedding=embedding)

View file

@ -13,6 +13,7 @@ from typing import Optional, Union
from pydantic import ConfigDict, Field, model_validator
from metagpt.actions.action_node import ActionNode
from metagpt.context import Context
from metagpt.llm import LLM
from metagpt.provider.base_llm import BaseLLM
from metagpt.schema import (
@ -33,14 +34,41 @@ class Action(SerializationMixin, is_polymorphic_base=True):
prefix: str = "" # aask*时会加上prefix作为system_message
desc: str = "" # for skill manager
node: ActionNode = Field(default=None, exclude=True)
_context: Optional[Context] = Field(default=None, exclude=True)
@property
def git_repo(self):
return self._context.git_repo
@property
def src_workspace(self):
return self._context.src_workspace
@property
def prompt_schema(self):
return self._context.config.prompt_schema
@property
def project_name(self):
return self._context.config.project_name
@project_name.setter
def project_name(self, value):
self._context.config.project_name = value
@property
def project_path(self):
return self._context.config.project_path
@model_validator(mode="before")
@classmethod
def set_name_if_empty(cls, values):
if "name" not in values or not values["name"]:
values["name"] = cls.__name__
return values
@model_validator(mode="before")
@classmethod
def _init_with_instruction(cls, values):
if "instruction" in values:
name = values["name"]

View file

@ -14,7 +14,6 @@ from typing import Any, Dict, List, Optional, Tuple, Type
from pydantic import BaseModel, create_model, model_validator
from tenacity import retry, stop_after_attempt, wait_random_exponential
from metagpt.config import CONFIG
from metagpt.llm import BaseLLM
from metagpt.logs import logger
from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess
@ -262,7 +261,7 @@ class ActionNode:
output_data_mapping: dict,
system_msgs: Optional[list[str]] = None,
schema="markdown", # compatible to original format
timeout=CONFIG.timeout,
timeout=None,
) -> (str, BaseModel):
"""Use ActionOutput to wrap the output of aask"""
content = await self.llm.aask(prompt, system_msgs, timeout=timeout)
@ -294,7 +293,7 @@ class ActionNode:
def set_context(self, context):
self.set_recursive("context", context)
async def simple_fill(self, schema, mode, timeout=CONFIG.timeout, exclude=None):
async def simple_fill(self, schema, mode, timeout=None, exclude=None):
prompt = self.compile(context=self.context, schema=schema, mode=mode, exclude=exclude)
if schema != "raw":
@ -309,7 +308,7 @@ class ActionNode:
return self
async def fill(self, context, llm, schema="json", mode="auto", strgy="simple", timeout=CONFIG.timeout, exclude=[]):
async def fill(self, context, llm, schema="json", mode="auto", strgy="simple", timeout=None, exclude=[]):
"""Fill the node(s) with mode.
:param context: Everything we should know when filling node.

View file

@ -9,12 +9,13 @@
2. According to Section 2.2.3.1 of RFC 135, replace file data in the message with the file name.
"""
import re
from typing import Optional
from pydantic import Field
from metagpt.actions.action import Action
from metagpt.config import CONFIG
from metagpt.const import TEST_CODES_FILE_REPO, TEST_OUTPUTS_FILE_REPO
from metagpt.context import Context
from metagpt.logs import logger
from metagpt.schema import RunCodeContext, RunCodeResult
from metagpt.utils.common import CodeParser
@ -49,8 +50,8 @@ Now you should start rewriting the code:
class DebugError(Action):
name: str = "DebugError"
context: RunCodeContext = Field(default_factory=RunCodeContext)
_context: Optional[Context] = None
async def run(self, *args, **kwargs) -> str:
output_doc = await FileRepository.get_file(
@ -66,7 +67,7 @@ class DebugError(Action):
logger.info(f"Debug and rewrite {self.context.test_filename}")
code_doc = await FileRepository.get_file(
filename=self.context.code_filename, relative_path=CONFIG.src_workspace
filename=self.context.code_filename, relative_path=self._context.src_workspace
)
if not code_doc:
return ""

View file

@ -15,7 +15,6 @@ from typing import Optional
from metagpt.actions import Action, ActionOutput
from metagpt.actions.design_api_an import DESIGN_API_NODE
from metagpt.config import CONFIG
from metagpt.const import (
DATA_API_DESIGN_FILE_REPO,
PRDS_FILE_REPO,
@ -46,13 +45,13 @@ class WriteDesign(Action):
"clearly and in detail."
)
async def run(self, with_messages: Message, schema: str = CONFIG.prompt_schema):
async def run(self, with_messages: Message, schema: str = None):
# Use `git status` to identify which PRD documents have been modified in the `docs/prds` directory.
prds_file_repo = CONFIG.git_repo.new_file_repository(PRDS_FILE_REPO)
prds_file_repo = self.git_repo.new_file_repository(PRDS_FILE_REPO)
changed_prds = prds_file_repo.changed_files
# Use `git status` to identify which design documents in the `docs/system_designs` directory have undergone
# changes.
system_design_file_repo = CONFIG.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO)
system_design_file_repo = self.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO)
changed_system_designs = system_design_file_repo.changed_files
# For those PRDs and design documents that have undergone changes, regenerate the design content.
@ -76,11 +75,11 @@ class WriteDesign(Action):
# leaving room for global optimization in subsequent steps.
return ActionOutput(content=changed_files.model_dump_json(), instruct_content=changed_files)
async def _new_system_design(self, context, schema=CONFIG.prompt_schema):
async def _new_system_design(self, context, schema=None):
node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, schema=schema)
return node
async def _merge(self, prd_doc, system_design_doc, schema=CONFIG.prompt_schema):
async def _merge(self, prd_doc, system_design_doc, schema=None):
context = NEW_REQ_TEMPLATE.format(old_design=system_design_doc.content, context=prd_doc.content)
node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, schema=schema)
system_design_doc.content = node.instruct_content.model_dump_json()
@ -106,23 +105,21 @@ class WriteDesign(Action):
await self._save_pdf(doc)
return doc
@staticmethod
async def _save_data_api_design(design_doc):
async def _save_data_api_design(self, design_doc):
m = json.loads(design_doc.content)
data_api_design = m.get("Data structures and interfaces")
if not data_api_design:
return
pathname = CONFIG.git_repo.workdir / DATA_API_DESIGN_FILE_REPO / Path(design_doc.filename).with_suffix("")
pathname = self.git_repo.workdir / DATA_API_DESIGN_FILE_REPO / Path(design_doc.filename).with_suffix("")
await WriteDesign._save_mermaid_file(data_api_design, pathname)
logger.info(f"Save class view to {str(pathname)}")
@staticmethod
async def _save_seq_flow(design_doc):
async def _save_seq_flow(self, design_doc):
m = json.loads(design_doc.content)
seq_flow = m.get("Program call flow")
if not seq_flow:
return
pathname = CONFIG.git_repo.workdir / Path(SEQ_FLOW_FILE_REPO) / Path(design_doc.filename).with_suffix("")
pathname = self.git_repo.workdir / Path(SEQ_FLOW_FILE_REPO) / Path(design_doc.filename).with_suffix("")
await WriteDesign._save_mermaid_file(seq_flow, pathname)
logger.info(f"Saving sequence flow to {str(pathname)}")

View file

@ -12,7 +12,6 @@ from pathlib import Path
from typing import Optional
from metagpt.actions import Action, ActionOutput
from metagpt.config import CONFIG
from metagpt.const import DOCS_FILE_REPO, REQUIREMENT_FILENAME
from metagpt.schema import Document
from metagpt.utils.file_repository import FileRepository
@ -25,18 +24,22 @@ class PrepareDocuments(Action):
name: str = "PrepareDocuments"
context: Optional[str] = None
@property
def config(self):
return self._context.config
def _init_repo(self):
"""Initialize the Git environment."""
if not CONFIG.project_path:
name = CONFIG.project_name or FileRepository.new_filename()
path = Path(CONFIG.workspace_path) / name
if not self.config.project_path:
name = self.config.project_name or FileRepository.new_filename()
path = Path(self.config.workspace.path) / name
else:
path = Path(CONFIG.project_path)
if path.exists() and not CONFIG.inc:
path = Path(self.config.project_path)
if path.exists() and not self.config.inc:
shutil.rmtree(path)
CONFIG.project_path = path
CONFIG.project_name = path.name
CONFIG.git_repo = GitRepository(local_path=path, auto_init=True)
self.config.project_path = path
self.config.project_name = path.name
self._context.git_repo = GitRepository(local_path=path, auto_init=True)
async def run(self, with_messages, **kwargs):
"""Create and initialize the workspace folder, initialize the Git environment."""

View file

@ -16,7 +16,6 @@ from typing import Optional
from metagpt.actions import ActionOutput
from metagpt.actions.action import Action
from metagpt.actions.project_management_an import PM_NODE
from metagpt.config import CONFIG
from metagpt.const import (
PACKAGE_REQUIREMENTS_FILENAME,
SYSTEM_DESIGN_FILE_REPO,
@ -40,11 +39,15 @@ class WriteTasks(Action):
name: str = "CreateTasks"
context: Optional[str] = None
async def run(self, with_messages, schema=CONFIG.prompt_schema):
system_design_file_repo = CONFIG.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO)
@property
def prompt_schema(self):
return self._context.config.prompt_schema
async def run(self, with_messages, schema=None):
system_design_file_repo = self.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO)
changed_system_designs = system_design_file_repo.changed_files
tasks_file_repo = CONFIG.git_repo.new_file_repository(TASK_FILE_REPO)
tasks_file_repo = self.git_repo.new_file_repository(TASK_FILE_REPO)
changed_tasks = tasks_file_repo.changed_files
change_files = Documents()
# Rewrite the system designs that have undergone changes based on the git head diff under
@ -87,21 +90,20 @@ class WriteTasks(Action):
await self._save_pdf(task_doc=task_doc)
return task_doc
async def _run_new_tasks(self, context, schema=CONFIG.prompt_schema):
node = await PM_NODE.fill(context, self.llm, schema)
async def _run_new_tasks(self, context):
node = await PM_NODE.fill(context, self.llm, schema=self.prompt_schema)
return node
async def _merge(self, system_design_doc, task_doc, schema=CONFIG.prompt_schema) -> Document:
async def _merge(self, system_design_doc, task_doc) -> Document:
context = NEW_REQ_TEMPLATE.format(context=system_design_doc.content, old_tasks=task_doc.content)
node = await PM_NODE.fill(context, self.llm, schema)
node = await PM_NODE.fill(context, self.llm, schema=self.prompt_schema)
task_doc.content = node.instruct_content.model_dump_json()
return task_doc
@staticmethod
async def _update_requirements(doc):
async def _update_requirements(self, doc):
m = json.loads(doc.content)
packages = set(m.get("Required Python third-party packages", set()))
file_repo = CONFIG.git_repo.new_file_repository()
file_repo = self.git_repo.new_file_repository()
requirement_doc = await file_repo.get(filename=PACKAGE_REQUIREMENTS_FILENAME)
if not requirement_doc:
requirement_doc = Document(filename=PACKAGE_REQUIREMENTS_FILENAME, root_path=".", content="")

View file

@ -10,7 +10,6 @@ import re
from pathlib import Path
from metagpt.actions import Action
from metagpt.config import CONFIG
from metagpt.const import CLASS_VIEW_FILE_REPO, GRAPH_REPO_FILE_REPO
from metagpt.repo_parser import RepoParser
from metagpt.utils.di_graph_repository import DiGraphRepository
@ -21,8 +20,8 @@ class RebuildClassView(Action):
def __init__(self, name="", context=None, llm=None):
super().__init__(name=name, context=context, llm=llm)
async def run(self, with_messages=None, format=CONFIG.prompt_schema):
graph_repo_pathname = CONFIG.git_repo.workdir / GRAPH_REPO_FILE_REPO / CONFIG.git_repo.workdir.name
async def run(self, with_messages=None):
graph_repo_pathname = self.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.git_repo.workdir.name
graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
repo_parser = RepoParser(base_directory=self.context)
class_views = await repo_parser.rebuild_class_views(path=Path(self.context)) # use pylint
@ -57,7 +56,7 @@ class RebuildClassView(Action):
# logger.info(f"{concat_namespace(filename, class_name)} {GraphKeyword.HAS_CLASS_VIEW} {class_view}")
async def _save(self, graph_db):
class_view_file_repo = CONFIG.git_repo.new_file_repository(relative_path=CLASS_VIEW_FILE_REPO)
class_view_file_repo = self.git_repo.new_file_repository(relative_path=CLASS_VIEW_FILE_REPO)
dataset = await graph_db.select(predicate=GraphKeyword.HAS_CLASS_VIEW)
all_class_view = []
for spo in dataset:

View file

@ -21,7 +21,6 @@ from typing import Tuple
from pydantic import Field
from metagpt.actions.action import Action
from metagpt.config import CONFIG
from metagpt.logs import logger
from metagpt.schema import RunCodeContext, RunCodeResult
from metagpt.utils.exceptions import handle_exception
@ -89,13 +88,12 @@ class RunCode(Action):
return "", str(e)
return namespace.get("result", ""), ""
@classmethod
async def run_script(cls, working_directory, additional_python_paths=[], command=[]) -> Tuple[str, str]:
async def run_script(self, working_directory, additional_python_paths=[], command=[]) -> Tuple[str, str]:
working_directory = str(working_directory)
additional_python_paths = [str(path) for path in additional_python_paths]
# Copy the current environment variables
env = CONFIG.new_environ()
env = self._context.new_environ()
# Modify the PYTHONPATH environment variable
additional_python_paths = [working_directory] + additional_python_paths

View file

@ -11,7 +11,7 @@ import pydantic
from pydantic import Field, model_validator
from metagpt.actions import Action
from metagpt.config import CONFIG, Config
from metagpt.config import Config
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.tools import SearchEngineType
@ -103,12 +103,11 @@ You are a member of a professional butler team and will provide helpful suggesti
"""
# TOTEST
class SearchAndSummarize(Action):
name: str = ""
content: Optional[str] = None
config: None = Field(default_factory=Config)
engine: Optional[SearchEngineType] = CONFIG.search_engine
engine: Optional[SearchEngineType] = None
search_func: Optional[Any] = None
search_engine: SearchEngine = None
result: str = ""

View file

@ -11,7 +11,6 @@ from pydantic import Field
from tenacity import retry, stop_after_attempt, wait_random_exponential
from metagpt.actions.action import Action
from metagpt.config import CONFIG
from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
from metagpt.logs import logger
from metagpt.schema import CodeSummarizeContext
@ -105,7 +104,7 @@ class SummarizeCode(Action):
design_doc = await FileRepository.get_file(filename=design_pathname.name, relative_path=SYSTEM_DESIGN_FILE_REPO)
task_pathname = Path(self.context.task_filename)
task_doc = await FileRepository.get_file(filename=task_pathname.name, relative_path=TASK_FILE_REPO)
src_file_repo = CONFIG.git_repo.new_file_repository(relative_path=CONFIG.src_workspace)
src_file_repo = self.git_repo.new_file_repository(relative_path=self._context.src_workspace)
code_blocks = []
for filename in self.context.codes_filenames:
code_doc = await src_file_repo.get(filename)

View file

@ -21,7 +21,6 @@ from pydantic import Field
from tenacity import retry, stop_after_attempt, wait_random_exponential
from metagpt.actions.action import Action
from metagpt.config import CONFIG
from metagpt.const import (
BUGFIX_FILENAME,
CODE_SUMMARIES_FILE_REPO,
@ -114,7 +113,12 @@ class WriteCode(Action):
if bug_feedback:
code_context = coding_context.code_doc.content
else:
code_context = await self.get_codes(coding_context.task_doc, exclude=self.context.filename)
code_context = await self.get_codes(
coding_context.task_doc,
exclude=self.context.filename,
git_repo=self.git_repo,
src_workspace=self._context.src_workspace,
)
prompt = PROMPT_TEMPLATE.format(
design=coding_context.design_doc.content if coding_context.design_doc else "",
@ -129,13 +133,13 @@ class WriteCode(Action):
code = await self.write_code(prompt)
if not coding_context.code_doc:
# avoid root_path pydantic ValidationError if use WriteCode alone
root_path = CONFIG.src_workspace if CONFIG.src_workspace else ""
root_path = self._context.src_workspace if self._context.src_workspace else ""
coding_context.code_doc = Document(filename=coding_context.filename, root_path=root_path)
coding_context.code_doc.content = code
return coding_context
@staticmethod
async def get_codes(task_doc, exclude) -> str:
async def get_codes(task_doc, exclude, git_repo, src_workspace) -> str:
if not task_doc:
return ""
if not task_doc.content:
@ -143,7 +147,7 @@ class WriteCode(Action):
m = json.loads(task_doc.content)
code_filenames = m.get("Task list", [])
codes = []
src_file_repo = CONFIG.git_repo.new_file_repository(relative_path=CONFIG.src_workspace)
src_file_repo = git_repo.new_file_repository(relative_path=src_workspace)
for filename in code_filenames:
if filename == exclude:
continue

View file

@ -13,7 +13,6 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential
from metagpt.actions import WriteCode
from metagpt.actions.action import Action
from metagpt.config import CONFIG
from metagpt.logs import logger
from metagpt.schema import CodingContext
from metagpt.utils.common import CodeParser
@ -137,11 +136,16 @@ class WriteCodeReview(Action):
async def run(self, *args, **kwargs) -> CodingContext:
iterative_code = self.context.code_doc.content
k = CONFIG.code_review_k_times or 1
k = self._context.config.code_review_k_times or 1
for i in range(k):
format_example = FORMAT_EXAMPLE.format(filename=self.context.code_doc.filename)
task_content = self.context.task_doc.content if self.context.task_doc else ""
code_context = await WriteCode.get_codes(self.context.task_doc, exclude=self.context.filename)
code_context = await WriteCode.get_codes(
self.context.task_doc,
exclude=self.context.filename,
git_repo=self._context.git_repo,
src_workspace=self.src_workspace,
)
context = "\n".join(
[
"## System Design\n" + str(self.context.design_doc) + "\n",

View file

@ -26,7 +26,6 @@ from metagpt.actions.write_prd_an import (
WP_ISSUE_TYPE_NODE,
WRITE_PRD_NODE,
)
from metagpt.config import CONFIG
from metagpt.const import (
BUGFIX_FILENAME,
COMPETITIVE_ANALYSIS_FILE_REPO,
@ -65,10 +64,10 @@ class WritePRD(Action):
name: str = "WritePRD"
content: Optional[str] = None
async def run(self, with_messages, schema=CONFIG.prompt_schema, *args, **kwargs) -> ActionOutput | Message:
async def run(self, with_messages, *args, **kwargs) -> ActionOutput | Message:
# Determine which requirement documents need to be rewritten: Use LLM to assess whether new requirements are
# related to the PRD. If they are related, rewrite the PRD.
docs_file_repo = CONFIG.git_repo.new_file_repository(relative_path=DOCS_FILE_REPO)
docs_file_repo = self.git_repo.new_file_repository(relative_path=DOCS_FILE_REPO)
requirement_doc = await docs_file_repo.get(filename=REQUIREMENT_FILENAME)
if requirement_doc and await self._is_bugfix(requirement_doc.content):
await docs_file_repo.save(filename=BUGFIX_FILENAME, content=requirement_doc.content)
@ -85,7 +84,7 @@ class WritePRD(Action):
else:
await docs_file_repo.delete(filename=BUGFIX_FILENAME)
prds_file_repo = CONFIG.git_repo.new_file_repository(PRDS_FILE_REPO)
prds_file_repo = self.git_repo.new_file_repository(PRDS_FILE_REPO)
prd_docs = await prds_file_repo.get_all()
change_files = Documents()
for prd_doc in prd_docs:
@ -109,7 +108,7 @@ class WritePRD(Action):
# optimization in subsequent steps.
return ActionOutput(content=change_files.model_dump_json(), instruct_content=change_files)
async def _run_new_requirement(self, requirements, schema=CONFIG.prompt_schema) -> ActionOutput:
async def _run_new_requirement(self, requirements) -> ActionOutput:
# sas = SearchAndSummarize()
# # rsp = await sas.run(context=requirements, system_text=SEARCH_AND_SUMMARIZE_SYSTEM_EN_US)
# rsp = ""
@ -117,7 +116,7 @@ class WritePRD(Action):
# if sas.result:
# logger.info(sas.result)
# logger.info(rsp)
project_name = CONFIG.project_name if CONFIG.project_name else ""
project_name = self.project_name
context = CONTEXT_TEMPLATE.format(requirements=requirements, project_name=project_name)
exclude = [PROJECT_NAME.key] if project_name else []
node = await WRITE_PRD_NODE.fill(context=context, llm=self.llm, exclude=exclude) # schema=schema
@ -129,11 +128,11 @@ class WritePRD(Action):
node = await WP_IS_RELATIVE_NODE.fill(context, self.llm)
return node.get("is_relative") == "YES"
async def _merge(self, new_requirement_doc, prd_doc, schema=CONFIG.prompt_schema) -> Document:
if not CONFIG.project_name:
CONFIG.project_name = Path(CONFIG.project_path).name
async def _merge(self, new_requirement_doc, prd_doc) -> Document:
if not self.project_name:
self.project_name = Path(self.project_path).name
prompt = NEW_REQ_TEMPLATE.format(requirements=new_requirement_doc.content, old_prd=prd_doc.content)
node = await WRITE_PRD_NODE.fill(context=prompt, llm=self.llm, schema=schema)
node = await WRITE_PRD_NODE.fill(context=prompt, llm=self.llm, schema=self.prompt_schema)
prd_doc.content = node.instruct_content.model_dump_json()
await self._rename_workspace(node)
return prd_doc
@ -157,15 +156,12 @@ class WritePRD(Action):
await self._save_pdf(new_prd_doc)
return new_prd_doc
@staticmethod
async def _save_competitive_analysis(prd_doc):
async def _save_competitive_analysis(self, prd_doc):
m = json.loads(prd_doc.content)
quadrant_chart = m.get("Competitive Quadrant Chart")
if not quadrant_chart:
return
pathname = (
CONFIG.git_repo.workdir / Path(COMPETITIVE_ANALYSIS_FILE_REPO) / Path(prd_doc.filename).with_suffix("")
)
pathname = self.git_repo.workdir / Path(COMPETITIVE_ANALYSIS_FILE_REPO) / Path(prd_doc.filename).with_suffix("")
if not pathname.parent.exists():
pathname.parent.mkdir(parents=True, exist_ok=True)
await mermaid_to_file(quadrant_chart, pathname)
@ -174,20 +170,19 @@ class WritePRD(Action):
async def _save_pdf(prd_doc):
await FileRepository.save_as(doc=prd_doc, with_suffix=".md", relative_path=PRD_PDF_FILE_REPO)
@staticmethod
async def _rename_workspace(prd):
if not CONFIG.project_name:
async def _rename_workspace(self, prd):
if not self.project_name:
if isinstance(prd, (ActionOutput, ActionNode)):
ws_name = prd.instruct_content.model_dump()["Project Name"]
else:
ws_name = CodeParser.parse_str(block="Project Name", text=prd)
if ws_name:
CONFIG.project_name = ws_name
CONFIG.git_repo.rename_root(CONFIG.project_name)
self.project_name = ws_name
self.git_repo.rename_root(self.project_name)
async def _is_bugfix(self, context) -> bool:
src_workspace_path = CONFIG.git_repo.workdir / CONFIG.git_repo.workdir.name
code_files = CONFIG.git_repo.get_files(relative_path=src_workspace_path)
src_workspace_path = self.git_repo.workdir / self.git_repo.workdir.name
code_files = self.git_repo.get_files(relative_path=src_workspace_path)
if not code_files:
return False
node = await WP_ISSUE_TYPE_NODE.fill(context, self.llm)

View file

@ -75,6 +75,7 @@ class WriteTeachingPlanPart(Action):
if "{" not in value:
return value
# FIXME: 从Context中获取参数
merged_opts = CONFIG.options or {}
try:
return value.format(**merged_opts)

View file

@ -11,7 +11,6 @@
from typing import Optional
from metagpt.actions.action import Action
from metagpt.config import CONFIG
from metagpt.const import TEST_CODES_FILE_REPO
from metagpt.logs import logger
from metagpt.schema import Document, TestingContext
@ -64,7 +63,7 @@ class WriteTest(Action):
code_to_test=self.context.code_doc.content,
test_file_name=self.context.test_doc.filename,
source_file_path=self.context.code_doc.root_relative_path,
workspace=CONFIG.git_repo.workdir,
workspace=self.git_repo.workdir,
)
self.context.test_doc.content = await self.write_code(prompt)
return self.context

View file

@ -11,13 +11,13 @@ import json
import os
import warnings
from copy import deepcopy
from enum import Enum
from pathlib import Path
from typing import Any
from uuid import uuid4
import yaml
from metagpt.configs.llm_config import LLMType
from metagpt.const import DEFAULT_WORKSPACE_ROOT, METAGPT_ROOT, OPTIONS
from metagpt.logs import logger
from metagpt.tools import SearchEngineType, WebBrowserEngineType
@ -38,19 +38,6 @@ class NotConfiguredException(Exception):
super().__init__(self.message)
class LLMProviderEnum(Enum):
OPENAI = "openai"
ANTHROPIC = "anthropic"
SPARK = "spark"
ZHIPUAI = "zhipuai"
FIREWORKS = "fireworks"
OPEN_LLM = "open_llm"
GEMINI = "gemini"
METAGPT = "metagpt"
AZURE_OPENAI = "azure_openai"
OLLAMA = "ollama"
class Config(metaclass=Singleton):
"""
Regular usage method:
@ -81,27 +68,25 @@ class Config(metaclass=Singleton):
global_options.update(OPTIONS.get())
logger.debug("Config loading done.")
def get_default_llm_provider_enum(self) -> LLMProviderEnum:
def get_default_llm_provider_enum(self) -> LLMType:
"""Get first valid LLM provider enum"""
mappings = {
LLMProviderEnum.OPENAI: bool(
LLMType.OPENAI: bool(
self._is_valid_llm_key(self.OPENAI_API_KEY) and not self.OPENAI_API_TYPE and self.OPENAI_API_MODEL
),
LLMProviderEnum.ANTHROPIC: self._is_valid_llm_key(self.ANTHROPIC_API_KEY),
LLMProviderEnum.ZHIPUAI: self._is_valid_llm_key(self.ZHIPUAI_API_KEY),
LLMProviderEnum.FIREWORKS: self._is_valid_llm_key(self.FIREWORKS_API_KEY),
LLMProviderEnum.OPEN_LLM: self._is_valid_llm_key(self.OPEN_LLM_API_BASE),
LLMProviderEnum.GEMINI: self._is_valid_llm_key(self.GEMINI_API_KEY),
LLMProviderEnum.METAGPT: bool(
self._is_valid_llm_key(self.OPENAI_API_KEY) and self.OPENAI_API_TYPE == "metagpt"
),
LLMProviderEnum.AZURE_OPENAI: bool(
LLMType.ANTHROPIC: self._is_valid_llm_key(self.ANTHROPIC_API_KEY),
LLMType.ZHIPUAI: self._is_valid_llm_key(self.ZHIPUAI_API_KEY),
LLMType.FIREWORKS: self._is_valid_llm_key(self.FIREWORKS_API_KEY),
LLMType.OPEN_LLM: self._is_valid_llm_key(self.OPEN_LLM_API_BASE),
LLMType.GEMINI: self._is_valid_llm_key(self.GEMINI_API_KEY),
LLMType.METAGPT: bool(self._is_valid_llm_key(self.OPENAI_API_KEY) and self.OPENAI_API_TYPE == "metagpt"),
LLMType.AZURE_OPENAI: bool(
self._is_valid_llm_key(self.OPENAI_API_KEY)
and self.OPENAI_API_TYPE == "azure"
and self.DEPLOYMENT_NAME
and self.OPENAI_API_VERSION
),
LLMProviderEnum.OLLAMA: self._is_valid_llm_key(self.OLLAMA_API_BASE),
LLMType.OLLAMA: self._is_valid_llm_key(self.OLLAMA_API_BASE),
}
provider = None
for k, v in mappings.items():
@ -109,7 +94,7 @@ class Config(metaclass=Singleton):
provider = k
break
if provider is LLMProviderEnum.GEMINI and not require_python_version(req_version=(3, 10)):
if provider is LLMType.GEMINI and not require_python_version(req_version=(3, 10)):
warnings.warn("Use Gemini requires Python >= 3.10")
model_name = self.get_model_name(provider=provider)
if model_name:
@ -122,8 +107,8 @@ class Config(metaclass=Singleton):
def get_model_name(self, provider=None) -> str:
provider = provider or self.get_default_llm_provider_enum()
model_mappings = {
LLMProviderEnum.OPENAI: self.OPENAI_API_MODEL,
LLMProviderEnum.AZURE_OPENAI: self.DEPLOYMENT_NAME,
LLMType.OPENAI: self.OPENAI_API_MODEL,
LLMType.AZURE_OPENAI: self.DEPLOYMENT_NAME,
}
return model_mappings.get(provider, "")
@ -166,6 +151,7 @@ class Config(metaclass=Singleton):
self.fireworks_api_model = self._get("FIREWORKS_API_MODEL")
self.claude_api_key = self._get("ANTHROPIC_API_KEY")
self.serpapi_api_key = self._get("SERPAPI_API_KEY")
self.serper_api_key = self._get("SERPER_API_KEY")
self.google_api_key = self._get("GOOGLE_API_KEY")
@ -200,7 +186,7 @@ class Config(metaclass=Singleton):
self.workspace_path = self.workspace_path / workspace_uid
self._ensure_workspace_exists()
self.max_auto_summarize_code = self.max_auto_summarize_code or self._get("MAX_AUTO_SUMMARIZE_CODE", 1)
self.timeout = int(self._get("TIMEOUT", 3))
self.timeout = int(self._get("TIMEOUT", 60))
def update_via_cli(self, project_path, project_name, inc, reqa_file, max_auto_summarize_code):
"""update config via cli"""

124
metagpt/config2.py Normal file
View file

@ -0,0 +1,124 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/1/4 01:25
@Author : alexanderwu
@File : llm_factory.py
"""
import os
from pathlib import Path
from typing import Dict, Iterable, List, Literal, Optional
from pydantic import BaseModel, Field, model_validator
from metagpt.configs.browser_config import BrowserConfig
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.configs.mermaid_config import MermaidConfig
from metagpt.configs.redis_config import RedisConfig
from metagpt.configs.s3_config import S3Config
from metagpt.configs.search_config import SearchConfig
from metagpt.configs.workspace_config import WorkspaceConfig
from metagpt.const import METAGPT_ROOT
from metagpt.utils.yaml_model import YamlModel
class CLIParams(BaseModel):
project_path: str = ""
project_name: str = ""
inc: bool = False
reqa_file: str = ""
max_auto_summarize_code: int = 0
git_reinit: bool = False
@model_validator(mode="after")
def check_project_path(self):
if self.project_path:
self.inc = True
self.project_name = self.project_name or Path(self.project_path).name
class Config(CLIParams, YamlModel):
# Key Parameters
llm: Dict[str, LLMConfig] = Field(default_factory=Dict)
# Global Proxy. Will be used if llm.proxy is not set
proxy: str = ""
# Tool Parameters
search: Dict[str, SearchConfig] = {}
browser: Dict[str, BrowserConfig] = {"default": BrowserConfig()}
mermaid: Dict[str, MermaidConfig] = {"default": MermaidConfig()}
# Storage Parameters
s3: Optional[S3Config] = None
redis: Optional[RedisConfig] = None
# Misc Parameters
repair_llm_output: bool = False
prompt_schema: Literal["json", "markdown", "raw"] = "json"
workspace: WorkspaceConfig = WorkspaceConfig()
enable_longterm_memory: bool = False
code_review_k_times: int = 2
# Will be removed in the future
llm_for_researcher_summary: str = "gpt3"
llm_for_researcher_report: str = "gpt3"
METAGPT_TEXT_TO_IMAGE_MODEL_URL: str = ""
@classmethod
def default(cls):
"""Load default config
- Priority: env < default_config_paths
- Inside default_config_paths, the latter one overwrites the former one
"""
default_config_paths: List[Path] = [
METAGPT_ROOT / "config/config2.yaml",
Path.home() / ".metagpt/config2.yaml",
]
dicts = [dict(os.environ)]
dicts += [Config.read_yaml(path) for path in default_config_paths]
final = merge_dict(dicts)
return Config(**final)
def update_via_cli(self, project_path, project_name, inc, reqa_file, max_auto_summarize_code):
"""update config via cli"""
# Use in the PrepareDocuments action according to Section 2.2.3.5.1 of RFC 135.
if project_path:
inc = True
project_name = project_name or Path(project_path).name
self.project_path = project_path
self.project_name = project_name
self.inc = inc
self.reqa_file = reqa_file
self.max_auto_summarize_code = max_auto_summarize_code
def get_llm_config(self, name: Optional[str] = None) -> LLMConfig:
"""Get LLM instance by name"""
if name is None:
# Use the first LLM as default
name = list(self.llm.keys())[0]
if name not in self.llm:
raise ValueError(f"LLM {name} not found in config")
return self.llm[name]
def get_openai_llm(self, name: Optional[str] = None) -> LLMConfig:
"""Get OpenAI LLMConfig by name. If no OpenAI, raise Exception"""
if name is None:
# Use the first OpenAI LLM as default
name = [k for k, v in self.llm.items() if v.api_type == LLMType.OPENAI][0]
if name not in self.llm:
raise ValueError(f"OpenAI LLM {name} not found in config")
return self.llm[name]
def merge_dict(dicts: Iterable[Dict]) -> Dict:
"""Merge multiple dicts into one, with the latter dict overwriting the former"""
result = {}
for dictionary in dicts:
result.update(dictionary)
return result
config = Config.default()

View file

@ -0,0 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/1/4 16:33
@Author : alexanderwu
@File : __init__.py
"""

View file

@ -0,0 +1,20 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/1/4 19:06
@Author : alexanderwu
@File : browser_config.py
"""
from typing import Literal
from metagpt.tools import WebBrowserEngineType
from metagpt.utils.yaml_model import YamlModel
class BrowserConfig(YamlModel):
"""Config for Browser"""
engine: WebBrowserEngineType = WebBrowserEngineType.PLAYWRIGHT
browser: Literal["chrome", "firefox", "edge", "ie"] = "chrome"
driver: Literal["chromium", "firefox", "webkit"] = "chromium"
path: str = ""

View file

@ -0,0 +1,74 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/1/4 16:33
@Author : alexanderwu
@File : llm_config.py
"""
from enum import Enum
from typing import Optional
from pydantic import field_validator
from metagpt.utils.yaml_model import YamlModel
class LLMType(Enum):
OPENAI = "openai"
ANTHROPIC = "anthropic"
SPARK = "spark"
ZHIPUAI = "zhipuai"
FIREWORKS = "fireworks"
OPEN_LLM = "open_llm"
GEMINI = "gemini"
METAGPT = "metagpt"
AZURE_OPENAI = "azure"
OLLAMA = "ollama"
class LLMConfig(YamlModel):
"""Config for LLM
OpenAI: https://github.com/openai/openai-python/blob/main/src/openai/resources/chat/completions.py#L681
Optional Fields in pydantic: https://docs.pydantic.dev/latest/migration/#required-optional-and-nullable-fields
"""
api_key: str
api_type: LLMType = LLMType.OPENAI
base_url: str = "https://api.openai.com/v1"
api_version: Optional[str] = None
model: Optional[str] = None # also stands for DEPLOYMENT_NAME
# For Spark(Xunfei), maybe remove later
app_id: Optional[str] = None
api_secret: Optional[str] = None
domain: Optional[str] = None
# For Chat Completion
max_token: int = 4096
temperature: float = 0.0
top_p: float = 1.0
top_k: int = 0
repetition_penalty: float = 1.0
stop: Optional[str] = None
presence_penalty: float = 0.0
frequency_penalty: float = 0.0
best_of: Optional[int] = None
n: Optional[int] = None
stream: bool = False
logprobs: Optional[bool] = None # https://cookbook.openai.com/examples/using_logprobs
top_logprobs: Optional[int] = None
timeout: int = 60
# For Network
proxy: Optional[str] = None
# Cost Control
calc_usage: bool = True
@field_validator("api_key")
@classmethod
def check_llm_key(cls, v):
if v in ["", None, "YOUR_API_KEY"]:
raise ValueError("Please set your API key in config.yaml")
return v

View file

@ -0,0 +1,18 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/1/4 19:07
@Author : alexanderwu
@File : mermaid_config.py
"""
from typing import Literal
from metagpt.utils.yaml_model import YamlModel
class MermaidConfig(YamlModel):
"""Config for Mermaid"""
engine: Literal["nodejs", "ink", "playwright", "pyppeteer"] = "nodejs"
path: str = ""
puppeteer_config: str = "" # Only for nodejs engine

View file

@ -0,0 +1,26 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/1/4 19:06
@Author : alexanderwu
@File : redis_config.py
"""
from metagpt.utils.yaml_model import YamlModelWithoutDefault
class RedisConfig(YamlModelWithoutDefault):
host: str
port: int
username: str = ""
password: str
db: str
def to_url(self):
return f"redis://{self.host}:{self.port}"
def to_kwargs(self):
return {
"username": self.username,
"password": self.password,
"db": self.db,
}

View file

@ -0,0 +1,15 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/1/4 19:07
@Author : alexanderwu
@File : s3_config.py
"""
from metagpt.utils.yaml_model import YamlModelWithoutDefault
class S3Config(YamlModelWithoutDefault):
access_key: str
secret_key: str
endpoint: str
bucket: str

View file

@ -0,0 +1,17 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/1/4 19:06
@Author : alexanderwu
@File : search_config.py
"""
from metagpt.tools import SearchEngineType
from metagpt.utils.yaml_model import YamlModel
class SearchConfig(YamlModel):
"""Config for Search"""
api_key: str
api_type: SearchEngineType = SearchEngineType.SERPAPI_GOOGLE
cse_id: str = "" # for google

View file

@ -0,0 +1,38 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/1/4 19:09
@Author : alexanderwu
@File : workspace_config.py
"""
from datetime import datetime
from pathlib import Path
from uuid import uuid4
from pydantic import field_validator, model_validator
from metagpt.const import DEFAULT_WORKSPACE_ROOT
from metagpt.utils.yaml_model import YamlModel
class WorkspaceConfig(YamlModel):
path: Path = DEFAULT_WORKSPACE_ROOT
use_uid: bool = False
uid: str = ""
@field_validator("path")
@classmethod
def check_workspace_path(cls, v):
if isinstance(v, str):
v = Path(v)
return v
@model_validator(mode="after")
def check_uid_and_update_path(self):
if self.use_uid and not self.uid:
self.uid = f"{datetime.now().strftime('%Y%m%d%H%M%S')}-{uuid4().hex[-8:]}"
self.path = self.path / self.uid
# Create workspace path if not exists
self.path.mkdir(parents=True, exist_ok=True)
return self

55
metagpt/context.py Normal file
View file

@ -0,0 +1,55 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/1/4 16:32
@Author : alexanderwu
@File : context.py
"""
import os
from pathlib import Path
from typing import Dict, Optional
from pydantic import BaseModel
from metagpt.config2 import Config
from metagpt.const import OPTIONS
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import get_llm
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.git_repository import GitRepository
class Context(BaseModel):
kwargs: Dict = {}
config: Config = Config.default()
git_repo: Optional[GitRepository] = None
src_workspace: Optional[Path] = None
cost_manager: CostManager = CostManager()
@property
def options(self):
"""Return all key-values"""
return OPTIONS.get()
def new_environ(self):
"""Return a new os.environ object"""
env = os.environ.copy()
i = self.options
env.update({k: v for k, v in i.items() if isinstance(v, str)})
return env
def llm(self, name: Optional[str] = None) -> BaseLLM:
"""Return a LLM instance"""
llm = get_llm(self.config.get_llm_config(name))
if llm.cost_manager is None:
llm.cost_manager = self.cost_manager
return llm
# Global context
context = Context()
if __name__ == "__main__":
print(context.model_dump_json(indent=4))
print(context.config.get_openai_llm())

View file

@ -9,14 +9,13 @@ import asyncio
from pathlib import Path
from typing import Optional
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain_core.embeddings import Embeddings
from metagpt.config import CONFIG
from metagpt.document import IndexableDocument
from metagpt.document_store.base_store import LocalStore
from metagpt.logs import logger
from metagpt.utils.embedding import get_embedding
class FaissStore(LocalStore):
@ -25,9 +24,7 @@ class FaissStore(LocalStore):
):
self.meta_col = meta_col
self.content_col = content_col
self.embedding = embedding or OpenAIEmbeddings(
openai_api_key=CONFIG.openai_api_key, openai_api_base=CONFIG.openai_base_url
)
self.embedding = embedding or get_embedding()
super().__init__(raw_data, cache_dir)
def _load(self) -> Optional["FaissStore"]:

View file

@ -17,7 +17,7 @@ from typing import Iterable, Set
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
from metagpt.config import CONFIG
from metagpt.context import Context
from metagpt.logs import logger
from metagpt.roles.role import Role
from metagpt.schema import Message
@ -35,6 +35,7 @@ class Environment(BaseModel):
roles: dict[str, SerializeAsAny[Role]] = Field(default_factory=dict, validate_default=True)
members: dict[Role, Set] = Field(default_factory=dict, exclude=True)
history: str = "" # For debug
context: Context = Field(default_factory=Context, exclude=True)
@model_validator(mode="after")
def init_roles(self):
@ -85,6 +86,7 @@ class Environment(BaseModel):
"""
self.roles[role.profile] = role
role.set_env(self)
role.context = self.context
def add_roles(self, roles: Iterable[Role]):
"""增加一批在当前环境的角色
@ -95,6 +97,7 @@ class Environment(BaseModel):
for role in roles: # setup system message with roles
role.set_env(self)
role.context = self.context
def publish_message(self, message: Message, peekable: bool = True) -> bool:
"""
@ -162,7 +165,6 @@ class Environment(BaseModel):
"""Set the labels for message to be consumed by the object"""
self.members[obj] = tags
@staticmethod
def archive(auto_archive=True):
if auto_archive and CONFIG.git_repo:
CONFIG.git_repo.archive()
def archive(self, auto_archive=True):
if auto_archive and self.context.git_repo:
self.context.git_repo.archive()

View file

@ -8,33 +8,37 @@
"""
import base64
from metagpt.config import CONFIG
from metagpt.config2 import Config
from metagpt.const import BASE64_FORMAT
from metagpt.llm import LLM
from metagpt.tools.metagpt_text_to_image import oas3_metagpt_text_to_image
from metagpt.tools.openai_text_to_image import oas3_openai_text_to_image
from metagpt.utils.s3 import S3
async def text_to_image(text, size_type: str = "512x512", openai_api_key="", model_url="", **kwargs):
async def text_to_image(text, size_type: str = "512x512", model_url="", config: Config = None):
"""Text to image
:param text: The text used for image conversion.
:param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`
:param size_type: If using OPENAI, the available size options are ['256x256', '512x512', '1024x1024'], while for MetaGPT, the options are ['512x512', '512x768'].
:param model_url: MetaGPT model url
:param config: Config
:return: The image data is returned in Base64 encoding.
"""
image_declaration = "data:image/png;base64,"
if CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL or model_url:
if model_url:
binary_data = await oas3_metagpt_text_to_image(text, size_type, model_url)
elif CONFIG.OPENAI_API_KEY or openai_api_key:
binary_data = await oas3_openai_text_to_image(text, size_type)
elif oai_llm := config.get_openai_llm():
binary_data = await oas3_openai_text_to_image(text, size_type, LLM(oai_llm))
else:
raise ValueError("Missing necessary parameters.")
base64_data = base64.b64encode(binary_data).decode("utf-8")
s3 = S3()
url = await s3.cache(data=base64_data, file_ext=".png", format=BASE64_FORMAT) if s3.is_valid else ""
assert config.s3, "S3 config is required."
s3 = S3(config.s3)
url = await s3.cache(data=base64_data, file_ext=".png", format=BASE64_FORMAT)
if url:
return f"![{text}]({url})"
return image_declaration + base64_data if base64_data else ""

View file

@ -48,7 +48,7 @@ async def text_to_speech(
audio_declaration = "data:audio/wav;base64,"
base64_data = await oas3_azsure_tts(text, lang, voice, style, role, subscription_key, region)
s3 = S3()
url = await s3.cache(data=base64_data, file_ext=".wav", format=BASE64_FORMAT) if s3.is_valid else ""
url = await s3.cache(data=base64_data, file_ext=".wav", format=BASE64_FORMAT)
if url:
return f"[{text}]({url})"
return audio_declaration + base64_data if base64_data else base64_data
@ -60,7 +60,7 @@ async def text_to_speech(
text=text, app_id=iflytek_app_id, api_key=iflytek_api_key, api_secret=iflytek_api_secret
)
s3 = S3()
url = await s3.cache(data=base64_data, file_ext=".mp3", format=BASE64_FORMAT) if s3.is_valid else ""
url = await s3.cache(data=base64_data, file_ext=".mp3", format=BASE64_FORMAT)
if url:
return f"[{text}]({url})"
return audio_declaration + base64_data if base64_data else base64_data

View file

@ -8,17 +8,10 @@
from typing import Optional
from metagpt.config import CONFIG, LLMProviderEnum
from metagpt.context import context
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.human_provider import HumanProvider
from metagpt.provider.llm_provider_registry import LLM_REGISTRY
_ = HumanProvider() # Avoid pre-commit error
def LLM(provider: Optional[LLMProviderEnum] = None) -> BaseLLM:
"""get the default llm provider"""
if provider is None:
provider = CONFIG.get_default_llm_provider_enum()
return LLM_REGISTRY.get_provider(provider)
def LLM(name: Optional[str] = None) -> BaseLLM:
"""get the default llm provider if name is None"""
return context.llm(name)

View file

@ -14,6 +14,7 @@ from metagpt.provider.openai_api import OpenAILLM
from metagpt.provider.zhipuai_api import ZhiPuAILLM
from metagpt.provider.azure_openai_api import AzureOpenAILLM
from metagpt.provider.metagpt_api import MetaGPTLLM
from metagpt.provider.human_provider import HumanProvider
__all__ = [
"FireworksLLM",
@ -24,4 +25,5 @@ __all__ = [
"AzureOpenAILLM",
"MetaGPTLLM",
"OllamaLLM",
"HumanProvider",
]

View file

@ -9,12 +9,15 @@
import anthropic
from anthropic import Anthropic, AsyncAnthropic
from metagpt.config import CONFIG
from metagpt.configs.llm_config import LLMConfig
class Claude2:
def __init__(self, config: LLMConfig = None):
self.config = config
def ask(self, prompt: str) -> str:
client = Anthropic(api_key=CONFIG.anthropic_api_key)
client = Anthropic(api_key=self.config.api_key)
res = client.completions.create(
model="claude-2",
@ -24,7 +27,7 @@ class Claude2:
return res.completion
async def aask(self, prompt: str) -> str:
aclient = AsyncAnthropic(api_key=CONFIG.anthropic_api_key)
aclient = AsyncAnthropic(api_key=self.config.api_key)
res = await aclient.completions.create(
model="claude-2",

View file

@ -13,12 +13,12 @@
from openai import AsyncAzureOpenAI
from openai._base_client import AsyncHttpxClientWrapper
from metagpt.config import LLMProviderEnum
from metagpt.configs.llm_config import LLMType
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import OpenAILLM
@register_provider(LLMProviderEnum.AZURE_OPENAI)
@register_provider(LLMType.AZURE_OPENAI)
class AzureOpenAILLM(OpenAILLM):
"""
Check https://platform.openai.com/examples for examples

View file

@ -8,15 +8,30 @@
"""
import json
from abc import ABC, abstractmethod
from typing import Optional
from typing import Optional, Union
from openai import AsyncOpenAI
from metagpt.configs.llm_config import LLMConfig
from metagpt.schema import Message
from metagpt.utils.cost_manager import CostManager
class BaseLLM(ABC):
"""LLM API abstract class, requiring all inheritors to provide a series of standard capabilities"""
config: LLMConfig
use_system_prompt: bool = True
system_prompt = "You are a helpful assistant."
# OpenAI / Azure / Others
aclient: Optional[Union[AsyncOpenAI]] = None
cost_manager: Optional[CostManager] = None
@abstractmethod
def __init__(self, config: LLMConfig = None):
pass
def _user_msg(self, msg: str) -> dict[str, str]:
return {"role": "user", "content": msg}
@ -63,10 +78,9 @@ class BaseLLM(ABC):
context.append(self._assistant_msg(rsp_text))
return self._extract_assistant_rsp(context)
async def aask_code(self, msgs: list[str], timeout=3) -> str:
async def aask_code(self, messages: Union[str, Message, list[dict]], timeout=3) -> dict:
"""FIXME: No code segment filtering has been done here, and all results are actually displayed"""
rsp_text = await self.aask_batch(msgs, timeout=timeout)
return rsp_text
raise NotImplementedError
@abstractmethod
async def acompletion(self, messages: list[dict], timeout=3):

View file

@ -15,7 +15,7 @@ from tenacity import (
wait_random_exponential,
)
from metagpt.config import CONFIG, Config, LLMProviderEnum
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.logs import logger
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import OpenAILLM, log_and_reraise
@ -64,27 +64,18 @@ class FireworksCostManager(CostManager):
token_costs = self.model_grade_token_costs(model)
cost = (prompt_tokens * token_costs["prompt"] + completion_tokens * token_costs["completion"]) / 1000000
self.total_cost += cost
max_budget = CONFIG.max_budget if CONFIG.max_budget else CONFIG.cost_manager.max_budget
logger.info(
f"Total running cost: ${self.total_cost:.4f} | Max budget: ${max_budget:.3f} | "
f"Total running cost: ${self.total_cost:.4f}"
f"Current cost: ${cost:.4f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
)
CONFIG.total_cost = self.total_cost
@register_provider(LLMProviderEnum.FIREWORKS)
@register_provider(LLMType.FIREWORKS)
class FireworksLLM(OpenAILLM):
def __init__(self):
self.config: Config = CONFIG
self.__init_fireworks()
def __init__(self, config: LLMConfig = None):
super().__init__(config=config)
self.auto_max_tokens = False
self._cost_manager = FireworksCostManager()
def __init_fireworks(self):
self.is_azure = False
self.rpm = int(self.config.get("RPM", 10))
self._init_client()
self.model = self.config.fireworks_api_model # `self.model` should after `_make_client` to rewrite it
self.cost_manager = FireworksCostManager()
def _make_client_kwargs(self) -> dict:
kwargs = dict(api_key=self.config.fireworks_api_key, base_url=self.config.fireworks_api_base)
@ -94,14 +85,14 @@ class FireworksLLM(OpenAILLM):
if self.config.calc_usage and usage:
try:
# use FireworksCostManager not CONFIG.cost_manager
self._cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
except Exception as e:
logger.error(f"updating costs failed!, exp: {e}")
def get_costs(self) -> Costs:
return self._cost_manager.get_costs()
return self.cost_manager.get_costs()
async def _achat_completion_stream(self, messages: list[dict]) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
**self._cons_kwargs(messages), stream=True
)

View file

@ -19,7 +19,7 @@ from tenacity import (
wait_random_exponential,
)
from metagpt.config import CONFIG, LLMProviderEnum
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
@ -41,21 +41,21 @@ class GeminiGenerativeModel(GenerativeModel):
return await self._async_client.count_tokens(model=self.model_name, contents=contents)
@register_provider(LLMProviderEnum.GEMINI)
@register_provider(LLMType.GEMINI)
class GeminiLLM(BaseLLM):
"""
Refs to `https://ai.google.dev/tutorials/python_quickstart`
"""
def __init__(self):
def __init__(self, config: LLMConfig = None):
self.use_system_prompt = False # google gemini has no system prompt when use api
self.__init_gemini(CONFIG)
self.__init_gemini(config)
self.model = "gemini-pro" # so far only one model
self.llm = GeminiGenerativeModel(model_name=self.model)
def __init_gemini(self, config: CONFIG):
genai.configure(api_key=config.gemini_api_key)
def __init_gemini(self, config: LLMConfig):
genai.configure(api_key=config.api_key)
def _user_msg(self, msg: str) -> dict[str, str]:
# Not to change BaseLLM default functions but update with Gemini's conversation format.
@ -71,11 +71,11 @@ class GeminiLLM(BaseLLM):
def _update_costs(self, usage: dict):
"""update each request's token cost"""
if CONFIG.calc_usage:
if self.config.calc_usage:
try:
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
CONFIG.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
except Exception as e:
logger.error(f"google gemini updats costs failed! exp: {e}")
@ -108,7 +108,7 @@ class GeminiLLM(BaseLLM):
self._update_costs(usage)
return resp
async def acompletion(self, messages: list[dict]) -> dict:
async def acompletion(self, messages: list[dict], timeout=3) -> dict:
return await self._achat_completion(messages)
async def _achat_completion_stream(self, messages: list[dict]) -> str:

View file

@ -5,6 +5,7 @@ Author: garylin2099
"""
from typing import Optional
from metagpt.configs.llm_config import LLMConfig
from metagpt.logs import logger
from metagpt.provider.base_llm import BaseLLM
@ -14,6 +15,9 @@ class HumanProvider(BaseLLM):
This enables replacing LLM anywhere in the framework with a human, thus introducing human interaction
"""
def __init__(self, config: LLMConfig = None):
pass
def ask(self, msg: str, timeout=3) -> str:
logger.info("It's your turn, please type in your response. You may also refer to the context below")
rsp = input(msg)

View file

@ -5,7 +5,8 @@
@Author : alexanderwu
@File : llm_provider_registry.py
"""
from metagpt.config import LLMProviderEnum
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.provider.base_llm import BaseLLM
class LLMProviderRegistry:
@ -15,13 +16,9 @@ class LLMProviderRegistry:
def register(self, key, provider_cls):
self.providers[key] = provider_cls
def get_provider(self, enum: LLMProviderEnum):
def get_provider(self, enum: LLMType):
"""get provider instance according to the enum"""
return self.providers[enum]()
# Registry instance
LLM_REGISTRY = LLMProviderRegistry()
return self.providers[enum]
def register_provider(key):
@ -32,3 +29,12 @@ def register_provider(key):
return cls
return decorator
def get_llm(config: LLMConfig) -> BaseLLM:
"""get the default llm provider"""
return LLM_REGISTRY.get_provider(config.api_type)(config)
# Registry instance
LLM_REGISTRY = LLMProviderRegistry()

View file

@ -5,12 +5,11 @@
@File : metagpt_api.py
@Desc : MetaGPT LLM provider.
"""
from metagpt.config import LLMProviderEnum
from metagpt.configs.llm_config import LLMType
from metagpt.provider import OpenAILLM
from metagpt.provider.llm_provider_registry import register_provider
@register_provider(LLMProviderEnum.METAGPT)
@register_provider(LLMType.METAGPT)
class MetaGPTLLM(OpenAILLM):
def __init__(self):
super().__init__()
pass

View file

@ -13,48 +13,33 @@ from tenacity import (
wait_random_exponential,
)
from metagpt.config import CONFIG, LLMProviderEnum
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.const import LLM_API_TIMEOUT
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.general_api_requestor import GeneralAPIRequestor
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import log_and_reraise
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.cost_manager import TokenCostManager
class OllamaCostManager(CostManager):
def update_cost(self, prompt_tokens, completion_tokens, model):
"""
Update the total cost, prompt tokens, and completion tokens.
"""
self.total_prompt_tokens += prompt_tokens
self.total_completion_tokens += completion_tokens
max_budget = CONFIG.max_budget if CONFIG.max_budget else CONFIG.cost_manager.max_budget
logger.info(
f"Max budget: ${max_budget:.3f} | "
f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
)
CONFIG.total_cost = self.total_cost
@register_provider(LLMProviderEnum.OLLAMA)
@register_provider(LLMType.OLLAMA)
class OllamaLLM(BaseLLM):
"""
Refs to `https://github.com/jmorganca/ollama/blob/main/docs/api.md#generate-a-chat-completion`
"""
def __init__(self):
self.__init_ollama(CONFIG)
self.client = GeneralAPIRequestor(base_url=CONFIG.ollama_api_base)
def __init__(self, config: LLMConfig = None):
self.__init_ollama(config)
self.client = GeneralAPIRequestor(base_url=config.api_base)
self.suffix_url = "/chat"
self.http_method = "post"
self.use_system_prompt = False
self._cost_manager = OllamaCostManager()
self._cost_manager = TokenCostManager()
def __init_ollama(self, config: CONFIG):
assert config.ollama_api_base
self.model = config.ollama_api_model
def __init_ollama(self, config: LLMConfig):
assert config.api_base
self.model = config.model
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream}
@ -62,7 +47,7 @@ class OllamaLLM(BaseLLM):
def _update_costs(self, usage: dict):
"""update each request's token cost"""
if CONFIG.calc_usage:
if self.config.calc_usage:
try:
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))

View file

@ -4,56 +4,27 @@
from openai.types import CompletionUsage
from metagpt.config import CONFIG, Config, LLMProviderEnum
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.logs import logger
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import OpenAILLM
from metagpt.utils.cost_manager import CostManager, Costs
from metagpt.utils.cost_manager import Costs, TokenCostManager
from metagpt.utils.token_counter import count_message_tokens, count_string_tokens
class OpenLLMCostManager(CostManager):
"""open llm model is self-host, it's free and without cost"""
def update_cost(self, prompt_tokens, completion_tokens, model):
"""
Update the total cost, prompt tokens, and completion tokens.
Args:
prompt_tokens (int): The number of tokens used in the prompt.
completion_tokens (int): The number of tokens used in the completion.
model (str): The model used for the API call.
"""
self.total_prompt_tokens += prompt_tokens
self.total_completion_tokens += completion_tokens
max_budget = CONFIG.max_budget if CONFIG.max_budget else CONFIG.cost_manager.max_budget
logger.info(
f"Max budget: ${max_budget:.3f} | reference "
f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
)
@register_provider(LLMProviderEnum.OPEN_LLM)
@register_provider(LLMType.OPEN_LLM)
class OpenLLM(OpenAILLM):
def __init__(self):
self.config: Config = CONFIG
self.__init_openllm()
self.auto_max_tokens = False
self._cost_manager = OpenLLMCostManager()
def __init_openllm(self):
self.is_azure = False
self.rpm = int(self.config.get("RPM", 10))
self._init_client()
self.model = self.config.open_llm_api_model # `self.model` should after `_make_client` to rewrite it
def __init__(self, config: LLMConfig):
super().__init__(config)
self._cost_manager = TokenCostManager()
def _make_client_kwargs(self) -> dict:
kwargs = dict(api_key="sk-xxx", base_url=self.config.open_llm_api_base)
kwargs = dict(api_key="sk-xxx", base_url=self.config.base_url)
return kwargs
def _calc_usage(self, messages: list[dict], rsp: str) -> CompletionUsage:
usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
if not CONFIG.calc_usage:
if not self.config.calc_usage:
return usage
try:

View file

@ -10,7 +10,7 @@
"""
import json
from typing import AsyncIterator, Union
from typing import AsyncIterator, Optional, Union
from openai import APIConnectionError, AsyncOpenAI, AsyncStream
from openai._base_client import AsyncHttpxClientWrapper
@ -24,13 +24,13 @@ from tenacity import (
wait_random_exponential,
)
from metagpt.config import CONFIG, Config, LLMProviderEnum
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA, GENERAL_TOOL_CHOICE
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.schema import Message
from metagpt.utils.cost_manager import Costs
from metagpt.utils.cost_manager import CostManager, Costs
from metagpt.utils.exceptions import handle_exception
from metagpt.utils.token_counter import (
count_message_tokens,
@ -50,18 +50,19 @@ See FAQ 5.8
raise retry_state.outcome.exception()
@register_provider(LLMProviderEnum.OPENAI)
@register_provider(LLMType.OPENAI)
class OpenAILLM(BaseLLM):
"""Check https://platform.openai.com/examples for examples"""
def __init__(self):
self.config: Config = CONFIG
self._init_openai()
def __init__(self, config: LLMConfig = None):
self.config = config
self._init_model()
self._init_client()
self.auto_max_tokens = False
self.cost_manager: Optional[CostManager] = None
def _init_openai(self):
self.model = self.config.OPENAI_API_MODEL # Used in _calc_usage & _cons_kwargs
def _init_model(self):
self.model = self.config.model # Used in _calc_usage & _cons_kwargs
def _init_client(self):
"""https://github.com/openai/openai-python#async-usage"""
@ -69,7 +70,7 @@ class OpenAILLM(BaseLLM):
self.aclient = AsyncOpenAI(**kwargs)
def _make_client_kwargs(self) -> dict:
kwargs = {"api_key": self.config.openai_api_key, "base_url": self.config.openai_base_url}
kwargs = {"api_key": self.config.api_key, "base_url": self.config.base_url}
# to use proxy, openai v1 needs http_client
if proxy_params := self._get_proxy_params():
@ -79,10 +80,10 @@ class OpenAILLM(BaseLLM):
def _get_proxy_params(self) -> dict:
params = {}
if self.config.openai_proxy:
params = {"proxies": self.config.openai_proxy}
if self.config.openai_base_url:
params["base_url"] = self.config.openai_base_url
if self.config.proxy:
params = {"proxies": self.config.proxy}
if self.config.base_url:
params["base_url"] = self.config.base_url
return params
@ -103,7 +104,7 @@ class OpenAILLM(BaseLLM):
"stop": None,
"temperature": 0.3,
"model": self.model,
"timeout": max(CONFIG.timeout, timeout),
"timeout": max(self.config.timeout, timeout),
}
if extra_kwargs:
kwargs.update(extra_kwargs)
@ -205,7 +206,7 @@ class OpenAILLM(BaseLLM):
def _calc_usage(self, messages: list[dict], rsp: str) -> CompletionUsage:
usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
if not CONFIG.calc_usage:
if not self.config.calc_usage:
return usage
try:
@ -218,16 +219,16 @@ class OpenAILLM(BaseLLM):
@handle_exception
def _update_costs(self, usage: CompletionUsage):
if CONFIG.calc_usage and usage:
CONFIG.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
if self.config.calc_usage and usage:
self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
def get_costs(self) -> Costs:
return CONFIG.cost_manager.get_costs()
return self.cost_manager.get_costs()
def _get_max_tokens(self, messages: list[dict]):
if not self.auto_max_tokens:
return CONFIG.max_tokens_rsp
return get_max_completion_tokens(messages, self.model, CONFIG.max_tokens_rsp)
return self.config.max_token
return get_max_completion_tokens(messages, self.model, self.config.max_tokens)
@handle_exception
async def amoderation(self, content: Union[str, list[str]]):

View file

@ -16,15 +16,16 @@ from wsgiref.handlers import format_date_time
import websocket # 使用websocket_client
from metagpt.config import CONFIG, LLMProviderEnum
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.logs import logger
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
@register_provider(LLMProviderEnum.SPARK)
@register_provider(LLMType.SPARK)
class SparkLLM(BaseLLM):
def __init__(self):
def __init__(self, config: LLMConfig = None):
self.config = config
logger.warning("当前方法无法支持异步运行。当你使用acompletion时并不能并行访问。")
def get_choice_text(self, rsp: dict) -> str:
@ -33,12 +34,12 @@ class SparkLLM(BaseLLM):
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
# 不支持
logger.error("该功能禁用。")
w = GetMessageFromWeb(messages)
w = GetMessageFromWeb(messages, self.config)
return w.run()
async def acompletion(self, messages: list[dict], timeout=3):
# 不支持异步
w = GetMessageFromWeb(messages)
w = GetMessageFromWeb(messages, self.config)
return w.run()
@ -89,14 +90,14 @@ class GetMessageFromWeb:
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释比对相同参数时生成的url与自己代码生成的url是否一致
return url
def __init__(self, text):
def __init__(self, text, config):
self.text = text
self.ret = ""
self.spark_appid = CONFIG.spark_appid
self.spark_api_secret = CONFIG.spark_api_secret
self.spark_api_key = CONFIG.spark_api_key
self.domain = CONFIG.domain
self.spark_url = CONFIG.spark_url
self.spark_appid = config.app_id
self.spark_api_secret = config.api_secret
self.spark_api_key = config.api_key
self.domain = config.domain
self.spark_url = config.base_url
def on_message(self, ws, message):
data = json.loads(message)

View file

@ -16,7 +16,7 @@ from tenacity import (
wait_random_exponential,
)
from metagpt.config import CONFIG, LLMProviderEnum
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
@ -31,27 +31,27 @@ class ZhiPuEvent(Enum):
FINISH = "finish"
@register_provider(LLMProviderEnum.ZHIPUAI)
@register_provider(LLMType.ZHIPUAI)
class ZhiPuAILLM(BaseLLM):
"""
Refs to `https://open.bigmodel.cn/dev/api#chatglm_turbo`
From now, there is only one model named `chatglm_turbo`
"""
def __init__(self):
self.__init_zhipuai(CONFIG)
def __init__(self, config: LLMConfig = None):
self.__init_zhipuai(config)
self.llm = ZhiPuModelAPI
self.model = "chatglm_turbo" # so far only one model, just use it
self.use_system_prompt: bool = False # zhipuai has no system prompt when use api
def __init_zhipuai(self, config: CONFIG):
assert config.zhipuai_api_key
zhipuai.api_key = config.zhipuai_api_key
def __init_zhipuai(self, config: LLMConfig):
assert config.api_key
zhipuai.api_key = config.api_key
# due to use openai sdk, set the api_key but it will't be used.
# openai.api_key = zhipuai.api_key # due to use openai sdk, set the api_key but it will't be used.
if config.openai_proxy:
if config.proxy:
# FIXME: openai v1.x sdk has no proxy support
openai.proxy = config.openai_proxy
openai.proxy = config.proxy
def _const_kwargs(self, messages: list[dict]) -> dict:
kwargs = {"model": self.model, "prompt": messages, "temperature": 0.3}
@ -59,11 +59,11 @@ class ZhiPuAILLM(BaseLLM):
def _update_costs(self, usage: dict):
"""update each request's token cost"""
if CONFIG.calc_usage:
if self.config.calc_usage:
try:
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
CONFIG.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
self.config.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
except Exception as e:
logger.error(f"zhipuai updats costs failed! exp: {e}")

View file

@ -27,7 +27,6 @@ from typing import Set
from metagpt.actions import Action, WriteCode, WriteCodeReview, WriteTasks
from metagpt.actions.fix_bug import FixBug
from metagpt.actions.summarize_code import SummarizeCode
from metagpt.config import CONFIG
from metagpt.const import (
CODE_SUMMARIES_FILE_REPO,
CODE_SUMMARIES_PDF_FILE_REPO,
@ -80,6 +79,7 @@ class Engineer(Role):
code_todos: list = []
summarize_todos: list = []
next_todo_action: str = ""
n_summarize: int = 0
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
@ -97,7 +97,7 @@ class Engineer(Role):
async def _act_sp_with_cr(self, review=False) -> Set[str]:
changed_files = set()
src_file_repo = CONFIG.git_repo.new_file_repository(CONFIG.src_workspace)
src_file_repo = self.git_repo.new_file_repository(self.src_workspace)
for todo in self.code_todos:
"""
# Select essential information from the historical data to reduce the length of the prompt (summarized from human experience):
@ -153,10 +153,10 @@ class Engineer(Role):
)
async def _act_summarize(self):
code_summaries_file_repo = CONFIG.git_repo.new_file_repository(CODE_SUMMARIES_FILE_REPO)
code_summaries_pdf_file_repo = CONFIG.git_repo.new_file_repository(CODE_SUMMARIES_PDF_FILE_REPO)
code_summaries_file_repo = self.git_repo.new_file_repository(CODE_SUMMARIES_FILE_REPO)
code_summaries_pdf_file_repo = self.git_repo.new_file_repository(CODE_SUMMARIES_PDF_FILE_REPO)
tasks = []
src_relative_path = CONFIG.src_workspace.relative_to(CONFIG.git_repo.workdir)
src_relative_path = self.src_workspace.relative_to(self.git_repo.workdir)
for todo in self.summarize_todos:
summary = await todo.run()
summary_filename = Path(todo.context.design_filename).with_suffix(".md").name
@ -179,8 +179,8 @@ class Engineer(Role):
else:
await code_summaries_file_repo.delete(filename=Path(todo.context.design_filename).name)
logger.info(f"--max-auto-summarize-code={CONFIG.max_auto_summarize_code}")
if not tasks or CONFIG.max_auto_summarize_code == 0:
logger.info(f"--max-auto-summarize-code={self.config.max_auto_summarize_code}")
if not tasks or self.config.max_auto_summarize_code == 0:
return Message(
content="",
role=self.profile,
@ -190,7 +190,7 @@ class Engineer(Role):
)
# The maximum number of times the 'SummarizeCode' action is automatically invoked, with -1 indicating unlimited.
# This parameter is used for debugging the workflow.
CONFIG.max_auto_summarize_code -= 1 if CONFIG.max_auto_summarize_code > 0 else 0
self.n_summarize += 1 if self.config.max_auto_summarize_code > self.n_summarize else 0
return Message(
content=json.dumps(tasks), role=self.profile, cause_by=SummarizeCode, send_to=self, sent_from=self
)
@ -203,8 +203,8 @@ class Engineer(Role):
return False, rsp
async def _think(self) -> Action | None:
if not CONFIG.src_workspace:
CONFIG.src_workspace = CONFIG.git_repo.workdir / CONFIG.git_repo.workdir.name
if not self.src_workspace:
self.src_workspace = self.git_repo.workdir / self.git_repo.workdir.name
write_code_filters = any_to_str_set([WriteTasks, SummarizeCode, FixBug])
summarize_code_filters = any_to_str_set([WriteCode, WriteCodeReview])
if not self.rc.news:
@ -253,11 +253,11 @@ class Engineer(Role):
async def _new_code_actions(self, bug_fix=False):
# Prepare file repos
src_file_repo = CONFIG.git_repo.new_file_repository(CONFIG.src_workspace)
src_file_repo = self.git_repo.new_file_repository(self.src_workspace)
changed_src_files = src_file_repo.all_files if bug_fix else src_file_repo.changed_files
task_file_repo = CONFIG.git_repo.new_file_repository(TASK_FILE_REPO)
task_file_repo = self.git_repo.new_file_repository(TASK_FILE_REPO)
changed_task_files = task_file_repo.changed_files
design_file_repo = CONFIG.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO)
design_file_repo = self.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO)
changed_files = Documents()
# Recode caused by upstream changes.
@ -283,7 +283,7 @@ class Engineer(Role):
changed_files.docs[task_filename] = coding_doc
self.code_todos = [WriteCode(context=i, llm=self.llm) for i in changed_files.docs.values()]
# Code directly modified by the user.
dependency = await CONFIG.git_repo.get_dependency()
dependency = await self.git_repo.get_dependency()
for filename in changed_src_files:
if filename in changed_files.docs:
continue
@ -301,7 +301,7 @@ class Engineer(Role):
self.rc.todo = self.code_todos[0]
async def _new_summarize_actions(self):
src_file_repo = CONFIG.git_repo.new_file_repository(CONFIG.src_workspace)
src_file_repo = self.git_repo.new_file_repository(self.src_workspace)
src_files = src_file_repo.all_files
# Generate a SummarizeCode action for each pair of (system_design_doc, task_doc).
summarizations = defaultdict(list)

View file

@ -9,7 +9,6 @@
from metagpt.actions import UserRequirement, WritePRD
from metagpt.actions.prepare_documents import PrepareDocuments
from metagpt.config import CONFIG
from metagpt.roles.role import Role
from metagpt.utils.common import any_to_name
@ -40,11 +39,11 @@ class ProductManager(Role):
async def _think(self) -> bool:
"""Decide what to do"""
if CONFIG.git_repo and not CONFIG.git_reinit:
if self.git_repo and not self.config.git_reinit:
self._set_state(1)
else:
self._set_state(0)
CONFIG.git_reinit = False
self.context.config.git_reinit = False
self.todo_action = any_to_name(WritePRD)
return bool(self.rc.todo)

View file

@ -15,10 +15,9 @@
of SummarizeCode.
"""
from metagpt.actions import DebugError, RunCode, WriteTest
from metagpt.actions.summarize_code import SummarizeCode
from metagpt.config import CONFIG
from metagpt.config2 import Config
from metagpt.const import (
MESSAGE_ROUTE_TO_NONE,
TEST_CODES_FILE_REPO,
@ -50,13 +49,17 @@ class QaEngineer(Role):
self._watch([SummarizeCode, WriteTest, RunCode, DebugError])
self.test_round = 0
@property
def config(self) -> Config:
return self.context.config
async def _write_test(self, message: Message) -> None:
src_file_repo = CONFIG.git_repo.new_file_repository(CONFIG.src_workspace)
src_file_repo = self.context.git_repo.new_file_repository(self.context.src_workspace)
changed_files = set(src_file_repo.changed_files.keys())
# Unit tests only.
if CONFIG.reqa_file and CONFIG.reqa_file not in changed_files:
changed_files.add(CONFIG.reqa_file)
tests_file_repo = CONFIG.git_repo.new_file_repository(TEST_CODES_FILE_REPO)
if self.config.reqa_file and self.config.reqa_file not in changed_files:
changed_files.add(self.config.reqa_file)
tests_file_repo = self.context.git_repo.new_file_repository(TEST_CODES_FILE_REPO)
for filename in changed_files:
# write tests
if not filename or "test" in filename:
@ -69,7 +72,7 @@ class QaEngineer(Role):
)
logger.info(f"Writing {test_doc.filename}..")
context = TestingContext(filename=test_doc.filename, test_doc=test_doc, code_doc=code_doc)
context = await WriteTest(context=context, llm=self.llm).run()
context = await WriteTest(context=context, _context=self.context, llm=self.llm).run()
await tests_file_repo.save(
filename=context.test_doc.filename,
content=context.test_doc.content,
@ -81,8 +84,8 @@ class QaEngineer(Role):
command=["python", context.test_doc.root_relative_path],
code_filename=context.code_doc.filename,
test_filename=context.test_doc.filename,
working_directory=str(CONFIG.git_repo.workdir),
additional_python_paths=[str(CONFIG.src_workspace)],
working_directory=str(self.context.git_repo.workdir),
additional_python_paths=[str(self.context.src_workspace)],
)
self.publish_message(
Message(
@ -98,17 +101,21 @@ class QaEngineer(Role):
async def _run_code(self, msg):
run_code_context = RunCodeContext.loads(msg.content)
src_doc = await CONFIG.git_repo.new_file_repository(CONFIG.src_workspace).get(run_code_context.code_filename)
src_doc = await self.context.git_repo.new_file_repository(self.context.src_workspace).get(
run_code_context.code_filename
)
if not src_doc:
return
test_doc = await CONFIG.git_repo.new_file_repository(TEST_CODES_FILE_REPO).get(run_code_context.test_filename)
test_doc = await self.context.git_repo.new_file_repository(TEST_CODES_FILE_REPO).get(
run_code_context.test_filename
)
if not test_doc:
return
run_code_context.code = src_doc.content
run_code_context.test_code = test_doc.content
result = await RunCode(context=run_code_context, llm=self.llm).run()
run_code_context.output_filename = run_code_context.test_filename + ".json"
await CONFIG.git_repo.new_file_repository(TEST_OUTPUTS_FILE_REPO).save(
await self.context.git_repo.new_file_repository(TEST_OUTPUTS_FILE_REPO).save(
filename=run_code_context.output_filename,
content=result.model_dump_json(),
dependencies={src_doc.root_relative_path, test_doc.root_relative_path},

View file

@ -32,9 +32,11 @@ from metagpt.actions import Action, ActionOutput
from metagpt.actions.action_node import ActionNode
from metagpt.actions.add_requirement import UserRequirement
from metagpt.const import SERDESER_PATH
from metagpt.llm import LLM, HumanProvider
from metagpt.context import Context
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.memory import Memory
from metagpt.provider import HumanProvider
from metagpt.provider.base_llm import BaseLLM
from metagpt.schema import Message, MessageQueue, SerializationMixin
from metagpt.utils.common import (
@ -148,9 +150,46 @@ class Role(SerializationMixin, is_polymorphic_base=True):
# builtin variables
recovered: bool = False # to tag if a recovered role
latest_observed_msg: Optional[Message] = None # record the latest observed message when interrupted
context: Optional[Context] = Field(default=None, exclude=True)
__hash__ = object.__hash__ # support Role as hashable type in `Environment.members`
@property
def config(self):
return self.context.config
@property
def git_repo(self):
return self.context.git_repo
@git_repo.setter
def git_repo(self, value):
self.context.git_repo = value
@property
def src_workspace(self):
return self.context.src_workspace
@src_workspace.setter
def src_workspace(self, value):
self.context.src_workspace = value
@property
def prompt_schema(self):
return self.context.config.prompt_schema
@property
def project_name(self):
return self.context.config.project_name
@project_name.setter
def project_name(self, value):
self.context.config.project_name = value
@property
def project_path(self):
return self.context.config.project_path
@model_validator(mode="after")
def check_subscription(self):
if not self.subscription:

View file

@ -15,7 +15,6 @@ import aiofiles
from metagpt.actions import UserRequirement
from metagpt.actions.write_teaching_plan import TeachingPlanBlock, WriteTeachingPlanPart
from metagpt.config import CONFIG
from metagpt.logs import logger
from metagpt.roles import Role
from metagpt.schema import Message
@ -81,7 +80,7 @@ class Teacher(Role):
async def save(self, content):
"""Save teaching plan"""
filename = Teacher.new_file_name(self.course_title)
pathname = CONFIG.workspace_path / "teaching_plan"
pathname = self.config.workspace.path / "teaching_plan"
pathname.mkdir(exist_ok=True)
pathname = pathname / filename
try:

View file

@ -35,7 +35,6 @@ from pydantic import (
)
from pydantic_core import core_schema
from metagpt.config import CONFIG
from metagpt.const import (
MESSAGE_ROUTE_CAUSE_BY,
MESSAGE_ROUTE_FROM,
@ -151,12 +150,6 @@ class Document(BaseModel):
"""
return os.path.join(self.root_path, self.filename)
@property
def full_path(self):
if not CONFIG.git_repo:
return None
return str(CONFIG.git_repo.workdir / self.root_path / self.filename)
def __str__(self):
return self.content

View file

@ -5,7 +5,7 @@ from pathlib import Path
import typer
from metagpt.config import CONFIG
from metagpt.config2 import config
app = typer.Typer(add_completion=False)
@ -44,7 +44,7 @@ def startup(
)
from metagpt.team import Team
CONFIG.update_via_cli(project_path, project_name, inc, reqa_file, max_auto_summarize_code)
config.update_via_cli(project_path, project_name, inc, reqa_file, max_auto_summarize_code)
if not recover_path:
company = Team()

View file

@ -15,7 +15,6 @@ from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from metagpt.actions import UserRequirement
from metagpt.config import CONFIG
from metagpt.const import MESSAGE_ROUTE_TO_ALL, SERDESER_PATH
from metagpt.environment import Environment
from metagpt.logs import logger
@ -79,18 +78,20 @@ class Team(BaseModel):
"""Hire roles to cooperate"""
self.env.add_roles(roles)
@property
def cost_manager(self):
"""Get cost manager"""
return self.env.context.cost_manager
def invest(self, investment: float):
"""Invest company. raise NoMoneyException when exceed max_budget."""
self.investment = investment
CONFIG.max_budget = investment
self.cost_manager.max_budget = investment
logger.info(f"Investment: ${investment}.")
@staticmethod
def _check_balance():
if CONFIG.cost_manager.total_cost > CONFIG.cost_manager.max_budget:
raise NoMoneyException(
CONFIG.cost_manager.total_cost, f"Insufficient funds: {CONFIG.cost_manager.max_budget}"
)
def _check_balance(self):
if self.cost_manager.total_cost > self.cost_manager.max_budget:
raise NoMoneyException(self.cost_manager.total_cost, f"Insufficient funds: {self.cost_manager.max_budget}")
def run_project(self, idea, send_to: str = ""):
"""Run a project from publishing user requirement."""

View file

@ -13,7 +13,6 @@ import aiohttp
import requests
from pydantic import BaseModel
from metagpt.config import CONFIG
from metagpt.logs import logger
@ -22,7 +21,7 @@ class MetaGPTText2Image:
"""
:param model_url: Model reset api url
"""
self.model_url = model_url if model_url else CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL
self.model_url = model_url
async def text_2_image(self, text, size_type="512x512"):
"""Text to image
@ -93,6 +92,4 @@ async def oas3_metagpt_text_to_image(text, size_type: str = "512x512", model_url
"""
if not text:
return ""
if not model_url:
model_url = CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL
return await MetaGPTText2Image(model_url).text_2_image(text, size_type=size_type)

View file

@ -10,16 +10,16 @@
import aiohttp
import requests
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.provider.base_llm import BaseLLM
class OpenAIText2Image:
def __init__(self):
def __init__(self, llm: BaseLLM):
"""
:param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`
"""
self._llm = LLM()
self.llm = llm
async def text_2_image(self, text, size_type="1024x1024"):
"""Text to image
@ -29,7 +29,7 @@ class OpenAIText2Image:
:return: The image data is returned in Base64 encoding.
"""
try:
result = await self._llm.aclient.images.generate(prompt=text, n=1, size=size_type)
result = await self.llm.aclient.images.generate(prompt=text, n=1, size=size_type)
except Exception as e:
logger.error(f"An error occurred:{e}")
return ""
@ -57,13 +57,14 @@ class OpenAIText2Image:
# Export
async def oas3_openai_text_to_image(text, size_type: str = "1024x1024"):
async def oas3_openai_text_to_image(text, size_type: str = "1024x1024", llm: BaseLLM = None):
"""Text to image
:param text: The text used for image conversion.
:param size_type: One of ['256x256', '512x512', '1024x1024']
:param llm: LLM instance
:return: The image data is returned in Base64 encoding.
"""
if not text:
return ""
return await OpenAIText2Image().text_2_image(text, size_type=size_type)
return await OpenAIText2Image(llm).text_2_image(text, size_type=size_type)

View file

@ -77,7 +77,7 @@ class SDEngine:
return self.payload
def _save(self, imgs, save_name=""):
save_dir = CONFIG.workspace_path / SD_OUTPUT_FILE_REPO
save_dir = CONFIG.path / SD_OUTPUT_FILE_REPO
if not save_dir.exists():
save_dir.mkdir(parents=True, exist_ok=True)
batch_decode_base64_to_image(imgs, str(save_dir), save_name=save_name)

View file

@ -80,3 +80,20 @@ class CostManager(BaseModel):
def get_costs(self) -> Costs:
"""Get all costs"""
return Costs(self.total_prompt_tokens, self.total_completion_tokens, self.total_cost, self.total_budget)
class TokenCostManager(CostManager):
"""open llm model is self-host, it's free and without cost"""
def update_cost(self, prompt_tokens, completion_tokens, model):
"""
Update the total cost, prompt tokens, and completion tokens.
Args:
prompt_tokens (int): The number of tokens used in the prompt.
completion_tokens (int): The number of tokens used in the completion.
model (str): The model used for the API call.
"""
self.total_prompt_tokens += prompt_tokens
self.total_completion_tokens += completion_tokens
logger.info(f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}")

View file

@ -0,0 +1,16 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/1/4 20:58
@Author : alexanderwu
@File : embedding.py
"""
from langchain_community.embeddings import OpenAIEmbeddings
from metagpt.config2 import config
def get_embedding():
llm = config.get_openai_llm()
embedding = OpenAIEmbeddings(openai_api_key=llm.api_key, openai_api_base=llm.base_url)
return embedding

View file

@ -12,26 +12,25 @@ from datetime import timedelta
import aioredis # https://aioredis.readthedocs.io/en/latest/getting-started/
from metagpt.config import CONFIG
from metagpt.configs.redis_config import RedisConfig
from metagpt.logs import logger
class Redis:
def __init__(self):
def __init__(self, config: RedisConfig = None):
self.config = config
self._client = None
async def _connect(self, force=False):
if self._client and not force:
return True
if not self.is_configured:
return False
try:
self._client = await aioredis.from_url(
f"redis://{CONFIG.REDIS_HOST}:{CONFIG.REDIS_PORT}",
username=CONFIG.REDIS_USER,
password=CONFIG.REDIS_PASSWORD,
db=CONFIG.REDIS_DB,
self.config.to_url(),
username=self.config.username,
password=self.config.password,
db=self.config.db,
)
return True
except Exception as e:
@ -62,18 +61,3 @@ class Redis:
return
await self._client.close()
self._client = None
@property
def is_valid(self) -> bool:
return self._client is not None
@property
def is_configured(self) -> bool:
return bool(
CONFIG.REDIS_HOST
and CONFIG.REDIS_HOST != "YOUR_REDIS_HOST"
and CONFIG.REDIS_PORT
and CONFIG.REDIS_PORT != "YOUR_REDIS_PORT"
and CONFIG.REDIS_DB is not None
and CONFIG.REDIS_PASSWORD is not None
)

View file

@ -8,7 +8,7 @@ from typing import Optional
import aioboto3
import aiofiles
from metagpt.config import CONFIG
from metagpt.config2 import S3Config
from metagpt.const import BASE64_FORMAT
from metagpt.logs import logger
@ -16,13 +16,14 @@ from metagpt.logs import logger
class S3:
"""A class for interacting with Amazon S3 storage."""
def __init__(self):
def __init__(self, config: S3Config):
self.session = aioboto3.Session()
self.config = config
self.auth_config = {
"service_name": "s3",
"aws_access_key_id": CONFIG.S3_ACCESS_KEY,
"aws_secret_access_key": CONFIG.S3_SECRET_KEY,
"endpoint_url": CONFIG.S3_ENDPOINT_URL,
"aws_access_key_id": config.access_key,
"aws_secret_access_key": config.secret_key,
"endpoint_url": config.endpoint,
}
async def upload_file(
@ -139,8 +140,8 @@ class S3:
data = base64.b64decode(data) if format == BASE64_FORMAT else data.encode(encoding="utf-8")
await file.write(data)
bucket = CONFIG.S3_BUCKET
object_pathname = CONFIG.S3_BUCKET or "system"
bucket = self.config.bucket
object_pathname = self.config.bucket or "system"
object_pathname += f"/{object_name}"
object_pathname = os.path.normpath(object_pathname)
await self.upload_file(bucket=bucket, local_path=str(pathname), object_name=object_pathname)
@ -151,20 +152,3 @@ class S3:
logger.exception(f"{e}, stack:{traceback.format_exc()}")
pathname.unlink(missing_ok=True)
return None
@property
def is_valid(self):
return self.is_configured
@property
def is_configured(self) -> bool:
return bool(
CONFIG.S3_ACCESS_KEY
and CONFIG.S3_ACCESS_KEY != "YOUR_S3_ACCESS_KEY"
and CONFIG.S3_SECRET_KEY
and CONFIG.S3_SECRET_KEY != "YOUR_S3_SECRET_KEY"
and CONFIG.S3_ENDPOINT_URL
and CONFIG.S3_ENDPOINT_URL != "YOUR_S3_ENDPOINT_URL"
and CONFIG.S3_BUCKET
and CONFIG.S3_BUCKET != "YOUR_S3_BUCKET"
)

View file

@ -0,0 +1,38 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/1/4 10:18
@Author : alexanderwu
@File : YamlModel.py
"""
from pathlib import Path
from typing import Dict, Optional
import yaml
from pydantic import BaseModel, model_validator
class YamlModel(BaseModel):
extra_fields: Optional[Dict[str, str]] = None
@classmethod
def read_yaml(cls, file_path: Path) -> Dict:
with open(file_path, "r") as file:
return yaml.safe_load(file)
@classmethod
def model_validate_yaml(cls, file_path: Path) -> "YamlModel":
return cls(**cls.read_yaml(file_path))
def model_dump_yaml(self, file_path: Path) -> None:
with open(file_path, "w") as file:
yaml.dump(self.model_dump(), file)
class YamlModelWithoutDefault(YamlModel):
@model_validator(mode="before")
@classmethod
def check_not_default_config(cls, values):
if any(["YOUR" in v for v in values]):
raise ValueError("Please set your S3 config in config.yaml")
return values

View file

@ -10,29 +10,26 @@
import pytest
from metagpt.config import CONFIG
from metagpt.config2 import Config
from metagpt.learn.text_to_image import text_to_image
@pytest.mark.asyncio
async def test_metagpt_llm():
# Prerequisites
assert CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL
assert CONFIG.OPENAI_API_KEY
async def test_metagpt_text_to_image():
config = Config()
assert config.METAGPT_TEXT_TO_IMAGE_MODEL_URL
data = await text_to_image("Panda emoji", size_type="512x512")
data = await text_to_image("Panda emoji", size_type="512x512", model_url=config.METAGPT_TEXT_TO_IMAGE_MODEL_URL)
assert "base64" in data or "http" in data
# Mock session env
old_options = CONFIG.options.copy()
new_options = old_options.copy()
new_options["METAGPT_TEXT_TO_IMAGE_MODEL_URL"] = None
CONFIG.set_context(new_options)
try:
data = await text_to_image("Panda emoji", size_type="512x512")
assert "base64" in data or "http" in data
finally:
CONFIG.set_context(old_options)
@pytest.mark.asyncio
async def test_openai_text_to_image():
config = Config.default()
assert config.get_openai_llm()
data = await text_to_image("Panda emoji", size_type="512x512", config=config)
assert "base64" in data or "http" in data
if __name__ == "__main__":

View file

@ -8,7 +8,7 @@
import pytest
from metagpt.config import LLMProviderEnum
from metagpt.configs.llm_config import LLMType
from metagpt.llm import LLM
from metagpt.memory.brain_memory import BrainMemory
from metagpt.schema import Message
@ -46,7 +46,7 @@ def test_extract_info(input, tag, val):
@pytest.mark.asyncio
@pytest.mark.parametrize("llm", [LLM(provider=LLMProviderEnum.OPENAI), LLM(provider=LLMProviderEnum.METAGPT)])
@pytest.mark.parametrize("llm", [LLM(provider=LLMType.OPENAI), LLM(provider=LLMType.METAGPT)])
async def test_memory_llm(llm):
memory = BrainMemory()
for i in range(500):

View file

@ -3,12 +3,8 @@
# @Desc :
from metagpt.config import CONFIG
from metagpt.provider.azure_openai_api import AzureOpenAILLM
CONFIG.OPENAI_API_VERSION = "xx"
CONFIG.openai_proxy = "http://127.0.0.1:80" # fake value
from metagpt.context import context
def test_azure_openai_api():
_ = AzureOpenAILLM()
_ = context.llm("azure")

View file

@ -8,6 +8,7 @@
import pytest
from metagpt.configs.llm_config import LLMConfig
from metagpt.provider.base_llm import BaseLLM
from metagpt.schema import Message
@ -28,6 +29,9 @@ resp_content = default_chat_resp["choices"][0]["message"]["content"]
class MockBaseLLM(BaseLLM):
def __init__(self, config: LLMConfig = None):
pass
def completion(self, messages: list[dict], timeout=3):
return default_chat_resp
@ -102,5 +106,5 @@ async def test_async_base_llm():
resp = await base_llm.aask_batch([prompt_msg])
assert resp == resp_content
resp = await base_llm.aask_code([prompt_msg])
assert resp == resp_content
# resp = await base_llm.aask_code([prompt_msg])
# assert resp == resp_content

View file

@ -5,10 +5,10 @@
@Author : mashenquan
@File : test_metagpt_api.py
"""
from metagpt.config import LLMProviderEnum
from metagpt.configs.llm_config import LLMType
from metagpt.llm import LLM
def test_llm():
llm = LLM(provider=LLMProviderEnum.METAGPT)
llm = LLM(provider=LLMType.METAGPT)
assert llm

View file

@ -2,18 +2,18 @@ from unittest.mock import Mock
import pytest
from metagpt.config import CONFIG
from metagpt.provider.openai_api import OpenAILLM
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.schema import UserMessage
CONFIG.openai_proxy = None
@pytest.mark.asyncio
async def test_aask_code():
llm = OpenAILLM()
llm = LLM(name="gpt3t")
msg = [{"role": "user", "content": "Write a python hello world code."}]
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
logger.info(rsp)
assert "language" in rsp
assert "code" in rsp
assert len(rsp["code"]) > 0
@ -21,7 +21,7 @@ async def test_aask_code():
@pytest.mark.asyncio
async def test_aask_code_str():
llm = OpenAILLM()
llm = LLM(name="gpt3t")
msg = "Write a python hello world code."
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
assert "language" in rsp
@ -30,8 +30,8 @@ async def test_aask_code_str():
@pytest.mark.asyncio
async def test_aask_code_Message():
llm = OpenAILLM()
async def test_aask_code_message():
llm = LLM(name="gpt3t")
msg = UserMessage("Write a python hello world code.")
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
assert "language" in rsp

View file

@ -28,6 +28,6 @@ def load_existing_repo(path):
def test_repo_set_load():
repo_path = CONFIG.workspace_path / "test_repo"
repo_path = CONFIG.path / "test_repo"
set_existing_repo(repo_path)
load_existing_repo(repo_path)

View file

@ -15,7 +15,6 @@ import pytest
from metagpt.actions import Action
from metagpt.actions.action_node import ActionNode
from metagpt.actions.write_code import WriteCode
from metagpt.config import CONFIG
from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
from metagpt.schema import (
AIMessage,
@ -119,8 +118,6 @@ def test_document():
assert doc.filename == meta_doc.filename
assert meta_doc.content == ""
assert doc.full_path == str(CONFIG.git_repo.workdir / doc.root_path / doc.filename)
@pytest.mark.asyncio
async def test_message_queue():

View file

@ -32,7 +32,7 @@ async def test_azure_tts():
Writing a binary file in Python is similar to writing a regular text file, but you'll work with bytes instead of strings.”
</mstts:express-as>
"""
path = CONFIG.workspace_path / "tts"
path = CONFIG.path / "tts"
path.mkdir(exist_ok=True, parents=True)
filename = path / "girl.wav"
filename.unlink(missing_ok=True)

View file

@ -22,5 +22,5 @@ def test_sd_engine_generate_prompt():
async def test_sd_engine_run_t2i():
sd_engine = SDEngine()
await sd_engine.run_t2i(prompts=["test"])
img_path = CONFIG.workspace_path / "resources" / "SD_Output" / "output_0.png"
img_path = CONFIG.path / "resources" / "SD_Output" / "output_0.png"
assert os.path.exists(img_path)

View file

@ -8,38 +8,19 @@
import pytest
from metagpt.config import CONFIG
from metagpt.config2 import Config
from metagpt.utils.redis import Redis
@pytest.mark.asyncio
async def test_redis():
# Prerequisites
assert CONFIG.REDIS_HOST and CONFIG.REDIS_HOST != "YOUR_REDIS_HOST"
assert CONFIG.REDIS_PORT and CONFIG.REDIS_PORT != "YOUR_REDIS_PORT"
# assert CONFIG.REDIS_USER
assert CONFIG.REDIS_PASSWORD is not None and CONFIG.REDIS_PASSWORD != "YOUR_REDIS_PASSWORD"
assert CONFIG.REDIS_DB is not None and CONFIG.REDIS_DB != "YOUR_REDIS_DB_INDEX, str, 0-based"
redis = Config.default().redis
conn = Redis()
assert not conn.is_valid
conn = Redis(redis)
await conn.set("test", "test", timeout_sec=0)
assert await conn.get("test") == b"test"
await conn.close()
# Mock session env
old_options = CONFIG.options.copy()
new_options = old_options.copy()
new_options["REDIS_HOST"] = "YOUR_REDIS_HOST"
CONFIG.set_context(new_options)
try:
conn = Redis()
await conn.set("test", "test", timeout_sec=0)
assert not await conn.get("test") == b"test"
await conn.close()
finally:
CONFIG.set_context(old_options)
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -11,30 +11,25 @@ from pathlib import Path
import aiofiles
import pytest
from metagpt.config import CONFIG
from metagpt.config2 import Config
from metagpt.utils.s3 import S3
@pytest.mark.asyncio
async def test_s3():
# Prerequisites
assert CONFIG.S3_ACCESS_KEY and CONFIG.S3_ACCESS_KEY != "YOUR_S3_ACCESS_KEY"
assert CONFIG.S3_SECRET_KEY and CONFIG.S3_SECRET_KEY != "YOUR_S3_SECRET_KEY"
assert CONFIG.S3_ENDPOINT_URL and CONFIG.S3_ENDPOINT_URL != "YOUR_S3_ENDPOINT_URL"
# assert CONFIG.S3_SECURE: true # true/false
assert CONFIG.S3_BUCKET and CONFIG.S3_BUCKET != "YOUR_S3_BUCKET"
conn = S3()
assert conn.is_valid
s3 = Config.default().s3
assert s3
conn = S3(s3)
object_name = "unittest.bak"
await conn.upload_file(bucket=CONFIG.S3_BUCKET, local_path=__file__, object_name=object_name)
await conn.upload_file(bucket=s3.bucket, local_path=__file__, object_name=object_name)
pathname = (Path(__file__).parent / uuid.uuid4().hex).with_suffix(".bak")
pathname.unlink(missing_ok=True)
await conn.download_file(bucket=CONFIG.S3_BUCKET, object_name=object_name, local_path=str(pathname))
await conn.download_file(bucket=s3.bucket, object_name=object_name, local_path=str(pathname))
assert pathname.exists()
url = await conn.get_object_url(bucket=CONFIG.S3_BUCKET, object_name=object_name)
url = await conn.get_object_url(bucket=s3.bucket, object_name=object_name)
assert url
bin_data = await conn.get_object(bucket=CONFIG.S3_BUCKET, object_name=object_name)
bin_data = await conn.get_object(bucket=s3.bucket, object_name=object_name)
assert bin_data
async with aiofiles.open(__file__, mode="r", encoding="utf-8") as reader:
data = await reader.read()
@ -42,17 +37,13 @@ async def test_s3():
assert "http" in res
# Mock session env
old_options = CONFIG.options.copy()
new_options = old_options.copy()
new_options["S3_ACCESS_KEY"] = "YOUR_S3_ACCESS_KEY"
CONFIG.set_context(new_options)
s3.access_key = "ABC"
try:
conn = S3()
assert not conn.is_valid
conn = S3(s3)
res = await conn.cache("ABC", ".bak", "script")
assert not res
finally:
CONFIG.set_context(old_options)
except Exception:
pass
if __name__ == "__main__":