add context and config2

This commit is contained in:
geekan 2024-01-04 23:33:09 +08:00
parent 10436172ca
commit 5c1f3a4b91
8 changed files with 34 additions and 34 deletions

View file

@ -23,6 +23,7 @@ from metagpt.schema import (
SerializationMixin,
TestingContext,
)
from metagpt.utils.file_repository import FileRepository
class Action(SerializationMixin, is_polymorphic_base=True):
@ -40,6 +41,10 @@ class Action(SerializationMixin, is_polymorphic_base=True):
def git_repo(self):
return self.g_context.git_repo
@property
def file_repo(self):
return FileRepository(self.g_context.git_repo)
@property
def src_workspace(self):
return self.g_context.src_workspace

View file

@ -47,7 +47,7 @@ class PrepareDocuments(Action):
# Write the newly added requirements from the main parameter idea to `docs/requirement.txt`.
doc = Document(root_path=DOCS_FILE_REPO, filename=REQUIREMENT_FILENAME, content=with_messages[0].content)
await FileRepository.save_file(filename=REQUIREMENT_FILENAME, content=doc.content, relative_path=DOCS_FILE_REPO)
await self.file_repo.save_file(filename=REQUIREMENT_FILENAME, content=doc.content, relative_path=DOCS_FILE_REPO)
# Send a Message notification to the WritePRD action, instructing it to process requirements using
# `docs/requirement.txt` and `docs/prds/`.

View file

@ -95,14 +95,14 @@ class WriteCode(Action):
return code
async def run(self, *args, **kwargs) -> CodingContext:
bug_feedback = await FileRepository.get_file(filename=BUGFIX_FILENAME, relative_path=DOCS_FILE_REPO)
bug_feedback = await self.file_repo.get_file(filename=BUGFIX_FILENAME, relative_path=DOCS_FILE_REPO)
coding_context = CodingContext.loads(self.context.content)
test_doc = await FileRepository.get_file(
test_doc = await self.file_repo.get_file(
filename="test_" + coding_context.filename + ".json", relative_path=TEST_OUTPUTS_FILE_REPO
)
summary_doc = None
if coding_context.design_doc and coding_context.design_doc.filename:
summary_doc = await FileRepository.get_file(
summary_doc = await self.file_repo.get_file(
filename=coding_context.design_doc.filename, relative_path=CODE_SUMMARIES_FILE_REPO
)
logs = ""

View file

@ -49,8 +49,3 @@ class ProductManager(Role):
async def _observe(self, ignore_memory=False) -> int:
return await super()._observe(ignore_memory=True)
@property
def todo(self) -> str:
"""AgentStore uses this attribute to display to the user what actions the current role should take."""
return self.todo_action

View file

@ -16,7 +16,6 @@ from typing import Dict, List, Set
import aiofiles
from metagpt.config import CONFIG
from metagpt.logs import logger
from metagpt.schema import Document
from metagpt.utils.common import aread
@ -201,8 +200,7 @@ class FileRepository:
await self.save(filename=str(filename), content=json_to_markdown(m), dependencies=dependencies)
logger.debug(f"File Saved: {str(filename)}")
@staticmethod
async def get_file(filename: Path | str, relative_path: Path | str = ".") -> Document | None:
async def get_file(self, filename: Path | str, relative_path: Path | str = ".") -> Document | None:
"""Retrieve a specific file from the file repository.
:param filename: The name or path of the file to retrieve.
@ -212,11 +210,10 @@ class FileRepository:
:return: The document representing the file, or None if not found.
:rtype: Document or None
"""
file_repo = CONFIG.git_repo.new_file_repository(relative_path=relative_path)
file_repo = self._git_repo.new_file_repository(relative_path=relative_path)
return await file_repo.get(filename=filename)
@staticmethod
async def get_all_files(relative_path: Path | str = ".") -> List[Document]:
async def get_all_files(self, relative_path: Path | str = ".") -> List[Document]:
"""Retrieve all files from the file repository.
:param relative_path: The relative path within the file repository.
@ -224,11 +221,12 @@ class FileRepository:
:return: A list of documents representing all files in the repository.
:rtype: List[Document]
"""
file_repo = CONFIG.git_repo.new_file_repository(relative_path=relative_path)
file_repo = self._git_repo.new_file_repository(relative_path=relative_path)
return await file_repo.get_all()
@staticmethod
async def save_file(filename: Path | str, content, dependencies: List[str] = None, relative_path: Path | str = "."):
async def save_file(
self, filename: Path | str, content, dependencies: List[str] = None, relative_path: Path | str = "."
):
"""Save a file to the file repository.
:param filename: The name or path of the file to save.
@ -239,12 +237,11 @@ class FileRepository:
:param relative_path: The relative path within the file repository.
:type relative_path: Path or str, optional
"""
file_repo = CONFIG.git_repo.new_file_repository(relative_path=relative_path)
file_repo = self._git_repo.new_file_repository(relative_path=relative_path)
return await file_repo.save(filename=filename, content=content, dependencies=dependencies)
@staticmethod
async def save_as(
doc: Document, with_suffix: str = None, dependencies: List[str] = None, relative_path: Path | str = "."
self, doc: Document, with_suffix: str = None, dependencies: List[str] = None, relative_path: Path | str = "."
):
"""Save a Document instance with optional modifications.
@ -262,7 +259,7 @@ class FileRepository:
:return: A boolean indicating whether the save operation was successful.
:rtype: bool
"""
file_repo = CONFIG.git_repo.new_file_repository(relative_path=relative_path)
file_repo = self._git_repo.new_file_repository(relative_path=relative_path)
return await file_repo.save_doc(doc=doc, with_suffix=with_suffix, dependencies=dependencies)
async def delete(self, filename: Path | str):
@ -282,7 +279,6 @@ class FileRepository:
await dependency_file.update(filename=pathname, dependencies=None)
logger.info(f"remove dependency key: {str(pathname)}")
@staticmethod
async def delete_file(filename: Path | str, relative_path: Path | str = "."):
file_repo = CONFIG.git_repo.new_file_repository(relative_path=relative_path)
async def delete_file(self, filename: Path | str, relative_path: Path | str = "."):
file_repo = self._git_repo.new_file_repository(relative_path=relative_path)
await file_repo.delete(filename=filename)

View file

@ -15,8 +15,9 @@ from typing import Optional
import pytest
from metagpt.config import CONFIG, Config
from metagpt.config2 import config
from metagpt.const import DEFAULT_WORKSPACE_ROOT, TEST_DATA_PATH
from metagpt.context import context
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.provider.openai_api import OpenAILLM
@ -53,6 +54,7 @@ class MockLLM(OpenAILLM):
timeout=3,
stream=True,
) -> str:
logger.debug(f"MockLLM.aask: {msg}")
if msg not in self.rsp_cache:
# Call the original unmocked method
rsp = await self.original_aask(msg, system_msgs, format_msgs, timeout, stream)
@ -81,7 +83,8 @@ def rsp_cache():
@pytest.fixture(scope="function")
def llm_mock(rsp_cache, mocker):
llm = MockLLM()
llm = MockLLM(config.get_llm_config())
llm.cost_manager = context.cost_manager
llm.rsp_cache = rsp_cache
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", llm.aask)
yield mocker
@ -90,7 +93,7 @@ def llm_mock(rsp_cache, mocker):
class Context:
def __init__(self):
self._llm_ui = None
self._llm_api = LLM(provider=CONFIG.get_default_llm_provider_enum())
self._llm_api = LLM()
@property
def llm_api(self):
@ -153,12 +156,12 @@ def loguru_caplog(caplog):
# init & dispose git repo
@pytest.fixture(scope="session", autouse=True)
def setup_and_teardown_git_repo(request):
CONFIG.git_repo = GitRepository(local_path=DEFAULT_WORKSPACE_ROOT / "unittest")
CONFIG.git_reinit = True
context.git_repo = GitRepository(local_path=DEFAULT_WORKSPACE_ROOT / "unittest")
context.config.git_reinit = True
# Destroy git repo at the end of the test session.
def fin():
CONFIG.git_repo.delete_repository()
context.git_repo.delete_repository()
# Register the function for destroying the environment.
request.addfinalizer(fin)
@ -166,4 +169,4 @@ def setup_and_teardown_git_repo(request):
@pytest.fixture(scope="session", autouse=True)
def init_config():
Config()
pass

View file

@ -25,6 +25,6 @@ async def test_prepare_documents():
await PrepareDocuments(g_context=context).run(with_messages=[msg])
assert context.git_repo
doc = await FileRepository.get_file(filename=REQUIREMENT_FILENAME, relative_path=DOCS_FILE_REPO)
doc = await FileRepository(context.git_repo).get_file(filename=REQUIREMENT_FILENAME, relative_path=DOCS_FILE_REPO)
assert doc
assert doc.content == msg.content

View file

@ -7,6 +7,7 @@
"""
import pytest
from metagpt.context import context
from metagpt.logs import logger
from metagpt.roles import ProductManager
from tests.metagpt.roles.mock import MockMessages
@ -15,7 +16,7 @@ from tests.metagpt.roles.mock import MockMessages
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_product_manager():
product_manager = ProductManager()
product_manager = ProductManager(context=context)
rsp = await product_manager.run(MockMessages.req)
logger.info(rsp)
assert len(rsp.content) > 0