mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-02 14:45:17 +02:00
add context and config2
This commit is contained in:
parent
42bb40a0f6
commit
e5d11a046c
76 changed files with 922 additions and 495 deletions
4
config/config2.yaml
Normal file
4
config/config2.yaml
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
llm:
|
||||
gpt3t:
|
||||
api_key: "YOUR_API_KEY"
|
||||
model: "gpt-3.5-turbo-1106"
|
||||
|
|
@ -6,7 +6,7 @@ Author: garylin2099
|
|||
import re
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.config2 import config
|
||||
from metagpt.const import METAGPT_ROOT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles import Role
|
||||
|
|
@ -48,8 +48,8 @@ class CreateAgent(Action):
|
|||
pattern = r"```python(.*)```"
|
||||
match = re.search(pattern, rsp, re.DOTALL)
|
||||
code_text = match.group(1) if match else ""
|
||||
CONFIG.workspace_path.mkdir(parents=True, exist_ok=True)
|
||||
new_file = CONFIG.workspace_path / "agent_created_agent.py"
|
||||
config.workspace.path.mkdir(parents=True, exist_ok=True)
|
||||
new_file = config.workspace.path / "agent_created_agent.py"
|
||||
new_file.write_text(code_text)
|
||||
return code_text
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import asyncio
|
|||
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.config2 import config
|
||||
from metagpt.const import DATA_PATH, EXAMPLE_PATH
|
||||
from metagpt.document_store import FaissStore
|
||||
from metagpt.logs import logger
|
||||
|
|
@ -16,7 +16,8 @@ from metagpt.roles import Sales
|
|||
|
||||
|
||||
def get_store():
|
||||
embedding = OpenAIEmbeddings(openai_api_key=CONFIG.openai_api_key, openai_api_base=CONFIG.openai_base_url)
|
||||
llm = config.get_openai_llm()
|
||||
embedding = OpenAIEmbeddings(openai_api_key=llm.api_key, openai_api_base=llm.base_url)
|
||||
return FaissStore(DATA_PATH / "example.json", embedding=embedding)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from typing import Optional, Union
|
|||
from pydantic import ConfigDict, Field, model_validator
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.context import Context
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import (
|
||||
|
|
@ -33,14 +34,41 @@ class Action(SerializationMixin, is_polymorphic_base=True):
|
|||
prefix: str = "" # aask*时会加上prefix,作为system_message
|
||||
desc: str = "" # for skill manager
|
||||
node: ActionNode = Field(default=None, exclude=True)
|
||||
_context: Optional[Context] = Field(default=None, exclude=True)
|
||||
|
||||
@property
|
||||
def git_repo(self):
|
||||
return self._context.git_repo
|
||||
|
||||
@property
|
||||
def src_workspace(self):
|
||||
return self._context.src_workspace
|
||||
|
||||
@property
|
||||
def prompt_schema(self):
|
||||
return self._context.config.prompt_schema
|
||||
|
||||
@property
|
||||
def project_name(self):
|
||||
return self._context.config.project_name
|
||||
|
||||
@project_name.setter
|
||||
def project_name(self, value):
|
||||
self._context.config.project_name = value
|
||||
|
||||
@property
|
||||
def project_path(self):
|
||||
return self._context.config.project_path
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def set_name_if_empty(cls, values):
|
||||
if "name" not in values or not values["name"]:
|
||||
values["name"] = cls.__name__
|
||||
return values
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _init_with_instruction(cls, values):
|
||||
if "instruction" in values:
|
||||
name = values["name"]
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ from typing import Any, Dict, List, Optional, Tuple, Type
|
|||
from pydantic import BaseModel, create_model, model_validator
|
||||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.llm import BaseLLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess
|
||||
|
|
@ -262,7 +261,7 @@ class ActionNode:
|
|||
output_data_mapping: dict,
|
||||
system_msgs: Optional[list[str]] = None,
|
||||
schema="markdown", # compatible to original format
|
||||
timeout=CONFIG.timeout,
|
||||
timeout=None,
|
||||
) -> (str, BaseModel):
|
||||
"""Use ActionOutput to wrap the output of aask"""
|
||||
content = await self.llm.aask(prompt, system_msgs, timeout=timeout)
|
||||
|
|
@ -294,7 +293,7 @@ class ActionNode:
|
|||
def set_context(self, context):
|
||||
self.set_recursive("context", context)
|
||||
|
||||
async def simple_fill(self, schema, mode, timeout=CONFIG.timeout, exclude=None):
|
||||
async def simple_fill(self, schema, mode, timeout=None, exclude=None):
|
||||
prompt = self.compile(context=self.context, schema=schema, mode=mode, exclude=exclude)
|
||||
|
||||
if schema != "raw":
|
||||
|
|
@ -309,7 +308,7 @@ class ActionNode:
|
|||
|
||||
return self
|
||||
|
||||
async def fill(self, context, llm, schema="json", mode="auto", strgy="simple", timeout=CONFIG.timeout, exclude=[]):
|
||||
async def fill(self, context, llm, schema="json", mode="auto", strgy="simple", timeout=None, exclude=[]):
|
||||
"""Fill the node(s) with mode.
|
||||
|
||||
:param context: Everything we should know when filling node.
|
||||
|
|
|
|||
|
|
@ -9,12 +9,13 @@
|
|||
2. According to Section 2.2.3.1 of RFC 135, replace file data in the message with the file name.
|
||||
"""
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import TEST_CODES_FILE_REPO, TEST_OUTPUTS_FILE_REPO
|
||||
from metagpt.context import Context
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import RunCodeContext, RunCodeResult
|
||||
from metagpt.utils.common import CodeParser
|
||||
|
|
@ -49,8 +50,8 @@ Now you should start rewriting the code:
|
|||
|
||||
|
||||
class DebugError(Action):
|
||||
name: str = "DebugError"
|
||||
context: RunCodeContext = Field(default_factory=RunCodeContext)
|
||||
_context: Optional[Context] = None
|
||||
|
||||
async def run(self, *args, **kwargs) -> str:
|
||||
output_doc = await FileRepository.get_file(
|
||||
|
|
@ -66,7 +67,7 @@ class DebugError(Action):
|
|||
|
||||
logger.info(f"Debug and rewrite {self.context.test_filename}")
|
||||
code_doc = await FileRepository.get_file(
|
||||
filename=self.context.code_filename, relative_path=CONFIG.src_workspace
|
||||
filename=self.context.code_filename, relative_path=self._context.src_workspace
|
||||
)
|
||||
if not code_doc:
|
||||
return ""
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ from typing import Optional
|
|||
|
||||
from metagpt.actions import Action, ActionOutput
|
||||
from metagpt.actions.design_api_an import DESIGN_API_NODE
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import (
|
||||
DATA_API_DESIGN_FILE_REPO,
|
||||
PRDS_FILE_REPO,
|
||||
|
|
@ -46,13 +45,13 @@ class WriteDesign(Action):
|
|||
"clearly and in detail."
|
||||
)
|
||||
|
||||
async def run(self, with_messages: Message, schema: str = CONFIG.prompt_schema):
|
||||
async def run(self, with_messages: Message, schema: str = None):
|
||||
# Use `git status` to identify which PRD documents have been modified in the `docs/prds` directory.
|
||||
prds_file_repo = CONFIG.git_repo.new_file_repository(PRDS_FILE_REPO)
|
||||
prds_file_repo = self.git_repo.new_file_repository(PRDS_FILE_REPO)
|
||||
changed_prds = prds_file_repo.changed_files
|
||||
# Use `git status` to identify which design documents in the `docs/system_designs` directory have undergone
|
||||
# changes.
|
||||
system_design_file_repo = CONFIG.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO)
|
||||
system_design_file_repo = self.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO)
|
||||
changed_system_designs = system_design_file_repo.changed_files
|
||||
|
||||
# For those PRDs and design documents that have undergone changes, regenerate the design content.
|
||||
|
|
@ -76,11 +75,11 @@ class WriteDesign(Action):
|
|||
# leaving room for global optimization in subsequent steps.
|
||||
return ActionOutput(content=changed_files.model_dump_json(), instruct_content=changed_files)
|
||||
|
||||
async def _new_system_design(self, context, schema=CONFIG.prompt_schema):
|
||||
async def _new_system_design(self, context, schema=None):
|
||||
node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, schema=schema)
|
||||
return node
|
||||
|
||||
async def _merge(self, prd_doc, system_design_doc, schema=CONFIG.prompt_schema):
|
||||
async def _merge(self, prd_doc, system_design_doc, schema=None):
|
||||
context = NEW_REQ_TEMPLATE.format(old_design=system_design_doc.content, context=prd_doc.content)
|
||||
node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, schema=schema)
|
||||
system_design_doc.content = node.instruct_content.model_dump_json()
|
||||
|
|
@ -106,23 +105,21 @@ class WriteDesign(Action):
|
|||
await self._save_pdf(doc)
|
||||
return doc
|
||||
|
||||
@staticmethod
|
||||
async def _save_data_api_design(design_doc):
|
||||
async def _save_data_api_design(self, design_doc):
|
||||
m = json.loads(design_doc.content)
|
||||
data_api_design = m.get("Data structures and interfaces")
|
||||
if not data_api_design:
|
||||
return
|
||||
pathname = CONFIG.git_repo.workdir / DATA_API_DESIGN_FILE_REPO / Path(design_doc.filename).with_suffix("")
|
||||
pathname = self.git_repo.workdir / DATA_API_DESIGN_FILE_REPO / Path(design_doc.filename).with_suffix("")
|
||||
await WriteDesign._save_mermaid_file(data_api_design, pathname)
|
||||
logger.info(f"Save class view to {str(pathname)}")
|
||||
|
||||
@staticmethod
|
||||
async def _save_seq_flow(design_doc):
|
||||
async def _save_seq_flow(self, design_doc):
|
||||
m = json.loads(design_doc.content)
|
||||
seq_flow = m.get("Program call flow")
|
||||
if not seq_flow:
|
||||
return
|
||||
pathname = CONFIG.git_repo.workdir / Path(SEQ_FLOW_FILE_REPO) / Path(design_doc.filename).with_suffix("")
|
||||
pathname = self.git_repo.workdir / Path(SEQ_FLOW_FILE_REPO) / Path(design_doc.filename).with_suffix("")
|
||||
await WriteDesign._save_mermaid_file(seq_flow, pathname)
|
||||
logger.info(f"Saving sequence flow to {str(pathname)}")
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ from pathlib import Path
|
|||
from typing import Optional
|
||||
|
||||
from metagpt.actions import Action, ActionOutput
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import DOCS_FILE_REPO, REQUIREMENT_FILENAME
|
||||
from metagpt.schema import Document
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
|
@ -25,18 +24,22 @@ class PrepareDocuments(Action):
|
|||
name: str = "PrepareDocuments"
|
||||
context: Optional[str] = None
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return self._context.config
|
||||
|
||||
def _init_repo(self):
|
||||
"""Initialize the Git environment."""
|
||||
if not CONFIG.project_path:
|
||||
name = CONFIG.project_name or FileRepository.new_filename()
|
||||
path = Path(CONFIG.workspace_path) / name
|
||||
if not self.config.project_path:
|
||||
name = self.config.project_name or FileRepository.new_filename()
|
||||
path = Path(self.config.workspace.path) / name
|
||||
else:
|
||||
path = Path(CONFIG.project_path)
|
||||
if path.exists() and not CONFIG.inc:
|
||||
path = Path(self.config.project_path)
|
||||
if path.exists() and not self.config.inc:
|
||||
shutil.rmtree(path)
|
||||
CONFIG.project_path = path
|
||||
CONFIG.project_name = path.name
|
||||
CONFIG.git_repo = GitRepository(local_path=path, auto_init=True)
|
||||
self.config.project_path = path
|
||||
self.config.project_name = path.name
|
||||
self._context.git_repo = GitRepository(local_path=path, auto_init=True)
|
||||
|
||||
async def run(self, with_messages, **kwargs):
|
||||
"""Create and initialize the workspace folder, initialize the Git environment."""
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ from typing import Optional
|
|||
from metagpt.actions import ActionOutput
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.actions.project_management_an import PM_NODE
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import (
|
||||
PACKAGE_REQUIREMENTS_FILENAME,
|
||||
SYSTEM_DESIGN_FILE_REPO,
|
||||
|
|
@ -40,11 +39,15 @@ class WriteTasks(Action):
|
|||
name: str = "CreateTasks"
|
||||
context: Optional[str] = None
|
||||
|
||||
async def run(self, with_messages, schema=CONFIG.prompt_schema):
|
||||
system_design_file_repo = CONFIG.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO)
|
||||
@property
|
||||
def prompt_schema(self):
|
||||
return self._context.config.prompt_schema
|
||||
|
||||
async def run(self, with_messages, schema=None):
|
||||
system_design_file_repo = self.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO)
|
||||
changed_system_designs = system_design_file_repo.changed_files
|
||||
|
||||
tasks_file_repo = CONFIG.git_repo.new_file_repository(TASK_FILE_REPO)
|
||||
tasks_file_repo = self.git_repo.new_file_repository(TASK_FILE_REPO)
|
||||
changed_tasks = tasks_file_repo.changed_files
|
||||
change_files = Documents()
|
||||
# Rewrite the system designs that have undergone changes based on the git head diff under
|
||||
|
|
@ -87,21 +90,20 @@ class WriteTasks(Action):
|
|||
await self._save_pdf(task_doc=task_doc)
|
||||
return task_doc
|
||||
|
||||
async def _run_new_tasks(self, context, schema=CONFIG.prompt_schema):
|
||||
node = await PM_NODE.fill(context, self.llm, schema)
|
||||
async def _run_new_tasks(self, context):
|
||||
node = await PM_NODE.fill(context, self.llm, schema=self.prompt_schema)
|
||||
return node
|
||||
|
||||
async def _merge(self, system_design_doc, task_doc, schema=CONFIG.prompt_schema) -> Document:
|
||||
async def _merge(self, system_design_doc, task_doc) -> Document:
|
||||
context = NEW_REQ_TEMPLATE.format(context=system_design_doc.content, old_tasks=task_doc.content)
|
||||
node = await PM_NODE.fill(context, self.llm, schema)
|
||||
node = await PM_NODE.fill(context, self.llm, schema=self.prompt_schema)
|
||||
task_doc.content = node.instruct_content.model_dump_json()
|
||||
return task_doc
|
||||
|
||||
@staticmethod
|
||||
async def _update_requirements(doc):
|
||||
async def _update_requirements(self, doc):
|
||||
m = json.loads(doc.content)
|
||||
packages = set(m.get("Required Python third-party packages", set()))
|
||||
file_repo = CONFIG.git_repo.new_file_repository()
|
||||
file_repo = self.git_repo.new_file_repository()
|
||||
requirement_doc = await file_repo.get(filename=PACKAGE_REQUIREMENTS_FILENAME)
|
||||
if not requirement_doc:
|
||||
requirement_doc = Document(filename=PACKAGE_REQUIREMENTS_FILENAME, root_path=".", content="")
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ import re
|
|||
from pathlib import Path
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import CLASS_VIEW_FILE_REPO, GRAPH_REPO_FILE_REPO
|
||||
from metagpt.repo_parser import RepoParser
|
||||
from metagpt.utils.di_graph_repository import DiGraphRepository
|
||||
|
|
@ -21,8 +20,8 @@ class RebuildClassView(Action):
|
|||
def __init__(self, name="", context=None, llm=None):
|
||||
super().__init__(name=name, context=context, llm=llm)
|
||||
|
||||
async def run(self, with_messages=None, format=CONFIG.prompt_schema):
|
||||
graph_repo_pathname = CONFIG.git_repo.workdir / GRAPH_REPO_FILE_REPO / CONFIG.git_repo.workdir.name
|
||||
async def run(self, with_messages=None):
|
||||
graph_repo_pathname = self.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.git_repo.workdir.name
|
||||
graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
|
||||
repo_parser = RepoParser(base_directory=self.context)
|
||||
class_views = await repo_parser.rebuild_class_views(path=Path(self.context)) # use pylint
|
||||
|
|
@ -57,7 +56,7 @@ class RebuildClassView(Action):
|
|||
# logger.info(f"{concat_namespace(filename, class_name)} {GraphKeyword.HAS_CLASS_VIEW} {class_view}")
|
||||
|
||||
async def _save(self, graph_db):
|
||||
class_view_file_repo = CONFIG.git_repo.new_file_repository(relative_path=CLASS_VIEW_FILE_REPO)
|
||||
class_view_file_repo = self.git_repo.new_file_repository(relative_path=CLASS_VIEW_FILE_REPO)
|
||||
dataset = await graph_db.select(predicate=GraphKeyword.HAS_CLASS_VIEW)
|
||||
all_class_view = []
|
||||
for spo in dataset:
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@ from typing import Tuple
|
|||
from pydantic import Field
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import RunCodeContext, RunCodeResult
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
|
|
@ -89,13 +88,12 @@ class RunCode(Action):
|
|||
return "", str(e)
|
||||
return namespace.get("result", ""), ""
|
||||
|
||||
@classmethod
|
||||
async def run_script(cls, working_directory, additional_python_paths=[], command=[]) -> Tuple[str, str]:
|
||||
async def run_script(self, working_directory, additional_python_paths=[], command=[]) -> Tuple[str, str]:
|
||||
working_directory = str(working_directory)
|
||||
additional_python_paths = [str(path) for path in additional_python_paths]
|
||||
|
||||
# Copy the current environment variables
|
||||
env = CONFIG.new_environ()
|
||||
env = self._context.new_environ()
|
||||
|
||||
# Modify the PYTHONPATH environment variable
|
||||
additional_python_paths = [working_directory] + additional_python_paths
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import pydantic
|
|||
from pydantic import Field, model_validator
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.config import CONFIG, Config
|
||||
from metagpt.config import Config
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from metagpt.tools import SearchEngineType
|
||||
|
|
@ -103,12 +103,11 @@ You are a member of a professional butler team and will provide helpful suggesti
|
|||
"""
|
||||
|
||||
|
||||
# TOTEST
|
||||
class SearchAndSummarize(Action):
|
||||
name: str = ""
|
||||
content: Optional[str] = None
|
||||
config: None = Field(default_factory=Config)
|
||||
engine: Optional[SearchEngineType] = CONFIG.search_engine
|
||||
engine: Optional[SearchEngineType] = None
|
||||
search_func: Optional[Any] = None
|
||||
search_engine: SearchEngine = None
|
||||
result: str = ""
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ from pydantic import Field
|
|||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import CodeSummarizeContext
|
||||
|
|
@ -105,7 +104,7 @@ class SummarizeCode(Action):
|
|||
design_doc = await FileRepository.get_file(filename=design_pathname.name, relative_path=SYSTEM_DESIGN_FILE_REPO)
|
||||
task_pathname = Path(self.context.task_filename)
|
||||
task_doc = await FileRepository.get_file(filename=task_pathname.name, relative_path=TASK_FILE_REPO)
|
||||
src_file_repo = CONFIG.git_repo.new_file_repository(relative_path=CONFIG.src_workspace)
|
||||
src_file_repo = self.git_repo.new_file_repository(relative_path=self._context.src_workspace)
|
||||
code_blocks = []
|
||||
for filename in self.context.codes_filenames:
|
||||
code_doc = await src_file_repo.get(filename)
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@ from pydantic import Field
|
|||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import (
|
||||
BUGFIX_FILENAME,
|
||||
CODE_SUMMARIES_FILE_REPO,
|
||||
|
|
@ -114,7 +113,12 @@ class WriteCode(Action):
|
|||
if bug_feedback:
|
||||
code_context = coding_context.code_doc.content
|
||||
else:
|
||||
code_context = await self.get_codes(coding_context.task_doc, exclude=self.context.filename)
|
||||
code_context = await self.get_codes(
|
||||
coding_context.task_doc,
|
||||
exclude=self.context.filename,
|
||||
git_repo=self.git_repo,
|
||||
src_workspace=self._context.src_workspace,
|
||||
)
|
||||
|
||||
prompt = PROMPT_TEMPLATE.format(
|
||||
design=coding_context.design_doc.content if coding_context.design_doc else "",
|
||||
|
|
@ -129,13 +133,13 @@ class WriteCode(Action):
|
|||
code = await self.write_code(prompt)
|
||||
if not coding_context.code_doc:
|
||||
# avoid root_path pydantic ValidationError if use WriteCode alone
|
||||
root_path = CONFIG.src_workspace if CONFIG.src_workspace else ""
|
||||
root_path = self._context.src_workspace if self._context.src_workspace else ""
|
||||
coding_context.code_doc = Document(filename=coding_context.filename, root_path=root_path)
|
||||
coding_context.code_doc.content = code
|
||||
return coding_context
|
||||
|
||||
@staticmethod
|
||||
async def get_codes(task_doc, exclude) -> str:
|
||||
async def get_codes(task_doc, exclude, git_repo, src_workspace) -> str:
|
||||
if not task_doc:
|
||||
return ""
|
||||
if not task_doc.content:
|
||||
|
|
@ -143,7 +147,7 @@ class WriteCode(Action):
|
|||
m = json.loads(task_doc.content)
|
||||
code_filenames = m.get("Task list", [])
|
||||
codes = []
|
||||
src_file_repo = CONFIG.git_repo.new_file_repository(relative_path=CONFIG.src_workspace)
|
||||
src_file_repo = git_repo.new_file_repository(relative_path=src_workspace)
|
||||
for filename in code_filenames:
|
||||
if filename == exclude:
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential
|
|||
|
||||
from metagpt.actions import WriteCode
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import CodingContext
|
||||
from metagpt.utils.common import CodeParser
|
||||
|
|
@ -137,11 +136,16 @@ class WriteCodeReview(Action):
|
|||
|
||||
async def run(self, *args, **kwargs) -> CodingContext:
|
||||
iterative_code = self.context.code_doc.content
|
||||
k = CONFIG.code_review_k_times or 1
|
||||
k = self._context.config.code_review_k_times or 1
|
||||
for i in range(k):
|
||||
format_example = FORMAT_EXAMPLE.format(filename=self.context.code_doc.filename)
|
||||
task_content = self.context.task_doc.content if self.context.task_doc else ""
|
||||
code_context = await WriteCode.get_codes(self.context.task_doc, exclude=self.context.filename)
|
||||
code_context = await WriteCode.get_codes(
|
||||
self.context.task_doc,
|
||||
exclude=self.context.filename,
|
||||
git_repo=self._context.git_repo,
|
||||
src_workspace=self.src_workspace,
|
||||
)
|
||||
context = "\n".join(
|
||||
[
|
||||
"## System Design\n" + str(self.context.design_doc) + "\n",
|
||||
|
|
|
|||
|
|
@ -26,7 +26,6 @@ from metagpt.actions.write_prd_an import (
|
|||
WP_ISSUE_TYPE_NODE,
|
||||
WRITE_PRD_NODE,
|
||||
)
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import (
|
||||
BUGFIX_FILENAME,
|
||||
COMPETITIVE_ANALYSIS_FILE_REPO,
|
||||
|
|
@ -65,10 +64,10 @@ class WritePRD(Action):
|
|||
name: str = "WritePRD"
|
||||
content: Optional[str] = None
|
||||
|
||||
async def run(self, with_messages, schema=CONFIG.prompt_schema, *args, **kwargs) -> ActionOutput | Message:
|
||||
async def run(self, with_messages, *args, **kwargs) -> ActionOutput | Message:
|
||||
# Determine which requirement documents need to be rewritten: Use LLM to assess whether new requirements are
|
||||
# related to the PRD. If they are related, rewrite the PRD.
|
||||
docs_file_repo = CONFIG.git_repo.new_file_repository(relative_path=DOCS_FILE_REPO)
|
||||
docs_file_repo = self.git_repo.new_file_repository(relative_path=DOCS_FILE_REPO)
|
||||
requirement_doc = await docs_file_repo.get(filename=REQUIREMENT_FILENAME)
|
||||
if requirement_doc and await self._is_bugfix(requirement_doc.content):
|
||||
await docs_file_repo.save(filename=BUGFIX_FILENAME, content=requirement_doc.content)
|
||||
|
|
@ -85,7 +84,7 @@ class WritePRD(Action):
|
|||
else:
|
||||
await docs_file_repo.delete(filename=BUGFIX_FILENAME)
|
||||
|
||||
prds_file_repo = CONFIG.git_repo.new_file_repository(PRDS_FILE_REPO)
|
||||
prds_file_repo = self.git_repo.new_file_repository(PRDS_FILE_REPO)
|
||||
prd_docs = await prds_file_repo.get_all()
|
||||
change_files = Documents()
|
||||
for prd_doc in prd_docs:
|
||||
|
|
@ -109,7 +108,7 @@ class WritePRD(Action):
|
|||
# optimization in subsequent steps.
|
||||
return ActionOutput(content=change_files.model_dump_json(), instruct_content=change_files)
|
||||
|
||||
async def _run_new_requirement(self, requirements, schema=CONFIG.prompt_schema) -> ActionOutput:
|
||||
async def _run_new_requirement(self, requirements) -> ActionOutput:
|
||||
# sas = SearchAndSummarize()
|
||||
# # rsp = await sas.run(context=requirements, system_text=SEARCH_AND_SUMMARIZE_SYSTEM_EN_US)
|
||||
# rsp = ""
|
||||
|
|
@ -117,7 +116,7 @@ class WritePRD(Action):
|
|||
# if sas.result:
|
||||
# logger.info(sas.result)
|
||||
# logger.info(rsp)
|
||||
project_name = CONFIG.project_name if CONFIG.project_name else ""
|
||||
project_name = self.project_name
|
||||
context = CONTEXT_TEMPLATE.format(requirements=requirements, project_name=project_name)
|
||||
exclude = [PROJECT_NAME.key] if project_name else []
|
||||
node = await WRITE_PRD_NODE.fill(context=context, llm=self.llm, exclude=exclude) # schema=schema
|
||||
|
|
@ -129,11 +128,11 @@ class WritePRD(Action):
|
|||
node = await WP_IS_RELATIVE_NODE.fill(context, self.llm)
|
||||
return node.get("is_relative") == "YES"
|
||||
|
||||
async def _merge(self, new_requirement_doc, prd_doc, schema=CONFIG.prompt_schema) -> Document:
|
||||
if not CONFIG.project_name:
|
||||
CONFIG.project_name = Path(CONFIG.project_path).name
|
||||
async def _merge(self, new_requirement_doc, prd_doc) -> Document:
|
||||
if not self.project_name:
|
||||
self.project_name = Path(self.project_path).name
|
||||
prompt = NEW_REQ_TEMPLATE.format(requirements=new_requirement_doc.content, old_prd=prd_doc.content)
|
||||
node = await WRITE_PRD_NODE.fill(context=prompt, llm=self.llm, schema=schema)
|
||||
node = await WRITE_PRD_NODE.fill(context=prompt, llm=self.llm, schema=self.prompt_schema)
|
||||
prd_doc.content = node.instruct_content.model_dump_json()
|
||||
await self._rename_workspace(node)
|
||||
return prd_doc
|
||||
|
|
@ -157,15 +156,12 @@ class WritePRD(Action):
|
|||
await self._save_pdf(new_prd_doc)
|
||||
return new_prd_doc
|
||||
|
||||
@staticmethod
|
||||
async def _save_competitive_analysis(prd_doc):
|
||||
async def _save_competitive_analysis(self, prd_doc):
|
||||
m = json.loads(prd_doc.content)
|
||||
quadrant_chart = m.get("Competitive Quadrant Chart")
|
||||
if not quadrant_chart:
|
||||
return
|
||||
pathname = (
|
||||
CONFIG.git_repo.workdir / Path(COMPETITIVE_ANALYSIS_FILE_REPO) / Path(prd_doc.filename).with_suffix("")
|
||||
)
|
||||
pathname = self.git_repo.workdir / Path(COMPETITIVE_ANALYSIS_FILE_REPO) / Path(prd_doc.filename).with_suffix("")
|
||||
if not pathname.parent.exists():
|
||||
pathname.parent.mkdir(parents=True, exist_ok=True)
|
||||
await mermaid_to_file(quadrant_chart, pathname)
|
||||
|
|
@ -174,20 +170,19 @@ class WritePRD(Action):
|
|||
async def _save_pdf(prd_doc):
|
||||
await FileRepository.save_as(doc=prd_doc, with_suffix=".md", relative_path=PRD_PDF_FILE_REPO)
|
||||
|
||||
@staticmethod
|
||||
async def _rename_workspace(prd):
|
||||
if not CONFIG.project_name:
|
||||
async def _rename_workspace(self, prd):
|
||||
if not self.project_name:
|
||||
if isinstance(prd, (ActionOutput, ActionNode)):
|
||||
ws_name = prd.instruct_content.model_dump()["Project Name"]
|
||||
else:
|
||||
ws_name = CodeParser.parse_str(block="Project Name", text=prd)
|
||||
if ws_name:
|
||||
CONFIG.project_name = ws_name
|
||||
CONFIG.git_repo.rename_root(CONFIG.project_name)
|
||||
self.project_name = ws_name
|
||||
self.git_repo.rename_root(self.project_name)
|
||||
|
||||
async def _is_bugfix(self, context) -> bool:
|
||||
src_workspace_path = CONFIG.git_repo.workdir / CONFIG.git_repo.workdir.name
|
||||
code_files = CONFIG.git_repo.get_files(relative_path=src_workspace_path)
|
||||
src_workspace_path = self.git_repo.workdir / self.git_repo.workdir.name
|
||||
code_files = self.git_repo.get_files(relative_path=src_workspace_path)
|
||||
if not code_files:
|
||||
return False
|
||||
node = await WP_ISSUE_TYPE_NODE.fill(context, self.llm)
|
||||
|
|
|
|||
|
|
@ -75,6 +75,7 @@ class WriteTeachingPlanPart(Action):
|
|||
if "{" not in value:
|
||||
return value
|
||||
|
||||
# FIXME: 从Context中获取参数
|
||||
merged_opts = CONFIG.options or {}
|
||||
try:
|
||||
return value.format(**merged_opts)
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@
|
|||
from typing import Optional
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import TEST_CODES_FILE_REPO
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Document, TestingContext
|
||||
|
|
@ -64,7 +63,7 @@ class WriteTest(Action):
|
|||
code_to_test=self.context.code_doc.content,
|
||||
test_file_name=self.context.test_doc.filename,
|
||||
source_file_path=self.context.code_doc.root_relative_path,
|
||||
workspace=CONFIG.git_repo.workdir,
|
||||
workspace=self.git_repo.workdir,
|
||||
)
|
||||
self.context.test_doc.content = await self.write_code(prompt)
|
||||
return self.context
|
||||
|
|
|
|||
|
|
@ -11,13 +11,13 @@ import json
|
|||
import os
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import yaml
|
||||
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.const import DEFAULT_WORKSPACE_ROOT, METAGPT_ROOT, OPTIONS
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools import SearchEngineType, WebBrowserEngineType
|
||||
|
|
@ -38,19 +38,6 @@ class NotConfiguredException(Exception):
|
|||
super().__init__(self.message)
|
||||
|
||||
|
||||
class LLMProviderEnum(Enum):
|
||||
OPENAI = "openai"
|
||||
ANTHROPIC = "anthropic"
|
||||
SPARK = "spark"
|
||||
ZHIPUAI = "zhipuai"
|
||||
FIREWORKS = "fireworks"
|
||||
OPEN_LLM = "open_llm"
|
||||
GEMINI = "gemini"
|
||||
METAGPT = "metagpt"
|
||||
AZURE_OPENAI = "azure_openai"
|
||||
OLLAMA = "ollama"
|
||||
|
||||
|
||||
class Config(metaclass=Singleton):
|
||||
"""
|
||||
Regular usage method:
|
||||
|
|
@ -81,27 +68,25 @@ class Config(metaclass=Singleton):
|
|||
global_options.update(OPTIONS.get())
|
||||
logger.debug("Config loading done.")
|
||||
|
||||
def get_default_llm_provider_enum(self) -> LLMProviderEnum:
|
||||
def get_default_llm_provider_enum(self) -> LLMType:
|
||||
"""Get first valid LLM provider enum"""
|
||||
mappings = {
|
||||
LLMProviderEnum.OPENAI: bool(
|
||||
LLMType.OPENAI: bool(
|
||||
self._is_valid_llm_key(self.OPENAI_API_KEY) and not self.OPENAI_API_TYPE and self.OPENAI_API_MODEL
|
||||
),
|
||||
LLMProviderEnum.ANTHROPIC: self._is_valid_llm_key(self.ANTHROPIC_API_KEY),
|
||||
LLMProviderEnum.ZHIPUAI: self._is_valid_llm_key(self.ZHIPUAI_API_KEY),
|
||||
LLMProviderEnum.FIREWORKS: self._is_valid_llm_key(self.FIREWORKS_API_KEY),
|
||||
LLMProviderEnum.OPEN_LLM: self._is_valid_llm_key(self.OPEN_LLM_API_BASE),
|
||||
LLMProviderEnum.GEMINI: self._is_valid_llm_key(self.GEMINI_API_KEY),
|
||||
LLMProviderEnum.METAGPT: bool(
|
||||
self._is_valid_llm_key(self.OPENAI_API_KEY) and self.OPENAI_API_TYPE == "metagpt"
|
||||
),
|
||||
LLMProviderEnum.AZURE_OPENAI: bool(
|
||||
LLMType.ANTHROPIC: self._is_valid_llm_key(self.ANTHROPIC_API_KEY),
|
||||
LLMType.ZHIPUAI: self._is_valid_llm_key(self.ZHIPUAI_API_KEY),
|
||||
LLMType.FIREWORKS: self._is_valid_llm_key(self.FIREWORKS_API_KEY),
|
||||
LLMType.OPEN_LLM: self._is_valid_llm_key(self.OPEN_LLM_API_BASE),
|
||||
LLMType.GEMINI: self._is_valid_llm_key(self.GEMINI_API_KEY),
|
||||
LLMType.METAGPT: bool(self._is_valid_llm_key(self.OPENAI_API_KEY) and self.OPENAI_API_TYPE == "metagpt"),
|
||||
LLMType.AZURE_OPENAI: bool(
|
||||
self._is_valid_llm_key(self.OPENAI_API_KEY)
|
||||
and self.OPENAI_API_TYPE == "azure"
|
||||
and self.DEPLOYMENT_NAME
|
||||
and self.OPENAI_API_VERSION
|
||||
),
|
||||
LLMProviderEnum.OLLAMA: self._is_valid_llm_key(self.OLLAMA_API_BASE),
|
||||
LLMType.OLLAMA: self._is_valid_llm_key(self.OLLAMA_API_BASE),
|
||||
}
|
||||
provider = None
|
||||
for k, v in mappings.items():
|
||||
|
|
@ -109,7 +94,7 @@ class Config(metaclass=Singleton):
|
|||
provider = k
|
||||
break
|
||||
|
||||
if provider is LLMProviderEnum.GEMINI and not require_python_version(req_version=(3, 10)):
|
||||
if provider is LLMType.GEMINI and not require_python_version(req_version=(3, 10)):
|
||||
warnings.warn("Use Gemini requires Python >= 3.10")
|
||||
model_name = self.get_model_name(provider=provider)
|
||||
if model_name:
|
||||
|
|
@ -122,8 +107,8 @@ class Config(metaclass=Singleton):
|
|||
def get_model_name(self, provider=None) -> str:
|
||||
provider = provider or self.get_default_llm_provider_enum()
|
||||
model_mappings = {
|
||||
LLMProviderEnum.OPENAI: self.OPENAI_API_MODEL,
|
||||
LLMProviderEnum.AZURE_OPENAI: self.DEPLOYMENT_NAME,
|
||||
LLMType.OPENAI: self.OPENAI_API_MODEL,
|
||||
LLMType.AZURE_OPENAI: self.DEPLOYMENT_NAME,
|
||||
}
|
||||
return model_mappings.get(provider, "")
|
||||
|
||||
|
|
@ -166,6 +151,7 @@ class Config(metaclass=Singleton):
|
|||
self.fireworks_api_model = self._get("FIREWORKS_API_MODEL")
|
||||
|
||||
self.claude_api_key = self._get("ANTHROPIC_API_KEY")
|
||||
|
||||
self.serpapi_api_key = self._get("SERPAPI_API_KEY")
|
||||
self.serper_api_key = self._get("SERPER_API_KEY")
|
||||
self.google_api_key = self._get("GOOGLE_API_KEY")
|
||||
|
|
@ -200,7 +186,7 @@ class Config(metaclass=Singleton):
|
|||
self.workspace_path = self.workspace_path / workspace_uid
|
||||
self._ensure_workspace_exists()
|
||||
self.max_auto_summarize_code = self.max_auto_summarize_code or self._get("MAX_AUTO_SUMMARIZE_CODE", 1)
|
||||
self.timeout = int(self._get("TIMEOUT", 3))
|
||||
self.timeout = int(self._get("TIMEOUT", 60))
|
||||
|
||||
def update_via_cli(self, project_path, project_name, inc, reqa_file, max_auto_summarize_code):
|
||||
"""update config via cli"""
|
||||
|
|
|
|||
124
metagpt/config2.py
Normal file
124
metagpt/config2.py
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/1/4 01:25
|
||||
@Author : alexanderwu
|
||||
@File : llm_factory.py
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from metagpt.configs.browser_config import BrowserConfig
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.configs.mermaid_config import MermaidConfig
|
||||
from metagpt.configs.redis_config import RedisConfig
|
||||
from metagpt.configs.s3_config import S3Config
|
||||
from metagpt.configs.search_config import SearchConfig
|
||||
from metagpt.configs.workspace_config import WorkspaceConfig
|
||||
from metagpt.const import METAGPT_ROOT
|
||||
from metagpt.utils.yaml_model import YamlModel
|
||||
|
||||
|
||||
class CLIParams(BaseModel):
|
||||
project_path: str = ""
|
||||
project_name: str = ""
|
||||
inc: bool = False
|
||||
reqa_file: str = ""
|
||||
max_auto_summarize_code: int = 0
|
||||
git_reinit: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_project_path(self):
|
||||
if self.project_path:
|
||||
self.inc = True
|
||||
self.project_name = self.project_name or Path(self.project_path).name
|
||||
|
||||
|
||||
class Config(CLIParams, YamlModel):
|
||||
# Key Parameters
|
||||
llm: Dict[str, LLMConfig] = Field(default_factory=Dict)
|
||||
|
||||
# Global Proxy. Will be used if llm.proxy is not set
|
||||
proxy: str = ""
|
||||
|
||||
# Tool Parameters
|
||||
search: Dict[str, SearchConfig] = {}
|
||||
browser: Dict[str, BrowserConfig] = {"default": BrowserConfig()}
|
||||
mermaid: Dict[str, MermaidConfig] = {"default": MermaidConfig()}
|
||||
|
||||
# Storage Parameters
|
||||
s3: Optional[S3Config] = None
|
||||
redis: Optional[RedisConfig] = None
|
||||
|
||||
# Misc Parameters
|
||||
repair_llm_output: bool = False
|
||||
prompt_schema: Literal["json", "markdown", "raw"] = "json"
|
||||
workspace: WorkspaceConfig = WorkspaceConfig()
|
||||
enable_longterm_memory: bool = False
|
||||
code_review_k_times: int = 2
|
||||
|
||||
# Will be removed in the future
|
||||
llm_for_researcher_summary: str = "gpt3"
|
||||
llm_for_researcher_report: str = "gpt3"
|
||||
METAGPT_TEXT_TO_IMAGE_MODEL_URL: str = ""
|
||||
|
||||
@classmethod
|
||||
def default(cls):
|
||||
"""Load default config
|
||||
- Priority: env < default_config_paths
|
||||
- Inside default_config_paths, the latter one overwrites the former one
|
||||
"""
|
||||
default_config_paths: List[Path] = [
|
||||
METAGPT_ROOT / "config/config2.yaml",
|
||||
Path.home() / ".metagpt/config2.yaml",
|
||||
]
|
||||
|
||||
dicts = [dict(os.environ)]
|
||||
dicts += [Config.read_yaml(path) for path in default_config_paths]
|
||||
final = merge_dict(dicts)
|
||||
return Config(**final)
|
||||
|
||||
def update_via_cli(self, project_path, project_name, inc, reqa_file, max_auto_summarize_code):
|
||||
"""update config via cli"""
|
||||
|
||||
# Use in the PrepareDocuments action according to Section 2.2.3.5.1 of RFC 135.
|
||||
if project_path:
|
||||
inc = True
|
||||
project_name = project_name or Path(project_path).name
|
||||
self.project_path = project_path
|
||||
self.project_name = project_name
|
||||
self.inc = inc
|
||||
self.reqa_file = reqa_file
|
||||
self.max_auto_summarize_code = max_auto_summarize_code
|
||||
|
||||
def get_llm_config(self, name: Optional[str] = None) -> LLMConfig:
|
||||
"""Get LLM instance by name"""
|
||||
if name is None:
|
||||
# Use the first LLM as default
|
||||
name = list(self.llm.keys())[0]
|
||||
if name not in self.llm:
|
||||
raise ValueError(f"LLM {name} not found in config")
|
||||
return self.llm[name]
|
||||
|
||||
def get_openai_llm(self, name: Optional[str] = None) -> LLMConfig:
|
||||
"""Get OpenAI LLMConfig by name. If no OpenAI, raise Exception"""
|
||||
if name is None:
|
||||
# Use the first OpenAI LLM as default
|
||||
name = [k for k, v in self.llm.items() if v.api_type == LLMType.OPENAI][0]
|
||||
if name not in self.llm:
|
||||
raise ValueError(f"OpenAI LLM {name} not found in config")
|
||||
return self.llm[name]
|
||||
|
||||
|
||||
def merge_dict(dicts: Iterable[Dict]) -> Dict:
|
||||
"""Merge multiple dicts into one, with the latter dict overwriting the former"""
|
||||
result = {}
|
||||
for dictionary in dicts:
|
||||
result.update(dictionary)
|
||||
return result
|
||||
|
||||
|
||||
config = Config.default()
|
||||
7
metagpt/configs/__init__.py
Normal file
7
metagpt/configs/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/1/4 16:33
|
||||
@Author : alexanderwu
|
||||
@File : __init__.py
|
||||
"""
|
||||
20
metagpt/configs/browser_config.py
Normal file
20
metagpt/configs/browser_config.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/1/4 19:06
|
||||
@Author : alexanderwu
|
||||
@File : browser_config.py
|
||||
"""
|
||||
from typing import Literal
|
||||
|
||||
from metagpt.tools import WebBrowserEngineType
|
||||
from metagpt.utils.yaml_model import YamlModel
|
||||
|
||||
|
||||
class BrowserConfig(YamlModel):
|
||||
"""Config for Browser"""
|
||||
|
||||
engine: WebBrowserEngineType = WebBrowserEngineType.PLAYWRIGHT
|
||||
browser: Literal["chrome", "firefox", "edge", "ie"] = "chrome"
|
||||
driver: Literal["chromium", "firefox", "webkit"] = "chromium"
|
||||
path: str = ""
|
||||
74
metagpt/configs/llm_config.py
Normal file
74
metagpt/configs/llm_config.py
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/1/4 16:33
|
||||
@Author : alexanderwu
|
||||
@File : llm_config.py
|
||||
"""
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import field_validator
|
||||
|
||||
from metagpt.utils.yaml_model import YamlModel
|
||||
|
||||
|
||||
class LLMType(Enum):
|
||||
OPENAI = "openai"
|
||||
ANTHROPIC = "anthropic"
|
||||
SPARK = "spark"
|
||||
ZHIPUAI = "zhipuai"
|
||||
FIREWORKS = "fireworks"
|
||||
OPEN_LLM = "open_llm"
|
||||
GEMINI = "gemini"
|
||||
METAGPT = "metagpt"
|
||||
AZURE_OPENAI = "azure"
|
||||
OLLAMA = "ollama"
|
||||
|
||||
|
||||
class LLMConfig(YamlModel):
|
||||
"""Config for LLM
|
||||
|
||||
OpenAI: https://github.com/openai/openai-python/blob/main/src/openai/resources/chat/completions.py#L681
|
||||
Optional Fields in pydantic: https://docs.pydantic.dev/latest/migration/#required-optional-and-nullable-fields
|
||||
"""
|
||||
|
||||
api_key: str
|
||||
api_type: LLMType = LLMType.OPENAI
|
||||
base_url: str = "https://api.openai.com/v1"
|
||||
api_version: Optional[str] = None
|
||||
model: Optional[str] = None # also stands for DEPLOYMENT_NAME
|
||||
|
||||
# For Spark(Xunfei), maybe remove later
|
||||
app_id: Optional[str] = None
|
||||
api_secret: Optional[str] = None
|
||||
domain: Optional[str] = None
|
||||
|
||||
# For Chat Completion
|
||||
max_token: int = 4096
|
||||
temperature: float = 0.0
|
||||
top_p: float = 1.0
|
||||
top_k: int = 0
|
||||
repetition_penalty: float = 1.0
|
||||
stop: Optional[str] = None
|
||||
presence_penalty: float = 0.0
|
||||
frequency_penalty: float = 0.0
|
||||
best_of: Optional[int] = None
|
||||
n: Optional[int] = None
|
||||
stream: bool = False
|
||||
logprobs: Optional[bool] = None # https://cookbook.openai.com/examples/using_logprobs
|
||||
top_logprobs: Optional[int] = None
|
||||
timeout: int = 60
|
||||
|
||||
# For Network
|
||||
proxy: Optional[str] = None
|
||||
|
||||
# Cost Control
|
||||
calc_usage: bool = True
|
||||
|
||||
@field_validator("api_key")
|
||||
@classmethod
|
||||
def check_llm_key(cls, v):
|
||||
if v in ["", None, "YOUR_API_KEY"]:
|
||||
raise ValueError("Please set your API key in config.yaml")
|
||||
return v
|
||||
18
metagpt/configs/mermaid_config.py
Normal file
18
metagpt/configs/mermaid_config.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/1/4 19:07
|
||||
@Author : alexanderwu
|
||||
@File : mermaid_config.py
|
||||
"""
|
||||
from typing import Literal
|
||||
|
||||
from metagpt.utils.yaml_model import YamlModel
|
||||
|
||||
|
||||
class MermaidConfig(YamlModel):
|
||||
"""Config for Mermaid"""
|
||||
|
||||
engine: Literal["nodejs", "ink", "playwright", "pyppeteer"] = "nodejs"
|
||||
path: str = ""
|
||||
puppeteer_config: str = "" # Only for nodejs engine
|
||||
26
metagpt/configs/redis_config.py
Normal file
26
metagpt/configs/redis_config.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/1/4 19:06
|
||||
@Author : alexanderwu
|
||||
@File : redis_config.py
|
||||
"""
|
||||
from metagpt.utils.yaml_model import YamlModelWithoutDefault
|
||||
|
||||
|
||||
class RedisConfig(YamlModelWithoutDefault):
|
||||
host: str
|
||||
port: int
|
||||
username: str = ""
|
||||
password: str
|
||||
db: str
|
||||
|
||||
def to_url(self):
|
||||
return f"redis://{self.host}:{self.port}"
|
||||
|
||||
def to_kwargs(self):
|
||||
return {
|
||||
"username": self.username,
|
||||
"password": self.password,
|
||||
"db": self.db,
|
||||
}
|
||||
15
metagpt/configs/s3_config.py
Normal file
15
metagpt/configs/s3_config.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/1/4 19:07
|
||||
@Author : alexanderwu
|
||||
@File : s3_config.py
|
||||
"""
|
||||
from metagpt.utils.yaml_model import YamlModelWithoutDefault
|
||||
|
||||
|
||||
class S3Config(YamlModelWithoutDefault):
|
||||
access_key: str
|
||||
secret_key: str
|
||||
endpoint: str
|
||||
bucket: str
|
||||
17
metagpt/configs/search_config.py
Normal file
17
metagpt/configs/search_config.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/1/4 19:06
|
||||
@Author : alexanderwu
|
||||
@File : search_config.py
|
||||
"""
|
||||
from metagpt.tools import SearchEngineType
|
||||
from metagpt.utils.yaml_model import YamlModel
|
||||
|
||||
|
||||
class SearchConfig(YamlModel):
|
||||
"""Config for Search"""
|
||||
|
||||
api_key: str
|
||||
api_type: SearchEngineType = SearchEngineType.SERPAPI_GOOGLE
|
||||
cse_id: str = "" # for google
|
||||
38
metagpt/configs/workspace_config.py
Normal file
38
metagpt/configs/workspace_config.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/1/4 19:09
|
||||
@Author : alexanderwu
|
||||
@File : workspace_config.py
|
||||
"""
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import field_validator, model_validator
|
||||
|
||||
from metagpt.const import DEFAULT_WORKSPACE_ROOT
|
||||
from metagpt.utils.yaml_model import YamlModel
|
||||
|
||||
|
||||
class WorkspaceConfig(YamlModel):
|
||||
path: Path = DEFAULT_WORKSPACE_ROOT
|
||||
use_uid: bool = False
|
||||
uid: str = ""
|
||||
|
||||
@field_validator("path")
|
||||
@classmethod
|
||||
def check_workspace_path(cls, v):
|
||||
if isinstance(v, str):
|
||||
v = Path(v)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_uid_and_update_path(self):
|
||||
if self.use_uid and not self.uid:
|
||||
self.uid = f"{datetime.now().strftime('%Y%m%d%H%M%S')}-{uuid4().hex[-8:]}"
|
||||
self.path = self.path / self.uid
|
||||
|
||||
# Create workspace path if not exists
|
||||
self.path.mkdir(parents=True, exist_ok=True)
|
||||
return self
|
||||
55
metagpt/context.py
Normal file
55
metagpt/context.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/1/4 16:32
|
||||
@Author : alexanderwu
|
||||
@File : context.py
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.const import OPTIONS
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import get_llm
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
|
||||
|
||||
class Context(BaseModel):
|
||||
kwargs: Dict = {}
|
||||
config: Config = Config.default()
|
||||
git_repo: Optional[GitRepository] = None
|
||||
src_workspace: Optional[Path] = None
|
||||
cost_manager: CostManager = CostManager()
|
||||
|
||||
@property
|
||||
def options(self):
|
||||
"""Return all key-values"""
|
||||
return OPTIONS.get()
|
||||
|
||||
def new_environ(self):
|
||||
"""Return a new os.environ object"""
|
||||
env = os.environ.copy()
|
||||
i = self.options
|
||||
env.update({k: v for k, v in i.items() if isinstance(v, str)})
|
||||
return env
|
||||
|
||||
def llm(self, name: Optional[str] = None) -> BaseLLM:
|
||||
"""Return a LLM instance"""
|
||||
llm = get_llm(self.config.get_llm_config(name))
|
||||
if llm.cost_manager is None:
|
||||
llm.cost_manager = self.cost_manager
|
||||
return llm
|
||||
|
||||
|
||||
# Global context
|
||||
context = Context()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(context.model_dump_json(indent=4))
|
||||
print(context.config.get_openai_llm())
|
||||
|
|
@ -9,14 +9,13 @@ import asyncio
|
|||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.vectorstores import FAISS
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.document import IndexableDocument
|
||||
from metagpt.document_store.base_store import LocalStore
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.embedding import get_embedding
|
||||
|
||||
|
||||
class FaissStore(LocalStore):
|
||||
|
|
@ -25,9 +24,7 @@ class FaissStore(LocalStore):
|
|||
):
|
||||
self.meta_col = meta_col
|
||||
self.content_col = content_col
|
||||
self.embedding = embedding or OpenAIEmbeddings(
|
||||
openai_api_key=CONFIG.openai_api_key, openai_api_base=CONFIG.openai_base_url
|
||||
)
|
||||
self.embedding = embedding or get_embedding()
|
||||
super().__init__(raw_data, cache_dir)
|
||||
|
||||
def _load(self) -> Optional["FaissStore"]:
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from typing import Iterable, Set
|
|||
|
||||
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.context import Context
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.role import Role
|
||||
from metagpt.schema import Message
|
||||
|
|
@ -35,6 +35,7 @@ class Environment(BaseModel):
|
|||
roles: dict[str, SerializeAsAny[Role]] = Field(default_factory=dict, validate_default=True)
|
||||
members: dict[Role, Set] = Field(default_factory=dict, exclude=True)
|
||||
history: str = "" # For debug
|
||||
context: Context = Field(default_factory=Context, exclude=True)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def init_roles(self):
|
||||
|
|
@ -85,6 +86,7 @@ class Environment(BaseModel):
|
|||
"""
|
||||
self.roles[role.profile] = role
|
||||
role.set_env(self)
|
||||
role.context = self.context
|
||||
|
||||
def add_roles(self, roles: Iterable[Role]):
|
||||
"""增加一批在当前环境的角色
|
||||
|
|
@ -95,6 +97,7 @@ class Environment(BaseModel):
|
|||
|
||||
for role in roles: # setup system message with roles
|
||||
role.set_env(self)
|
||||
role.context = self.context
|
||||
|
||||
def publish_message(self, message: Message, peekable: bool = True) -> bool:
|
||||
"""
|
||||
|
|
@ -162,7 +165,6 @@ class Environment(BaseModel):
|
|||
"""Set the labels for message to be consumed by the object"""
|
||||
self.members[obj] = tags
|
||||
|
||||
@staticmethod
|
||||
def archive(auto_archive=True):
|
||||
if auto_archive and CONFIG.git_repo:
|
||||
CONFIG.git_repo.archive()
|
||||
def archive(self, auto_archive=True):
|
||||
if auto_archive and self.context.git_repo:
|
||||
self.context.git_repo.archive()
|
||||
|
|
|
|||
|
|
@ -8,33 +8,37 @@
|
|||
"""
|
||||
import base64
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.const import BASE64_FORMAT
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.tools.metagpt_text_to_image import oas3_metagpt_text_to_image
|
||||
from metagpt.tools.openai_text_to_image import oas3_openai_text_to_image
|
||||
from metagpt.utils.s3 import S3
|
||||
|
||||
|
||||
async def text_to_image(text, size_type: str = "512x512", openai_api_key="", model_url="", **kwargs):
|
||||
async def text_to_image(text, size_type: str = "512x512", model_url="", config: Config = None):
|
||||
"""Text to image
|
||||
|
||||
:param text: The text used for image conversion.
|
||||
:param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`
|
||||
:param size_type: If using OPENAI, the available size options are ['256x256', '512x512', '1024x1024'], while for MetaGPT, the options are ['512x512', '512x768'].
|
||||
:param model_url: MetaGPT model url
|
||||
:param config: Config
|
||||
:return: The image data is returned in Base64 encoding.
|
||||
"""
|
||||
image_declaration = "data:image/png;base64,"
|
||||
if CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL or model_url:
|
||||
|
||||
if model_url:
|
||||
binary_data = await oas3_metagpt_text_to_image(text, size_type, model_url)
|
||||
elif CONFIG.OPENAI_API_KEY or openai_api_key:
|
||||
binary_data = await oas3_openai_text_to_image(text, size_type)
|
||||
elif oai_llm := config.get_openai_llm():
|
||||
binary_data = await oas3_openai_text_to_image(text, size_type, LLM(oai_llm))
|
||||
else:
|
||||
raise ValueError("Missing necessary parameters.")
|
||||
base64_data = base64.b64encode(binary_data).decode("utf-8")
|
||||
|
||||
s3 = S3()
|
||||
url = await s3.cache(data=base64_data, file_ext=".png", format=BASE64_FORMAT) if s3.is_valid else ""
|
||||
assert config.s3, "S3 config is required."
|
||||
s3 = S3(config.s3)
|
||||
url = await s3.cache(data=base64_data, file_ext=".png", format=BASE64_FORMAT)
|
||||
if url:
|
||||
return f""
|
||||
return image_declaration + base64_data if base64_data else ""
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ async def text_to_speech(
|
|||
audio_declaration = "data:audio/wav;base64,"
|
||||
base64_data = await oas3_azsure_tts(text, lang, voice, style, role, subscription_key, region)
|
||||
s3 = S3()
|
||||
url = await s3.cache(data=base64_data, file_ext=".wav", format=BASE64_FORMAT) if s3.is_valid else ""
|
||||
url = await s3.cache(data=base64_data, file_ext=".wav", format=BASE64_FORMAT)
|
||||
if url:
|
||||
return f"[{text}]({url})"
|
||||
return audio_declaration + base64_data if base64_data else base64_data
|
||||
|
|
@ -60,7 +60,7 @@ async def text_to_speech(
|
|||
text=text, app_id=iflytek_app_id, api_key=iflytek_api_key, api_secret=iflytek_api_secret
|
||||
)
|
||||
s3 = S3()
|
||||
url = await s3.cache(data=base64_data, file_ext=".mp3", format=BASE64_FORMAT) if s3.is_valid else ""
|
||||
url = await s3.cache(data=base64_data, file_ext=".mp3", format=BASE64_FORMAT)
|
||||
if url:
|
||||
return f"[{text}]({url})"
|
||||
return audio_declaration + base64_data if base64_data else base64_data
|
||||
|
|
|
|||
|
|
@ -8,17 +8,10 @@
|
|||
|
||||
from typing import Optional
|
||||
|
||||
from metagpt.config import CONFIG, LLMProviderEnum
|
||||
from metagpt.context import context
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.human_provider import HumanProvider
|
||||
from metagpt.provider.llm_provider_registry import LLM_REGISTRY
|
||||
|
||||
_ = HumanProvider() # Avoid pre-commit error
|
||||
|
||||
|
||||
def LLM(provider: Optional[LLMProviderEnum] = None) -> BaseLLM:
|
||||
"""get the default llm provider"""
|
||||
if provider is None:
|
||||
provider = CONFIG.get_default_llm_provider_enum()
|
||||
|
||||
return LLM_REGISTRY.get_provider(provider)
|
||||
def LLM(name: Optional[str] = None) -> BaseLLM:
|
||||
"""get the default llm provider if name is None"""
|
||||
return context.llm(name)
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from metagpt.provider.openai_api import OpenAILLM
|
|||
from metagpt.provider.zhipuai_api import ZhiPuAILLM
|
||||
from metagpt.provider.azure_openai_api import AzureOpenAILLM
|
||||
from metagpt.provider.metagpt_api import MetaGPTLLM
|
||||
from metagpt.provider.human_provider import HumanProvider
|
||||
|
||||
__all__ = [
|
||||
"FireworksLLM",
|
||||
|
|
@ -24,4 +25,5 @@ __all__ = [
|
|||
"AzureOpenAILLM",
|
||||
"MetaGPTLLM",
|
||||
"OllamaLLM",
|
||||
"HumanProvider",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -9,12 +9,15 @@
|
|||
import anthropic
|
||||
from anthropic import Anthropic, AsyncAnthropic
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.configs.llm_config import LLMConfig
|
||||
|
||||
|
||||
class Claude2:
|
||||
def __init__(self, config: LLMConfig = None):
|
||||
self.config = config
|
||||
|
||||
def ask(self, prompt: str) -> str:
|
||||
client = Anthropic(api_key=CONFIG.anthropic_api_key)
|
||||
client = Anthropic(api_key=self.config.api_key)
|
||||
|
||||
res = client.completions.create(
|
||||
model="claude-2",
|
||||
|
|
@ -24,7 +27,7 @@ class Claude2:
|
|||
return res.completion
|
||||
|
||||
async def aask(self, prompt: str) -> str:
|
||||
aclient = AsyncAnthropic(api_key=CONFIG.anthropic_api_key)
|
||||
aclient = AsyncAnthropic(api_key=self.config.api_key)
|
||||
|
||||
res = await aclient.completions.create(
|
||||
model="claude-2",
|
||||
|
|
|
|||
|
|
@ -13,12 +13,12 @@
|
|||
from openai import AsyncAzureOpenAI
|
||||
from openai._base_client import AsyncHttpxClientWrapper
|
||||
|
||||
from metagpt.config import LLMProviderEnum
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import OpenAILLM
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.AZURE_OPENAI)
|
||||
@register_provider(LLMType.AZURE_OPENAI)
|
||||
class AzureOpenAILLM(OpenAILLM):
|
||||
"""
|
||||
Check https://platform.openai.com/examples for examples
|
||||
|
|
|
|||
|
|
@ -8,15 +8,30 @@
|
|||
"""
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
|
||||
|
||||
class BaseLLM(ABC):
|
||||
"""LLM API abstract class, requiring all inheritors to provide a series of standard capabilities"""
|
||||
|
||||
config: LLMConfig
|
||||
use_system_prompt: bool = True
|
||||
system_prompt = "You are a helpful assistant."
|
||||
|
||||
# OpenAI / Azure / Others
|
||||
aclient: Optional[Union[AsyncOpenAI]] = None
|
||||
cost_manager: Optional[CostManager] = None
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, config: LLMConfig = None):
|
||||
pass
|
||||
|
||||
def _user_msg(self, msg: str) -> dict[str, str]:
|
||||
return {"role": "user", "content": msg}
|
||||
|
||||
|
|
@ -63,10 +78,9 @@ class BaseLLM(ABC):
|
|||
context.append(self._assistant_msg(rsp_text))
|
||||
return self._extract_assistant_rsp(context)
|
||||
|
||||
async def aask_code(self, msgs: list[str], timeout=3) -> str:
|
||||
async def aask_code(self, messages: Union[str, Message, list[dict]], timeout=3) -> dict:
|
||||
"""FIXME: No code segment filtering has been done here, and all results are actually displayed"""
|
||||
rsp_text = await self.aask_batch(msgs, timeout=timeout)
|
||||
return rsp_text
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def acompletion(self, messages: list[dict], timeout=3):
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from tenacity import (
|
|||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from metagpt.config import CONFIG, Config, LLMProviderEnum
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import OpenAILLM, log_and_reraise
|
||||
|
|
@ -64,27 +64,18 @@ class FireworksCostManager(CostManager):
|
|||
token_costs = self.model_grade_token_costs(model)
|
||||
cost = (prompt_tokens * token_costs["prompt"] + completion_tokens * token_costs["completion"]) / 1000000
|
||||
self.total_cost += cost
|
||||
max_budget = CONFIG.max_budget if CONFIG.max_budget else CONFIG.cost_manager.max_budget
|
||||
logger.info(
|
||||
f"Total running cost: ${self.total_cost:.4f} | Max budget: ${max_budget:.3f} | "
|
||||
f"Total running cost: ${self.total_cost:.4f}"
|
||||
f"Current cost: ${cost:.4f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
|
||||
)
|
||||
CONFIG.total_cost = self.total_cost
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.FIREWORKS)
|
||||
@register_provider(LLMType.FIREWORKS)
|
||||
class FireworksLLM(OpenAILLM):
|
||||
def __init__(self):
|
||||
self.config: Config = CONFIG
|
||||
self.__init_fireworks()
|
||||
def __init__(self, config: LLMConfig = None):
|
||||
super().__init__(config=config)
|
||||
self.auto_max_tokens = False
|
||||
self._cost_manager = FireworksCostManager()
|
||||
|
||||
def __init_fireworks(self):
|
||||
self.is_azure = False
|
||||
self.rpm = int(self.config.get("RPM", 10))
|
||||
self._init_client()
|
||||
self.model = self.config.fireworks_api_model # `self.model` should after `_make_client` to rewrite it
|
||||
self.cost_manager = FireworksCostManager()
|
||||
|
||||
def _make_client_kwargs(self) -> dict:
|
||||
kwargs = dict(api_key=self.config.fireworks_api_key, base_url=self.config.fireworks_api_base)
|
||||
|
|
@ -94,14 +85,14 @@ class FireworksLLM(OpenAILLM):
|
|||
if self.config.calc_usage and usage:
|
||||
try:
|
||||
# use FireworksCostManager not CONFIG.cost_manager
|
||||
self._cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"updating costs failed!, exp: {e}")
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
return self._cost_manager.get_costs()
|
||||
return self.cost_manager.get_costs()
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict]) -> str:
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
|
||||
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
|
||||
**self._cons_kwargs(messages), stream=True
|
||||
)
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from tenacity import (
|
|||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from metagpt.config import CONFIG, LLMProviderEnum
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
|
|
@ -41,21 +41,21 @@ class GeminiGenerativeModel(GenerativeModel):
|
|||
return await self._async_client.count_tokens(model=self.model_name, contents=contents)
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.GEMINI)
|
||||
@register_provider(LLMType.GEMINI)
|
||||
class GeminiLLM(BaseLLM):
|
||||
"""
|
||||
Refs to `https://ai.google.dev/tutorials/python_quickstart`
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, config: LLMConfig = None):
|
||||
self.use_system_prompt = False # google gemini has no system prompt when use api
|
||||
|
||||
self.__init_gemini(CONFIG)
|
||||
self.__init_gemini(config)
|
||||
self.model = "gemini-pro" # so far only one model
|
||||
self.llm = GeminiGenerativeModel(model_name=self.model)
|
||||
|
||||
def __init_gemini(self, config: CONFIG):
|
||||
genai.configure(api_key=config.gemini_api_key)
|
||||
def __init_gemini(self, config: LLMConfig):
|
||||
genai.configure(api_key=config.api_key)
|
||||
|
||||
def _user_msg(self, msg: str) -> dict[str, str]:
|
||||
# Not to change BaseLLM default functions but update with Gemini's conversation format.
|
||||
|
|
@ -71,11 +71,11 @@ class GeminiLLM(BaseLLM):
|
|||
|
||||
def _update_costs(self, usage: dict):
|
||||
"""update each request's token cost"""
|
||||
if CONFIG.calc_usage:
|
||||
if self.config.calc_usage:
|
||||
try:
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0))
|
||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
CONFIG.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"google gemini updats costs failed! exp: {e}")
|
||||
|
||||
|
|
@ -108,7 +108,7 @@ class GeminiLLM(BaseLLM):
|
|||
self._update_costs(usage)
|
||||
return resp
|
||||
|
||||
async def acompletion(self, messages: list[dict]) -> dict:
|
||||
async def acompletion(self, messages: list[dict], timeout=3) -> dict:
|
||||
return await self._achat_completion(messages)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict]) -> str:
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ Author: garylin2099
|
|||
"""
|
||||
from typing import Optional
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
|
||||
|
|
@ -14,6 +15,9 @@ class HumanProvider(BaseLLM):
|
|||
This enables replacing LLM anywhere in the framework with a human, thus introducing human interaction
|
||||
"""
|
||||
|
||||
def __init__(self, config: LLMConfig = None):
|
||||
pass
|
||||
|
||||
def ask(self, msg: str, timeout=3) -> str:
|
||||
logger.info("It's your turn, please type in your response. You may also refer to the context below")
|
||||
rsp = input(msg)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@
|
|||
@Author : alexanderwu
|
||||
@File : llm_provider_registry.py
|
||||
"""
|
||||
from metagpt.config import LLMProviderEnum
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
|
||||
|
||||
class LLMProviderRegistry:
|
||||
|
|
@ -15,13 +16,9 @@ class LLMProviderRegistry:
|
|||
def register(self, key, provider_cls):
|
||||
self.providers[key] = provider_cls
|
||||
|
||||
def get_provider(self, enum: LLMProviderEnum):
|
||||
def get_provider(self, enum: LLMType):
|
||||
"""get provider instance according to the enum"""
|
||||
return self.providers[enum]()
|
||||
|
||||
|
||||
# Registry instance
|
||||
LLM_REGISTRY = LLMProviderRegistry()
|
||||
return self.providers[enum]
|
||||
|
||||
|
||||
def register_provider(key):
|
||||
|
|
@ -32,3 +29,12 @@ def register_provider(key):
|
|||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def get_llm(config: LLMConfig) -> BaseLLM:
|
||||
"""get the default llm provider"""
|
||||
return LLM_REGISTRY.get_provider(config.api_type)(config)
|
||||
|
||||
|
||||
# Registry instance
|
||||
LLM_REGISTRY = LLMProviderRegistry()
|
||||
|
|
|
|||
|
|
@ -5,12 +5,11 @@
|
|||
@File : metagpt_api.py
|
||||
@Desc : MetaGPT LLM provider.
|
||||
"""
|
||||
from metagpt.config import LLMProviderEnum
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.provider import OpenAILLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.METAGPT)
|
||||
@register_provider(LLMType.METAGPT)
|
||||
class MetaGPTLLM(OpenAILLM):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -13,48 +13,33 @@ from tenacity import (
|
|||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from metagpt.config import CONFIG, LLMProviderEnum
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.const import LLM_API_TIMEOUT
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.general_api_requestor import GeneralAPIRequestor
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import log_and_reraise
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.cost_manager import TokenCostManager
|
||||
|
||||
|
||||
class OllamaCostManager(CostManager):
|
||||
def update_cost(self, prompt_tokens, completion_tokens, model):
|
||||
"""
|
||||
Update the total cost, prompt tokens, and completion tokens.
|
||||
"""
|
||||
self.total_prompt_tokens += prompt_tokens
|
||||
self.total_completion_tokens += completion_tokens
|
||||
max_budget = CONFIG.max_budget if CONFIG.max_budget else CONFIG.cost_manager.max_budget
|
||||
logger.info(
|
||||
f"Max budget: ${max_budget:.3f} | "
|
||||
f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
|
||||
)
|
||||
CONFIG.total_cost = self.total_cost
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.OLLAMA)
|
||||
@register_provider(LLMType.OLLAMA)
|
||||
class OllamaLLM(BaseLLM):
|
||||
"""
|
||||
Refs to `https://github.com/jmorganca/ollama/blob/main/docs/api.md#generate-a-chat-completion`
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.__init_ollama(CONFIG)
|
||||
self.client = GeneralAPIRequestor(base_url=CONFIG.ollama_api_base)
|
||||
def __init__(self, config: LLMConfig = None):
|
||||
self.__init_ollama(config)
|
||||
self.client = GeneralAPIRequestor(base_url=config.api_base)
|
||||
self.suffix_url = "/chat"
|
||||
self.http_method = "post"
|
||||
self.use_system_prompt = False
|
||||
self._cost_manager = OllamaCostManager()
|
||||
self._cost_manager = TokenCostManager()
|
||||
|
||||
def __init_ollama(self, config: CONFIG):
|
||||
assert config.ollama_api_base
|
||||
self.model = config.ollama_api_model
|
||||
def __init_ollama(self, config: LLMConfig):
|
||||
assert config.api_base
|
||||
self.model = config.model
|
||||
|
||||
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
|
||||
kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream}
|
||||
|
|
@ -62,7 +47,7 @@ class OllamaLLM(BaseLLM):
|
|||
|
||||
def _update_costs(self, usage: dict):
|
||||
"""update each request's token cost"""
|
||||
if CONFIG.calc_usage:
|
||||
if self.config.calc_usage:
|
||||
try:
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0))
|
||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
|
|
|
|||
|
|
@ -4,56 +4,27 @@
|
|||
|
||||
from openai.types import CompletionUsage
|
||||
|
||||
from metagpt.config import CONFIG, Config, LLMProviderEnum
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import OpenAILLM
|
||||
from metagpt.utils.cost_manager import CostManager, Costs
|
||||
from metagpt.utils.cost_manager import Costs, TokenCostManager
|
||||
from metagpt.utils.token_counter import count_message_tokens, count_string_tokens
|
||||
|
||||
|
||||
class OpenLLMCostManager(CostManager):
|
||||
"""open llm model is self-host, it's free and without cost"""
|
||||
|
||||
def update_cost(self, prompt_tokens, completion_tokens, model):
|
||||
"""
|
||||
Update the total cost, prompt tokens, and completion tokens.
|
||||
|
||||
Args:
|
||||
prompt_tokens (int): The number of tokens used in the prompt.
|
||||
completion_tokens (int): The number of tokens used in the completion.
|
||||
model (str): The model used for the API call.
|
||||
"""
|
||||
self.total_prompt_tokens += prompt_tokens
|
||||
self.total_completion_tokens += completion_tokens
|
||||
max_budget = CONFIG.max_budget if CONFIG.max_budget else CONFIG.cost_manager.max_budget
|
||||
logger.info(
|
||||
f"Max budget: ${max_budget:.3f} | reference "
|
||||
f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
|
||||
)
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.OPEN_LLM)
|
||||
@register_provider(LLMType.OPEN_LLM)
|
||||
class OpenLLM(OpenAILLM):
|
||||
def __init__(self):
|
||||
self.config: Config = CONFIG
|
||||
self.__init_openllm()
|
||||
self.auto_max_tokens = False
|
||||
self._cost_manager = OpenLLMCostManager()
|
||||
|
||||
def __init_openllm(self):
|
||||
self.is_azure = False
|
||||
self.rpm = int(self.config.get("RPM", 10))
|
||||
self._init_client()
|
||||
self.model = self.config.open_llm_api_model # `self.model` should after `_make_client` to rewrite it
|
||||
def __init__(self, config: LLMConfig):
|
||||
super().__init__(config)
|
||||
self._cost_manager = TokenCostManager()
|
||||
|
||||
def _make_client_kwargs(self) -> dict:
|
||||
kwargs = dict(api_key="sk-xxx", base_url=self.config.open_llm_api_base)
|
||||
kwargs = dict(api_key="sk-xxx", base_url=self.config.base_url)
|
||||
return kwargs
|
||||
|
||||
def _calc_usage(self, messages: list[dict], rsp: str) -> CompletionUsage:
|
||||
usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
|
||||
if not CONFIG.calc_usage:
|
||||
if not self.config.calc_usage:
|
||||
return usage
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@
|
|||
"""
|
||||
|
||||
import json
|
||||
from typing import AsyncIterator, Union
|
||||
from typing import AsyncIterator, Optional, Union
|
||||
|
||||
from openai import APIConnectionError, AsyncOpenAI, AsyncStream
|
||||
from openai._base_client import AsyncHttpxClientWrapper
|
||||
|
|
@ -24,13 +24,13 @@ from tenacity import (
|
|||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from metagpt.config import CONFIG, Config, LLMProviderEnum
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA, GENERAL_TOOL_CHOICE
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.cost_manager import Costs
|
||||
from metagpt.utils.cost_manager import CostManager, Costs
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
from metagpt.utils.token_counter import (
|
||||
count_message_tokens,
|
||||
|
|
@ -50,18 +50,19 @@ See FAQ 5.8
|
|||
raise retry_state.outcome.exception()
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.OPENAI)
|
||||
@register_provider(LLMType.OPENAI)
|
||||
class OpenAILLM(BaseLLM):
|
||||
"""Check https://platform.openai.com/examples for examples"""
|
||||
|
||||
def __init__(self):
|
||||
self.config: Config = CONFIG
|
||||
self._init_openai()
|
||||
def __init__(self, config: LLMConfig = None):
|
||||
self.config = config
|
||||
self._init_model()
|
||||
self._init_client()
|
||||
self.auto_max_tokens = False
|
||||
self.cost_manager: Optional[CostManager] = None
|
||||
|
||||
def _init_openai(self):
|
||||
self.model = self.config.OPENAI_API_MODEL # Used in _calc_usage & _cons_kwargs
|
||||
def _init_model(self):
|
||||
self.model = self.config.model # Used in _calc_usage & _cons_kwargs
|
||||
|
||||
def _init_client(self):
|
||||
"""https://github.com/openai/openai-python#async-usage"""
|
||||
|
|
@ -69,7 +70,7 @@ class OpenAILLM(BaseLLM):
|
|||
self.aclient = AsyncOpenAI(**kwargs)
|
||||
|
||||
def _make_client_kwargs(self) -> dict:
|
||||
kwargs = {"api_key": self.config.openai_api_key, "base_url": self.config.openai_base_url}
|
||||
kwargs = {"api_key": self.config.api_key, "base_url": self.config.base_url}
|
||||
|
||||
# to use proxy, openai v1 needs http_client
|
||||
if proxy_params := self._get_proxy_params():
|
||||
|
|
@ -79,10 +80,10 @@ class OpenAILLM(BaseLLM):
|
|||
|
||||
def _get_proxy_params(self) -> dict:
|
||||
params = {}
|
||||
if self.config.openai_proxy:
|
||||
params = {"proxies": self.config.openai_proxy}
|
||||
if self.config.openai_base_url:
|
||||
params["base_url"] = self.config.openai_base_url
|
||||
if self.config.proxy:
|
||||
params = {"proxies": self.config.proxy}
|
||||
if self.config.base_url:
|
||||
params["base_url"] = self.config.base_url
|
||||
|
||||
return params
|
||||
|
||||
|
|
@ -103,7 +104,7 @@ class OpenAILLM(BaseLLM):
|
|||
"stop": None,
|
||||
"temperature": 0.3,
|
||||
"model": self.model,
|
||||
"timeout": max(CONFIG.timeout, timeout),
|
||||
"timeout": max(self.config.timeout, timeout),
|
||||
}
|
||||
if extra_kwargs:
|
||||
kwargs.update(extra_kwargs)
|
||||
|
|
@ -205,7 +206,7 @@ class OpenAILLM(BaseLLM):
|
|||
|
||||
def _calc_usage(self, messages: list[dict], rsp: str) -> CompletionUsage:
|
||||
usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
|
||||
if not CONFIG.calc_usage:
|
||||
if not self.config.calc_usage:
|
||||
return usage
|
||||
|
||||
try:
|
||||
|
|
@ -218,16 +219,16 @@ class OpenAILLM(BaseLLM):
|
|||
|
||||
@handle_exception
|
||||
def _update_costs(self, usage: CompletionUsage):
|
||||
if CONFIG.calc_usage and usage:
|
||||
CONFIG.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
if self.config.calc_usage and usage:
|
||||
self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
return CONFIG.cost_manager.get_costs()
|
||||
return self.cost_manager.get_costs()
|
||||
|
||||
def _get_max_tokens(self, messages: list[dict]):
|
||||
if not self.auto_max_tokens:
|
||||
return CONFIG.max_tokens_rsp
|
||||
return get_max_completion_tokens(messages, self.model, CONFIG.max_tokens_rsp)
|
||||
return self.config.max_token
|
||||
return get_max_completion_tokens(messages, self.model, self.config.max_tokens)
|
||||
|
||||
@handle_exception
|
||||
async def amoderation(self, content: Union[str, list[str]]):
|
||||
|
|
|
|||
|
|
@ -16,15 +16,16 @@ from wsgiref.handlers import format_date_time
|
|||
|
||||
import websocket # 使用websocket_client
|
||||
|
||||
from metagpt.config import CONFIG, LLMProviderEnum
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.SPARK)
|
||||
@register_provider(LLMType.SPARK)
|
||||
class SparkLLM(BaseLLM):
|
||||
def __init__(self):
|
||||
def __init__(self, config: LLMConfig = None):
|
||||
self.config = config
|
||||
logger.warning("当前方法无法支持异步运行。当你使用acompletion时,并不能并行访问。")
|
||||
|
||||
def get_choice_text(self, rsp: dict) -> str:
|
||||
|
|
@ -33,12 +34,12 @@ class SparkLLM(BaseLLM):
|
|||
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
|
||||
# 不支持
|
||||
logger.error("该功能禁用。")
|
||||
w = GetMessageFromWeb(messages)
|
||||
w = GetMessageFromWeb(messages, self.config)
|
||||
return w.run()
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=3):
|
||||
# 不支持异步
|
||||
w = GetMessageFromWeb(messages)
|
||||
w = GetMessageFromWeb(messages, self.config)
|
||||
return w.run()
|
||||
|
||||
|
||||
|
|
@ -89,14 +90,14 @@ class GetMessageFromWeb:
|
|||
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
|
||||
return url
|
||||
|
||||
def __init__(self, text):
|
||||
def __init__(self, text, config):
|
||||
self.text = text
|
||||
self.ret = ""
|
||||
self.spark_appid = CONFIG.spark_appid
|
||||
self.spark_api_secret = CONFIG.spark_api_secret
|
||||
self.spark_api_key = CONFIG.spark_api_key
|
||||
self.domain = CONFIG.domain
|
||||
self.spark_url = CONFIG.spark_url
|
||||
self.spark_appid = config.app_id
|
||||
self.spark_api_secret = config.api_secret
|
||||
self.spark_api_key = config.api_key
|
||||
self.domain = config.domain
|
||||
self.spark_url = config.base_url
|
||||
|
||||
def on_message(self, ws, message):
|
||||
data = json.loads(message)
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from tenacity import (
|
|||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from metagpt.config import CONFIG, LLMProviderEnum
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
|
|
@ -31,27 +31,27 @@ class ZhiPuEvent(Enum):
|
|||
FINISH = "finish"
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.ZHIPUAI)
|
||||
@register_provider(LLMType.ZHIPUAI)
|
||||
class ZhiPuAILLM(BaseLLM):
|
||||
"""
|
||||
Refs to `https://open.bigmodel.cn/dev/api#chatglm_turbo`
|
||||
From now, there is only one model named `chatglm_turbo`
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.__init_zhipuai(CONFIG)
|
||||
def __init__(self, config: LLMConfig = None):
|
||||
self.__init_zhipuai(config)
|
||||
self.llm = ZhiPuModelAPI
|
||||
self.model = "chatglm_turbo" # so far only one model, just use it
|
||||
self.use_system_prompt: bool = False # zhipuai has no system prompt when use api
|
||||
|
||||
def __init_zhipuai(self, config: CONFIG):
|
||||
assert config.zhipuai_api_key
|
||||
zhipuai.api_key = config.zhipuai_api_key
|
||||
def __init_zhipuai(self, config: LLMConfig):
|
||||
assert config.api_key
|
||||
zhipuai.api_key = config.api_key
|
||||
# due to use openai sdk, set the api_key but it will't be used.
|
||||
# openai.api_key = zhipuai.api_key # due to use openai sdk, set the api_key but it will't be used.
|
||||
if config.openai_proxy:
|
||||
if config.proxy:
|
||||
# FIXME: openai v1.x sdk has no proxy support
|
||||
openai.proxy = config.openai_proxy
|
||||
openai.proxy = config.proxy
|
||||
|
||||
def _const_kwargs(self, messages: list[dict]) -> dict:
|
||||
kwargs = {"model": self.model, "prompt": messages, "temperature": 0.3}
|
||||
|
|
@ -59,11 +59,11 @@ class ZhiPuAILLM(BaseLLM):
|
|||
|
||||
def _update_costs(self, usage: dict):
|
||||
"""update each request's token cost"""
|
||||
if CONFIG.calc_usage:
|
||||
if self.config.calc_usage:
|
||||
try:
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0))
|
||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
CONFIG.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
self.config.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"zhipuai updats costs failed! exp: {e}")
|
||||
|
||||
|
|
|
|||
|
|
@ -27,7 +27,6 @@ from typing import Set
|
|||
from metagpt.actions import Action, WriteCode, WriteCodeReview, WriteTasks
|
||||
from metagpt.actions.fix_bug import FixBug
|
||||
from metagpt.actions.summarize_code import SummarizeCode
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import (
|
||||
CODE_SUMMARIES_FILE_REPO,
|
||||
CODE_SUMMARIES_PDF_FILE_REPO,
|
||||
|
|
@ -80,6 +79,7 @@ class Engineer(Role):
|
|||
code_todos: list = []
|
||||
summarize_todos: list = []
|
||||
next_todo_action: str = ""
|
||||
n_summarize: int = 0
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
|
@ -97,7 +97,7 @@ class Engineer(Role):
|
|||
|
||||
async def _act_sp_with_cr(self, review=False) -> Set[str]:
|
||||
changed_files = set()
|
||||
src_file_repo = CONFIG.git_repo.new_file_repository(CONFIG.src_workspace)
|
||||
src_file_repo = self.git_repo.new_file_repository(self.src_workspace)
|
||||
for todo in self.code_todos:
|
||||
"""
|
||||
# Select essential information from the historical data to reduce the length of the prompt (summarized from human experience):
|
||||
|
|
@ -153,10 +153,10 @@ class Engineer(Role):
|
|||
)
|
||||
|
||||
async def _act_summarize(self):
|
||||
code_summaries_file_repo = CONFIG.git_repo.new_file_repository(CODE_SUMMARIES_FILE_REPO)
|
||||
code_summaries_pdf_file_repo = CONFIG.git_repo.new_file_repository(CODE_SUMMARIES_PDF_FILE_REPO)
|
||||
code_summaries_file_repo = self.git_repo.new_file_repository(CODE_SUMMARIES_FILE_REPO)
|
||||
code_summaries_pdf_file_repo = self.git_repo.new_file_repository(CODE_SUMMARIES_PDF_FILE_REPO)
|
||||
tasks = []
|
||||
src_relative_path = CONFIG.src_workspace.relative_to(CONFIG.git_repo.workdir)
|
||||
src_relative_path = self.src_workspace.relative_to(self.git_repo.workdir)
|
||||
for todo in self.summarize_todos:
|
||||
summary = await todo.run()
|
||||
summary_filename = Path(todo.context.design_filename).with_suffix(".md").name
|
||||
|
|
@ -179,8 +179,8 @@ class Engineer(Role):
|
|||
else:
|
||||
await code_summaries_file_repo.delete(filename=Path(todo.context.design_filename).name)
|
||||
|
||||
logger.info(f"--max-auto-summarize-code={CONFIG.max_auto_summarize_code}")
|
||||
if not tasks or CONFIG.max_auto_summarize_code == 0:
|
||||
logger.info(f"--max-auto-summarize-code={self.config.max_auto_summarize_code}")
|
||||
if not tasks or self.config.max_auto_summarize_code == 0:
|
||||
return Message(
|
||||
content="",
|
||||
role=self.profile,
|
||||
|
|
@ -190,7 +190,7 @@ class Engineer(Role):
|
|||
)
|
||||
# The maximum number of times the 'SummarizeCode' action is automatically invoked, with -1 indicating unlimited.
|
||||
# This parameter is used for debugging the workflow.
|
||||
CONFIG.max_auto_summarize_code -= 1 if CONFIG.max_auto_summarize_code > 0 else 0
|
||||
self.n_summarize += 1 if self.config.max_auto_summarize_code > self.n_summarize else 0
|
||||
return Message(
|
||||
content=json.dumps(tasks), role=self.profile, cause_by=SummarizeCode, send_to=self, sent_from=self
|
||||
)
|
||||
|
|
@ -203,8 +203,8 @@ class Engineer(Role):
|
|||
return False, rsp
|
||||
|
||||
async def _think(self) -> Action | None:
|
||||
if not CONFIG.src_workspace:
|
||||
CONFIG.src_workspace = CONFIG.git_repo.workdir / CONFIG.git_repo.workdir.name
|
||||
if not self.src_workspace:
|
||||
self.src_workspace = self.git_repo.workdir / self.git_repo.workdir.name
|
||||
write_code_filters = any_to_str_set([WriteTasks, SummarizeCode, FixBug])
|
||||
summarize_code_filters = any_to_str_set([WriteCode, WriteCodeReview])
|
||||
if not self.rc.news:
|
||||
|
|
@ -253,11 +253,11 @@ class Engineer(Role):
|
|||
|
||||
async def _new_code_actions(self, bug_fix=False):
|
||||
# Prepare file repos
|
||||
src_file_repo = CONFIG.git_repo.new_file_repository(CONFIG.src_workspace)
|
||||
src_file_repo = self.git_repo.new_file_repository(self.src_workspace)
|
||||
changed_src_files = src_file_repo.all_files if bug_fix else src_file_repo.changed_files
|
||||
task_file_repo = CONFIG.git_repo.new_file_repository(TASK_FILE_REPO)
|
||||
task_file_repo = self.git_repo.new_file_repository(TASK_FILE_REPO)
|
||||
changed_task_files = task_file_repo.changed_files
|
||||
design_file_repo = CONFIG.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO)
|
||||
design_file_repo = self.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO)
|
||||
|
||||
changed_files = Documents()
|
||||
# Recode caused by upstream changes.
|
||||
|
|
@ -283,7 +283,7 @@ class Engineer(Role):
|
|||
changed_files.docs[task_filename] = coding_doc
|
||||
self.code_todos = [WriteCode(context=i, llm=self.llm) for i in changed_files.docs.values()]
|
||||
# Code directly modified by the user.
|
||||
dependency = await CONFIG.git_repo.get_dependency()
|
||||
dependency = await self.git_repo.get_dependency()
|
||||
for filename in changed_src_files:
|
||||
if filename in changed_files.docs:
|
||||
continue
|
||||
|
|
@ -301,7 +301,7 @@ class Engineer(Role):
|
|||
self.rc.todo = self.code_todos[0]
|
||||
|
||||
async def _new_summarize_actions(self):
|
||||
src_file_repo = CONFIG.git_repo.new_file_repository(CONFIG.src_workspace)
|
||||
src_file_repo = self.git_repo.new_file_repository(self.src_workspace)
|
||||
src_files = src_file_repo.all_files
|
||||
# Generate a SummarizeCode action for each pair of (system_design_doc, task_doc).
|
||||
summarizations = defaultdict(list)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@
|
|||
|
||||
from metagpt.actions import UserRequirement, WritePRD
|
||||
from metagpt.actions.prepare_documents import PrepareDocuments
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.roles.role import Role
|
||||
from metagpt.utils.common import any_to_name
|
||||
|
||||
|
|
@ -40,11 +39,11 @@ class ProductManager(Role):
|
|||
|
||||
async def _think(self) -> bool:
|
||||
"""Decide what to do"""
|
||||
if CONFIG.git_repo and not CONFIG.git_reinit:
|
||||
if self.git_repo and not self.config.git_reinit:
|
||||
self._set_state(1)
|
||||
else:
|
||||
self._set_state(0)
|
||||
CONFIG.git_reinit = False
|
||||
self.context.config.git_reinit = False
|
||||
self.todo_action = any_to_name(WritePRD)
|
||||
return bool(self.rc.todo)
|
||||
|
||||
|
|
|
|||
|
|
@ -15,10 +15,9 @@
|
|||
of SummarizeCode.
|
||||
"""
|
||||
|
||||
|
||||
from metagpt.actions import DebugError, RunCode, WriteTest
|
||||
from metagpt.actions.summarize_code import SummarizeCode
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.const import (
|
||||
MESSAGE_ROUTE_TO_NONE,
|
||||
TEST_CODES_FILE_REPO,
|
||||
|
|
@ -50,13 +49,17 @@ class QaEngineer(Role):
|
|||
self._watch([SummarizeCode, WriteTest, RunCode, DebugError])
|
||||
self.test_round = 0
|
||||
|
||||
@property
|
||||
def config(self) -> Config:
|
||||
return self.context.config
|
||||
|
||||
async def _write_test(self, message: Message) -> None:
|
||||
src_file_repo = CONFIG.git_repo.new_file_repository(CONFIG.src_workspace)
|
||||
src_file_repo = self.context.git_repo.new_file_repository(self.context.src_workspace)
|
||||
changed_files = set(src_file_repo.changed_files.keys())
|
||||
# Unit tests only.
|
||||
if CONFIG.reqa_file and CONFIG.reqa_file not in changed_files:
|
||||
changed_files.add(CONFIG.reqa_file)
|
||||
tests_file_repo = CONFIG.git_repo.new_file_repository(TEST_CODES_FILE_REPO)
|
||||
if self.config.reqa_file and self.config.reqa_file not in changed_files:
|
||||
changed_files.add(self.config.reqa_file)
|
||||
tests_file_repo = self.context.git_repo.new_file_repository(TEST_CODES_FILE_REPO)
|
||||
for filename in changed_files:
|
||||
# write tests
|
||||
if not filename or "test" in filename:
|
||||
|
|
@ -69,7 +72,7 @@ class QaEngineer(Role):
|
|||
)
|
||||
logger.info(f"Writing {test_doc.filename}..")
|
||||
context = TestingContext(filename=test_doc.filename, test_doc=test_doc, code_doc=code_doc)
|
||||
context = await WriteTest(context=context, llm=self.llm).run()
|
||||
context = await WriteTest(context=context, _context=self.context, llm=self.llm).run()
|
||||
await tests_file_repo.save(
|
||||
filename=context.test_doc.filename,
|
||||
content=context.test_doc.content,
|
||||
|
|
@ -81,8 +84,8 @@ class QaEngineer(Role):
|
|||
command=["python", context.test_doc.root_relative_path],
|
||||
code_filename=context.code_doc.filename,
|
||||
test_filename=context.test_doc.filename,
|
||||
working_directory=str(CONFIG.git_repo.workdir),
|
||||
additional_python_paths=[str(CONFIG.src_workspace)],
|
||||
working_directory=str(self.context.git_repo.workdir),
|
||||
additional_python_paths=[str(self.context.src_workspace)],
|
||||
)
|
||||
self.publish_message(
|
||||
Message(
|
||||
|
|
@ -98,17 +101,21 @@ class QaEngineer(Role):
|
|||
|
||||
async def _run_code(self, msg):
|
||||
run_code_context = RunCodeContext.loads(msg.content)
|
||||
src_doc = await CONFIG.git_repo.new_file_repository(CONFIG.src_workspace).get(run_code_context.code_filename)
|
||||
src_doc = await self.context.git_repo.new_file_repository(self.context.src_workspace).get(
|
||||
run_code_context.code_filename
|
||||
)
|
||||
if not src_doc:
|
||||
return
|
||||
test_doc = await CONFIG.git_repo.new_file_repository(TEST_CODES_FILE_REPO).get(run_code_context.test_filename)
|
||||
test_doc = await self.context.git_repo.new_file_repository(TEST_CODES_FILE_REPO).get(
|
||||
run_code_context.test_filename
|
||||
)
|
||||
if not test_doc:
|
||||
return
|
||||
run_code_context.code = src_doc.content
|
||||
run_code_context.test_code = test_doc.content
|
||||
result = await RunCode(context=run_code_context, llm=self.llm).run()
|
||||
run_code_context.output_filename = run_code_context.test_filename + ".json"
|
||||
await CONFIG.git_repo.new_file_repository(TEST_OUTPUTS_FILE_REPO).save(
|
||||
await self.context.git_repo.new_file_repository(TEST_OUTPUTS_FILE_REPO).save(
|
||||
filename=run_code_context.output_filename,
|
||||
content=result.model_dump_json(),
|
||||
dependencies={src_doc.root_relative_path, test_doc.root_relative_path},
|
||||
|
|
|
|||
|
|
@ -32,9 +32,11 @@ from metagpt.actions import Action, ActionOutput
|
|||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.add_requirement import UserRequirement
|
||||
from metagpt.const import SERDESER_PATH
|
||||
from metagpt.llm import LLM, HumanProvider
|
||||
from metagpt.context import Context
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.memory import Memory
|
||||
from metagpt.provider import HumanProvider
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Message, MessageQueue, SerializationMixin
|
||||
from metagpt.utils.common import (
|
||||
|
|
@ -148,9 +150,46 @@ class Role(SerializationMixin, is_polymorphic_base=True):
|
|||
# builtin variables
|
||||
recovered: bool = False # to tag if a recovered role
|
||||
latest_observed_msg: Optional[Message] = None # record the latest observed message when interrupted
|
||||
context: Optional[Context] = Field(default=None, exclude=True)
|
||||
|
||||
__hash__ = object.__hash__ # support Role as hashable type in `Environment.members`
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return self.context.config
|
||||
|
||||
@property
|
||||
def git_repo(self):
|
||||
return self.context.git_repo
|
||||
|
||||
@git_repo.setter
|
||||
def git_repo(self, value):
|
||||
self.context.git_repo = value
|
||||
|
||||
@property
|
||||
def src_workspace(self):
|
||||
return self.context.src_workspace
|
||||
|
||||
@src_workspace.setter
|
||||
def src_workspace(self, value):
|
||||
self.context.src_workspace = value
|
||||
|
||||
@property
|
||||
def prompt_schema(self):
|
||||
return self.context.config.prompt_schema
|
||||
|
||||
@property
|
||||
def project_name(self):
|
||||
return self.context.config.project_name
|
||||
|
||||
@project_name.setter
|
||||
def project_name(self, value):
|
||||
self.context.config.project_name = value
|
||||
|
||||
@property
|
||||
def project_path(self):
|
||||
return self.context.config.project_path
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_subscription(self):
|
||||
if not self.subscription:
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ import aiofiles
|
|||
|
||||
from metagpt.actions import UserRequirement
|
||||
from metagpt.actions.write_teaching_plan import TeachingPlanBlock, WriteTeachingPlanPart
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles import Role
|
||||
from metagpt.schema import Message
|
||||
|
|
@ -81,7 +80,7 @@ class Teacher(Role):
|
|||
async def save(self, content):
|
||||
"""Save teaching plan"""
|
||||
filename = Teacher.new_file_name(self.course_title)
|
||||
pathname = CONFIG.workspace_path / "teaching_plan"
|
||||
pathname = self.config.workspace.path / "teaching_plan"
|
||||
pathname.mkdir(exist_ok=True)
|
||||
pathname = pathname / filename
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -35,7 +35,6 @@ from pydantic import (
|
|||
)
|
||||
from pydantic_core import core_schema
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import (
|
||||
MESSAGE_ROUTE_CAUSE_BY,
|
||||
MESSAGE_ROUTE_FROM,
|
||||
|
|
@ -151,12 +150,6 @@ class Document(BaseModel):
|
|||
"""
|
||||
return os.path.join(self.root_path, self.filename)
|
||||
|
||||
@property
|
||||
def full_path(self):
|
||||
if not CONFIG.git_repo:
|
||||
return None
|
||||
return str(CONFIG.git_repo.workdir / self.root_path / self.filename)
|
||||
|
||||
def __str__(self):
|
||||
return self.content
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from pathlib import Path
|
|||
|
||||
import typer
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.config2 import config
|
||||
|
||||
app = typer.Typer(add_completion=False)
|
||||
|
||||
|
|
@ -44,7 +44,7 @@ def startup(
|
|||
)
|
||||
from metagpt.team import Team
|
||||
|
||||
CONFIG.update_via_cli(project_path, project_name, inc, reqa_file, max_auto_summarize_code)
|
||||
config.update_via_cli(project_path, project_name, inc, reqa_file, max_auto_summarize_code)
|
||||
|
||||
if not recover_path:
|
||||
company = Team()
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ from typing import Any
|
|||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from metagpt.actions import UserRequirement
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import MESSAGE_ROUTE_TO_ALL, SERDESER_PATH
|
||||
from metagpt.environment import Environment
|
||||
from metagpt.logs import logger
|
||||
|
|
@ -79,18 +78,20 @@ class Team(BaseModel):
|
|||
"""Hire roles to cooperate"""
|
||||
self.env.add_roles(roles)
|
||||
|
||||
@property
|
||||
def cost_manager(self):
|
||||
"""Get cost manager"""
|
||||
return self.env.context.cost_manager
|
||||
|
||||
def invest(self, investment: float):
|
||||
"""Invest company. raise NoMoneyException when exceed max_budget."""
|
||||
self.investment = investment
|
||||
CONFIG.max_budget = investment
|
||||
self.cost_manager.max_budget = investment
|
||||
logger.info(f"Investment: ${investment}.")
|
||||
|
||||
@staticmethod
|
||||
def _check_balance():
|
||||
if CONFIG.cost_manager.total_cost > CONFIG.cost_manager.max_budget:
|
||||
raise NoMoneyException(
|
||||
CONFIG.cost_manager.total_cost, f"Insufficient funds: {CONFIG.cost_manager.max_budget}"
|
||||
)
|
||||
def _check_balance(self):
|
||||
if self.cost_manager.total_cost > self.cost_manager.max_budget:
|
||||
raise NoMoneyException(self.cost_manager.total_cost, f"Insufficient funds: {self.cost_manager.max_budget}")
|
||||
|
||||
def run_project(self, idea, send_to: str = ""):
|
||||
"""Run a project from publishing user requirement."""
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ import aiohttp
|
|||
import requests
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
|
|
@ -22,7 +21,7 @@ class MetaGPTText2Image:
|
|||
"""
|
||||
:param model_url: Model reset api url
|
||||
"""
|
||||
self.model_url = model_url if model_url else CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL
|
||||
self.model_url = model_url
|
||||
|
||||
async def text_2_image(self, text, size_type="512x512"):
|
||||
"""Text to image
|
||||
|
|
@ -93,6 +92,4 @@ async def oas3_metagpt_text_to_image(text, size_type: str = "512x512", model_url
|
|||
"""
|
||||
if not text:
|
||||
return ""
|
||||
if not model_url:
|
||||
model_url = CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL
|
||||
return await MetaGPTText2Image(model_url).text_2_image(text, size_type=size_type)
|
||||
|
|
|
|||
|
|
@ -10,16 +10,16 @@
|
|||
import aiohttp
|
||||
import requests
|
||||
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
|
||||
|
||||
class OpenAIText2Image:
|
||||
def __init__(self):
|
||||
def __init__(self, llm: BaseLLM):
|
||||
"""
|
||||
:param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`
|
||||
"""
|
||||
self._llm = LLM()
|
||||
self.llm = llm
|
||||
|
||||
async def text_2_image(self, text, size_type="1024x1024"):
|
||||
"""Text to image
|
||||
|
|
@ -29,7 +29,7 @@ class OpenAIText2Image:
|
|||
:return: The image data is returned in Base64 encoding.
|
||||
"""
|
||||
try:
|
||||
result = await self._llm.aclient.images.generate(prompt=text, n=1, size=size_type)
|
||||
result = await self.llm.aclient.images.generate(prompt=text, n=1, size=size_type)
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred:{e}")
|
||||
return ""
|
||||
|
|
@ -57,13 +57,14 @@ class OpenAIText2Image:
|
|||
|
||||
|
||||
# Export
|
||||
async def oas3_openai_text_to_image(text, size_type: str = "1024x1024"):
|
||||
async def oas3_openai_text_to_image(text, size_type: str = "1024x1024", llm: BaseLLM = None):
|
||||
"""Text to image
|
||||
|
||||
:param text: The text used for image conversion.
|
||||
:param size_type: One of ['256x256', '512x512', '1024x1024']
|
||||
:param llm: LLM instance
|
||||
:return: The image data is returned in Base64 encoding.
|
||||
"""
|
||||
if not text:
|
||||
return ""
|
||||
return await OpenAIText2Image().text_2_image(text, size_type=size_type)
|
||||
return await OpenAIText2Image(llm).text_2_image(text, size_type=size_type)
|
||||
|
|
|
|||
|
|
@ -77,7 +77,7 @@ class SDEngine:
|
|||
return self.payload
|
||||
|
||||
def _save(self, imgs, save_name=""):
|
||||
save_dir = CONFIG.workspace_path / SD_OUTPUT_FILE_REPO
|
||||
save_dir = CONFIG.path / SD_OUTPUT_FILE_REPO
|
||||
if not save_dir.exists():
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
batch_decode_base64_to_image(imgs, str(save_dir), save_name=save_name)
|
||||
|
|
|
|||
|
|
@ -80,3 +80,20 @@ class CostManager(BaseModel):
|
|||
def get_costs(self) -> Costs:
|
||||
"""Get all costs"""
|
||||
return Costs(self.total_prompt_tokens, self.total_completion_tokens, self.total_cost, self.total_budget)
|
||||
|
||||
|
||||
class TokenCostManager(CostManager):
|
||||
"""open llm model is self-host, it's free and without cost"""
|
||||
|
||||
def update_cost(self, prompt_tokens, completion_tokens, model):
|
||||
"""
|
||||
Update the total cost, prompt tokens, and completion tokens.
|
||||
|
||||
Args:
|
||||
prompt_tokens (int): The number of tokens used in the prompt.
|
||||
completion_tokens (int): The number of tokens used in the completion.
|
||||
model (str): The model used for the API call.
|
||||
"""
|
||||
self.total_prompt_tokens += prompt_tokens
|
||||
self.total_completion_tokens += completion_tokens
|
||||
logger.info(f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}")
|
||||
|
|
|
|||
16
metagpt/utils/embedding.py
Normal file
16
metagpt/utils/embedding.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/1/4 20:58
|
||||
@Author : alexanderwu
|
||||
@File : embedding.py
|
||||
"""
|
||||
from langchain_community.embeddings import OpenAIEmbeddings
|
||||
|
||||
from metagpt.config2 import config
|
||||
|
||||
|
||||
def get_embedding():
|
||||
llm = config.get_openai_llm()
|
||||
embedding = OpenAIEmbeddings(openai_api_key=llm.api_key, openai_api_base=llm.base_url)
|
||||
return embedding
|
||||
|
|
@ -12,26 +12,25 @@ from datetime import timedelta
|
|||
|
||||
import aioredis # https://aioredis.readthedocs.io/en/latest/getting-started/
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.configs.redis_config import RedisConfig
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
class Redis:
|
||||
def __init__(self):
|
||||
def __init__(self, config: RedisConfig = None):
|
||||
self.config = config
|
||||
self._client = None
|
||||
|
||||
async def _connect(self, force=False):
|
||||
if self._client and not force:
|
||||
return True
|
||||
if not self.is_configured:
|
||||
return False
|
||||
|
||||
try:
|
||||
self._client = await aioredis.from_url(
|
||||
f"redis://{CONFIG.REDIS_HOST}:{CONFIG.REDIS_PORT}",
|
||||
username=CONFIG.REDIS_USER,
|
||||
password=CONFIG.REDIS_PASSWORD,
|
||||
db=CONFIG.REDIS_DB,
|
||||
self.config.to_url(),
|
||||
username=self.config.username,
|
||||
password=self.config.password,
|
||||
db=self.config.db,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
|
|
@ -62,18 +61,3 @@ class Redis:
|
|||
return
|
||||
await self._client.close()
|
||||
self._client = None
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
return self._client is not None
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
return bool(
|
||||
CONFIG.REDIS_HOST
|
||||
and CONFIG.REDIS_HOST != "YOUR_REDIS_HOST"
|
||||
and CONFIG.REDIS_PORT
|
||||
and CONFIG.REDIS_PORT != "YOUR_REDIS_PORT"
|
||||
and CONFIG.REDIS_DB is not None
|
||||
and CONFIG.REDIS_PASSWORD is not None
|
||||
)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from typing import Optional
|
|||
import aioboto3
|
||||
import aiofiles
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.config2 import S3Config
|
||||
from metagpt.const import BASE64_FORMAT
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
|
@ -16,13 +16,14 @@ from metagpt.logs import logger
|
|||
class S3:
|
||||
"""A class for interacting with Amazon S3 storage."""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, config: S3Config):
|
||||
self.session = aioboto3.Session()
|
||||
self.config = config
|
||||
self.auth_config = {
|
||||
"service_name": "s3",
|
||||
"aws_access_key_id": CONFIG.S3_ACCESS_KEY,
|
||||
"aws_secret_access_key": CONFIG.S3_SECRET_KEY,
|
||||
"endpoint_url": CONFIG.S3_ENDPOINT_URL,
|
||||
"aws_access_key_id": config.access_key,
|
||||
"aws_secret_access_key": config.secret_key,
|
||||
"endpoint_url": config.endpoint,
|
||||
}
|
||||
|
||||
async def upload_file(
|
||||
|
|
@ -139,8 +140,8 @@ class S3:
|
|||
data = base64.b64decode(data) if format == BASE64_FORMAT else data.encode(encoding="utf-8")
|
||||
await file.write(data)
|
||||
|
||||
bucket = CONFIG.S3_BUCKET
|
||||
object_pathname = CONFIG.S3_BUCKET or "system"
|
||||
bucket = self.config.bucket
|
||||
object_pathname = self.config.bucket or "system"
|
||||
object_pathname += f"/{object_name}"
|
||||
object_pathname = os.path.normpath(object_pathname)
|
||||
await self.upload_file(bucket=bucket, local_path=str(pathname), object_name=object_pathname)
|
||||
|
|
@ -151,20 +152,3 @@ class S3:
|
|||
logger.exception(f"{e}, stack:{traceback.format_exc()}")
|
||||
pathname.unlink(missing_ok=True)
|
||||
return None
|
||||
|
||||
@property
|
||||
def is_valid(self):
|
||||
return self.is_configured
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
return bool(
|
||||
CONFIG.S3_ACCESS_KEY
|
||||
and CONFIG.S3_ACCESS_KEY != "YOUR_S3_ACCESS_KEY"
|
||||
and CONFIG.S3_SECRET_KEY
|
||||
and CONFIG.S3_SECRET_KEY != "YOUR_S3_SECRET_KEY"
|
||||
and CONFIG.S3_ENDPOINT_URL
|
||||
and CONFIG.S3_ENDPOINT_URL != "YOUR_S3_ENDPOINT_URL"
|
||||
and CONFIG.S3_BUCKET
|
||||
and CONFIG.S3_BUCKET != "YOUR_S3_BUCKET"
|
||||
)
|
||||
|
|
|
|||
38
metagpt/utils/yaml_model.py
Normal file
38
metagpt/utils/yaml_model.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/1/4 10:18
|
||||
@Author : alexanderwu
|
||||
@File : YamlModel.py
|
||||
"""
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
|
||||
class YamlModel(BaseModel):
|
||||
extra_fields: Optional[Dict[str, str]] = None
|
||||
|
||||
@classmethod
|
||||
def read_yaml(cls, file_path: Path) -> Dict:
|
||||
with open(file_path, "r") as file:
|
||||
return yaml.safe_load(file)
|
||||
|
||||
@classmethod
|
||||
def model_validate_yaml(cls, file_path: Path) -> "YamlModel":
|
||||
return cls(**cls.read_yaml(file_path))
|
||||
|
||||
def model_dump_yaml(self, file_path: Path) -> None:
|
||||
with open(file_path, "w") as file:
|
||||
yaml.dump(self.model_dump(), file)
|
||||
|
||||
|
||||
class YamlModelWithoutDefault(YamlModel):
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_not_default_config(cls, values):
|
||||
if any(["YOUR" in v for v in values]):
|
||||
raise ValueError("Please set your S3 config in config.yaml")
|
||||
return values
|
||||
|
|
@ -10,29 +10,26 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.learn.text_to_image import text_to_image
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metagpt_llm():
|
||||
# Prerequisites
|
||||
assert CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL
|
||||
assert CONFIG.OPENAI_API_KEY
|
||||
async def test_metagpt_text_to_image():
|
||||
config = Config()
|
||||
assert config.METAGPT_TEXT_TO_IMAGE_MODEL_URL
|
||||
|
||||
data = await text_to_image("Panda emoji", size_type="512x512")
|
||||
data = await text_to_image("Panda emoji", size_type="512x512", model_url=config.METAGPT_TEXT_TO_IMAGE_MODEL_URL)
|
||||
assert "base64" in data or "http" in data
|
||||
|
||||
# Mock session env
|
||||
old_options = CONFIG.options.copy()
|
||||
new_options = old_options.copy()
|
||||
new_options["METAGPT_TEXT_TO_IMAGE_MODEL_URL"] = None
|
||||
CONFIG.set_context(new_options)
|
||||
try:
|
||||
data = await text_to_image("Panda emoji", size_type="512x512")
|
||||
assert "base64" in data or "http" in data
|
||||
finally:
|
||||
CONFIG.set_context(old_options)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_text_to_image():
|
||||
config = Config.default()
|
||||
assert config.get_openai_llm()
|
||||
|
||||
data = await text_to_image("Panda emoji", size_type="512x512", config=config)
|
||||
assert "base64" in data or "http" in data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.config import LLMProviderEnum
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.memory.brain_memory import BrainMemory
|
||||
from metagpt.schema import Message
|
||||
|
|
@ -46,7 +46,7 @@ def test_extract_info(input, tag, val):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("llm", [LLM(provider=LLMProviderEnum.OPENAI), LLM(provider=LLMProviderEnum.METAGPT)])
|
||||
@pytest.mark.parametrize("llm", [LLM(provider=LLMType.OPENAI), LLM(provider=LLMType.METAGPT)])
|
||||
async def test_memory_llm(llm):
|
||||
memory = BrainMemory()
|
||||
for i in range(500):
|
||||
|
|
|
|||
|
|
@ -3,12 +3,8 @@
|
|||
# @Desc :
|
||||
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.azure_openai_api import AzureOpenAILLM
|
||||
|
||||
CONFIG.OPENAI_API_VERSION = "xx"
|
||||
CONFIG.openai_proxy = "http://127.0.0.1:80" # fake value
|
||||
from metagpt.context import context
|
||||
|
||||
|
||||
def test_azure_openai_api():
|
||||
_ = AzureOpenAILLM()
|
||||
_ = context.llm("azure")
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
|
@ -28,6 +29,9 @@ resp_content = default_chat_resp["choices"][0]["message"]["content"]
|
|||
|
||||
|
||||
class MockBaseLLM(BaseLLM):
|
||||
def __init__(self, config: LLMConfig = None):
|
||||
pass
|
||||
|
||||
def completion(self, messages: list[dict], timeout=3):
|
||||
return default_chat_resp
|
||||
|
||||
|
|
@ -102,5 +106,5 @@ async def test_async_base_llm():
|
|||
resp = await base_llm.aask_batch([prompt_msg])
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await base_llm.aask_code([prompt_msg])
|
||||
assert resp == resp_content
|
||||
# resp = await base_llm.aask_code([prompt_msg])
|
||||
# assert resp == resp_content
|
||||
|
|
|
|||
|
|
@ -5,10 +5,10 @@
|
|||
@Author : mashenquan
|
||||
@File : test_metagpt_api.py
|
||||
"""
|
||||
from metagpt.config import LLMProviderEnum
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.llm import LLM
|
||||
|
||||
|
||||
def test_llm():
|
||||
llm = LLM(provider=LLMProviderEnum.METAGPT)
|
||||
llm = LLM(provider=LLMType.METAGPT)
|
||||
assert llm
|
||||
|
|
|
|||
|
|
@ -2,18 +2,18 @@ from unittest.mock import Mock
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.openai_api import OpenAILLM
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import UserMessage
|
||||
|
||||
CONFIG.openai_proxy = None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aask_code():
|
||||
llm = OpenAILLM()
|
||||
llm = LLM(name="gpt3t")
|
||||
msg = [{"role": "user", "content": "Write a python hello world code."}]
|
||||
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
|
||||
logger.info(rsp)
|
||||
assert "language" in rsp
|
||||
assert "code" in rsp
|
||||
assert len(rsp["code"]) > 0
|
||||
|
|
@ -21,7 +21,7 @@ async def test_aask_code():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aask_code_str():
|
||||
llm = OpenAILLM()
|
||||
llm = LLM(name="gpt3t")
|
||||
msg = "Write a python hello world code."
|
||||
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
assert "language" in rsp
|
||||
|
|
@ -30,8 +30,8 @@ async def test_aask_code_str():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aask_code_Message():
|
||||
llm = OpenAILLM()
|
||||
async def test_aask_code_message():
|
||||
llm = LLM(name="gpt3t")
|
||||
msg = UserMessage("Write a python hello world code.")
|
||||
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
assert "language" in rsp
|
||||
|
|
|
|||
|
|
@ -28,6 +28,6 @@ def load_existing_repo(path):
|
|||
|
||||
|
||||
def test_repo_set_load():
|
||||
repo_path = CONFIG.workspace_path / "test_repo"
|
||||
repo_path = CONFIG.path / "test_repo"
|
||||
set_existing_repo(repo_path)
|
||||
load_existing_repo(repo_path)
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ import pytest
|
|||
from metagpt.actions import Action
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.write_code import WriteCode
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
|
||||
from metagpt.schema import (
|
||||
AIMessage,
|
||||
|
|
@ -119,8 +118,6 @@ def test_document():
|
|||
assert doc.filename == meta_doc.filename
|
||||
assert meta_doc.content == ""
|
||||
|
||||
assert doc.full_path == str(CONFIG.git_repo.workdir / doc.root_path / doc.filename)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_queue():
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ async def test_azure_tts():
|
|||
“Writing a binary file in Python is similar to writing a regular text file, but you'll work with bytes instead of strings.”
|
||||
</mstts:express-as>
|
||||
"""
|
||||
path = CONFIG.workspace_path / "tts"
|
||||
path = CONFIG.path / "tts"
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
filename = path / "girl.wav"
|
||||
filename.unlink(missing_ok=True)
|
||||
|
|
|
|||
|
|
@ -22,5 +22,5 @@ def test_sd_engine_generate_prompt():
|
|||
async def test_sd_engine_run_t2i():
|
||||
sd_engine = SDEngine()
|
||||
await sd_engine.run_t2i(prompts=["test"])
|
||||
img_path = CONFIG.workspace_path / "resources" / "SD_Output" / "output_0.png"
|
||||
img_path = CONFIG.path / "resources" / "SD_Output" / "output_0.png"
|
||||
assert os.path.exists(img_path)
|
||||
|
|
|
|||
|
|
@ -8,38 +8,19 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.utils.redis import Redis
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis():
|
||||
# Prerequisites
|
||||
assert CONFIG.REDIS_HOST and CONFIG.REDIS_HOST != "YOUR_REDIS_HOST"
|
||||
assert CONFIG.REDIS_PORT and CONFIG.REDIS_PORT != "YOUR_REDIS_PORT"
|
||||
# assert CONFIG.REDIS_USER
|
||||
assert CONFIG.REDIS_PASSWORD is not None and CONFIG.REDIS_PASSWORD != "YOUR_REDIS_PASSWORD"
|
||||
assert CONFIG.REDIS_DB is not None and CONFIG.REDIS_DB != "YOUR_REDIS_DB_INDEX, str, 0-based"
|
||||
redis = Config.default().redis
|
||||
|
||||
conn = Redis()
|
||||
assert not conn.is_valid
|
||||
conn = Redis(redis)
|
||||
await conn.set("test", "test", timeout_sec=0)
|
||||
assert await conn.get("test") == b"test"
|
||||
await conn.close()
|
||||
|
||||
# Mock session env
|
||||
old_options = CONFIG.options.copy()
|
||||
new_options = old_options.copy()
|
||||
new_options["REDIS_HOST"] = "YOUR_REDIS_HOST"
|
||||
CONFIG.set_context(new_options)
|
||||
try:
|
||||
conn = Redis()
|
||||
await conn.set("test", "test", timeout_sec=0)
|
||||
assert not await conn.get("test") == b"test"
|
||||
await conn.close()
|
||||
finally:
|
||||
CONFIG.set_context(old_options)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -11,30 +11,25 @@ from pathlib import Path
|
|||
import aiofiles
|
||||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.utils.s3 import S3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_s3():
|
||||
# Prerequisites
|
||||
assert CONFIG.S3_ACCESS_KEY and CONFIG.S3_ACCESS_KEY != "YOUR_S3_ACCESS_KEY"
|
||||
assert CONFIG.S3_SECRET_KEY and CONFIG.S3_SECRET_KEY != "YOUR_S3_SECRET_KEY"
|
||||
assert CONFIG.S3_ENDPOINT_URL and CONFIG.S3_ENDPOINT_URL != "YOUR_S3_ENDPOINT_URL"
|
||||
# assert CONFIG.S3_SECURE: true # true/false
|
||||
assert CONFIG.S3_BUCKET and CONFIG.S3_BUCKET != "YOUR_S3_BUCKET"
|
||||
|
||||
conn = S3()
|
||||
assert conn.is_valid
|
||||
s3 = Config.default().s3
|
||||
assert s3
|
||||
conn = S3(s3)
|
||||
object_name = "unittest.bak"
|
||||
await conn.upload_file(bucket=CONFIG.S3_BUCKET, local_path=__file__, object_name=object_name)
|
||||
await conn.upload_file(bucket=s3.bucket, local_path=__file__, object_name=object_name)
|
||||
pathname = (Path(__file__).parent / uuid.uuid4().hex).with_suffix(".bak")
|
||||
pathname.unlink(missing_ok=True)
|
||||
await conn.download_file(bucket=CONFIG.S3_BUCKET, object_name=object_name, local_path=str(pathname))
|
||||
await conn.download_file(bucket=s3.bucket, object_name=object_name, local_path=str(pathname))
|
||||
assert pathname.exists()
|
||||
url = await conn.get_object_url(bucket=CONFIG.S3_BUCKET, object_name=object_name)
|
||||
url = await conn.get_object_url(bucket=s3.bucket, object_name=object_name)
|
||||
assert url
|
||||
bin_data = await conn.get_object(bucket=CONFIG.S3_BUCKET, object_name=object_name)
|
||||
bin_data = await conn.get_object(bucket=s3.bucket, object_name=object_name)
|
||||
assert bin_data
|
||||
async with aiofiles.open(__file__, mode="r", encoding="utf-8") as reader:
|
||||
data = await reader.read()
|
||||
|
|
@ -42,17 +37,13 @@ async def test_s3():
|
|||
assert "http" in res
|
||||
|
||||
# Mock session env
|
||||
old_options = CONFIG.options.copy()
|
||||
new_options = old_options.copy()
|
||||
new_options["S3_ACCESS_KEY"] = "YOUR_S3_ACCESS_KEY"
|
||||
CONFIG.set_context(new_options)
|
||||
s3.access_key = "ABC"
|
||||
try:
|
||||
conn = S3()
|
||||
assert not conn.is_valid
|
||||
conn = S3(s3)
|
||||
res = await conn.cache("ABC", ".bak", "script")
|
||||
assert not res
|
||||
finally:
|
||||
CONFIG.set_context(old_options)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue