add context and config2

This commit is contained in:
geekan 2024-01-04 22:02:47 +08:00
parent e5d11a046c
commit 10436172ca
25 changed files with 72 additions and 113 deletions

View file

@ -34,31 +34,31 @@ class Action(SerializationMixin, is_polymorphic_base=True):
prefix: str = "" # aask*时会加上prefix作为system_message
desc: str = "" # for skill manager
node: ActionNode = Field(default=None, exclude=True)
_context: Optional[Context] = Field(default=None, exclude=True)
g_context: Optional[Context] = Field(default=None, exclude=True)
@property
def git_repo(self):
return self._context.git_repo
return self.g_context.git_repo
@property
def src_workspace(self):
return self._context.src_workspace
return self.g_context.src_workspace
@property
def prompt_schema(self):
return self._context.config.prompt_schema
return self.g_context.config.prompt_schema
@property
def project_name(self):
return self._context.config.project_name
return self.g_context.config.project_name
@project_name.setter
def project_name(self, value):
self._context.config.project_name = value
self.g_context.config.project_name = value
@property
def project_path(self):
return self._context.config.project_path
return self.g_context.config.project_path
@model_validator(mode="before")
@classmethod

View file

@ -261,7 +261,7 @@ class ActionNode:
output_data_mapping: dict,
system_msgs: Optional[list[str]] = None,
schema="markdown", # compatible to original format
timeout=None,
timeout=3,
) -> (str, BaseModel):
"""Use ActionOutput to wrap the output of aask"""
content = await self.llm.aask(prompt, system_msgs, timeout=timeout)
@ -293,7 +293,7 @@ class ActionNode:
def set_context(self, context):
self.set_recursive("context", context)
async def simple_fill(self, schema, mode, timeout=None, exclude=None):
async def simple_fill(self, schema, mode, timeout=3, exclude=None):
prompt = self.compile(context=self.context, schema=schema, mode=mode, exclude=exclude)
if schema != "raw":
@ -308,7 +308,7 @@ class ActionNode:
return self
async def fill(self, context, llm, schema="json", mode="auto", strgy="simple", timeout=None, exclude=[]):
async def fill(self, context, llm, schema="json", mode="auto", strgy="simple", timeout=3, exclude=[]):
"""Fill the node(s) with mode.
:param context: Everything we should know when filling node.

View file

@ -51,7 +51,7 @@ Now you should start rewriting the code:
class DebugError(Action):
context: RunCodeContext = Field(default_factory=RunCodeContext)
_context: Optional[Context] = None
g_context: Optional[Context] = None
async def run(self, *args, **kwargs) -> str:
output_doc = await FileRepository.get_file(
@ -67,7 +67,7 @@ class DebugError(Action):
logger.info(f"Debug and rewrite {self.context.test_filename}")
code_doc = await FileRepository.get_file(
filename=self.context.code_filename, relative_path=self._context.src_workspace
filename=self.context.code_filename, relative_path=self.g_context.src_workspace
)
if not code_doc:
return ""

View file

@ -26,7 +26,7 @@ class PrepareDocuments(Action):
@property
def config(self):
return self._context.config
return self.g_context.config
def _init_repo(self):
"""Initialize the Git environment."""
@ -39,7 +39,7 @@ class PrepareDocuments(Action):
shutil.rmtree(path)
self.config.project_path = path
self.config.project_name = path.name
self._context.git_repo = GitRepository(local_path=path, auto_init=True)
self.g_context.git_repo = GitRepository(local_path=path, auto_init=True)
async def run(self, with_messages, **kwargs):
"""Create and initialize the workspace folder, initialize the Git environment."""

View file

@ -41,7 +41,7 @@ class WriteTasks(Action):
@property
def prompt_schema(self):
return self._context.config.prompt_schema
return self.g_context.config.prompt_schema
async def run(self, with_messages, schema=None):
system_design_file_repo = self.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO)

View file

@ -93,7 +93,7 @@ class RunCode(Action):
additional_python_paths = [str(path) for path in additional_python_paths]
# Copy the current environment variables
env = self._context.new_environ()
env = self.g_context.new_environ()
# Modify the PYTHONPATH environment variable
additional_python_paths = [working_directory] + additional_python_paths

View file

@ -104,7 +104,7 @@ class SummarizeCode(Action):
design_doc = await FileRepository.get_file(filename=design_pathname.name, relative_path=SYSTEM_DESIGN_FILE_REPO)
task_pathname = Path(self.context.task_filename)
task_doc = await FileRepository.get_file(filename=task_pathname.name, relative_path=TASK_FILE_REPO)
src_file_repo = self.git_repo.new_file_repository(relative_path=self._context.src_workspace)
src_file_repo = self.git_repo.new_file_repository(relative_path=self.g_context.src_workspace)
code_blocks = []
for filename in self.context.codes_filenames:
code_doc = await src_file_repo.get(filename)

View file

@ -117,7 +117,7 @@ class WriteCode(Action):
coding_context.task_doc,
exclude=self.context.filename,
git_repo=self.git_repo,
src_workspace=self._context.src_workspace,
src_workspace=self.g_context.src_workspace,
)
prompt = PROMPT_TEMPLATE.format(
@ -133,7 +133,7 @@ class WriteCode(Action):
code = await self.write_code(prompt)
if not coding_context.code_doc:
# avoid root_path pydantic ValidationError if use WriteCode alone
root_path = self._context.src_workspace if self._context.src_workspace else ""
root_path = self.g_context.src_workspace if self.g_context.src_workspace else ""
coding_context.code_doc = Document(filename=coding_context.filename, root_path=root_path)
coding_context.code_doc.content = code
return coding_context

View file

@ -136,14 +136,14 @@ class WriteCodeReview(Action):
async def run(self, *args, **kwargs) -> CodingContext:
iterative_code = self.context.code_doc.content
k = self._context.config.code_review_k_times or 1
k = self.g_context.config.code_review_k_times or 1
for i in range(k):
format_example = FORMAT_EXAMPLE.format(filename=self.context.code_doc.filename)
task_content = self.context.task_doc.content if self.context.task_doc else ""
code_context = await WriteCode.get_codes(
self.context.task_doc,
exclude=self.context.filename,
git_repo=self._context.git_repo,
git_repo=self.g_context.git_repo,
src_workspace=self.src_workspace,
)
context = "\n".join(

View file

@ -9,8 +9,6 @@ import os
from pathlib import Path
from typing import Dict, Optional
from pydantic import BaseModel
from metagpt.config2 import Config
from metagpt.const import OPTIONS
from metagpt.provider.base_llm import BaseLLM
@ -19,7 +17,7 @@ from metagpt.utils.cost_manager import CostManager
from metagpt.utils.git_repository import GitRepository
class Context(BaseModel):
class Context:
kwargs: Dict = {}
config: Config = Config.default()
git_repo: Optional[GitRepository] = None

View file

@ -97,8 +97,10 @@ class Assistant(Role):
async def talk_handler(self, text, **kwargs) -> bool:
history = self.memory.history_text
text = kwargs.get("last_talk") or text
self.rc.todo = TalkAction(
context=text, knowledge=self.memory.get_knowledge(), history_summary=history, llm=self.llm, **kwargs
self.set_todo(
TalkAction(
context=text, knowledge=self.memory.get_knowledge(), history_summary=history, llm=self.llm, **kwargs
)
)
return True
@ -112,7 +114,7 @@ class Assistant(Role):
await action.run(**kwargs)
if action.args is None:
return await self.talk_handler(text=last_talk, **kwargs)
self.rc.todo = SkillAction(skill=skill, args=action.args, llm=self.llm, name=skill.name, desc=skill.description)
self.set_todo(SkillAction(skill=skill, args=action.args, llm=self.llm, name=skill.name, desc=skill.description))
return True
async def refine_memory(self) -> str:

View file

@ -281,7 +281,9 @@ class Engineer(Role):
f"{changed_files.docs[task_filename].model_dump_json()}"
)
changed_files.docs[task_filename] = coding_doc
self.code_todos = [WriteCode(context=i, llm=self.llm) for i in changed_files.docs.values()]
self.code_todos = [
WriteCode(context=i, g_context=self.context, llm=self.llm) for i in changed_files.docs.values()
]
# Code directly modified by the user.
dependency = await self.git_repo.get_dependency()
for filename in changed_src_files:
@ -295,10 +297,10 @@ class Engineer(Role):
dependency=dependency,
)
changed_files.docs[filename] = coding_doc
self.code_todos.append(WriteCode(context=coding_doc, llm=self.llm))
self.code_todos.append(WriteCode(context=coding_doc, g_context=self.context, llm=self.llm))
if self.code_todos:
self.rc.todo = self.code_todos[0]
self.set_todo(self.code_todos[0])
async def _new_summarize_actions(self):
src_file_repo = self.git_repo.new_file_repository(self.src_workspace)
@ -313,7 +315,7 @@ class Engineer(Role):
ctx.codes_filenames = filenames
self.summarize_todos.append(SummarizeCode(context=ctx, llm=self.llm))
if self.summarize_todos:
self.rc.todo = self.summarize_todos[0]
self.set_todo(self.summarize_todos[0])
@property
def todo(self) -> str:

View file

@ -87,7 +87,7 @@ class InvoiceOCRAssistant(Role):
else:
self._init_actions([GenerateTable])
self.rc.todo = None
self.set_todo(None)
content = INVOICE_OCR_SUCCESS
resp = OCRResults(ocr_result=json.dumps(resp))
msg = Message(content=content, instruct_content=resp)

View file

@ -72,7 +72,7 @@ class QaEngineer(Role):
)
logger.info(f"Writing {test_doc.filename}..")
context = TestingContext(filename=test_doc.filename, test_doc=test_doc, code_doc=code_doc)
context = await WriteTest(context=context, _context=self.context, llm=self.llm).run()
context = await WriteTest(context=context, g_context=self.context, llm=self.llm).run()
await tests_file_repo.save(
filename=context.test_doc.filename,
content=context.test_doc.content,
@ -137,7 +137,7 @@ class QaEngineer(Role):
async def _debug_error(self, msg):
run_code_context = RunCodeContext.loads(msg.content)
code = await DebugError(context=run_code_context, llm=self.llm).run()
code = await DebugError(context=run_code_context, g_context=self.context, llm=self.llm).run()
await FileRepository.save_file(
filename=run_code_context.test_filename, content=code, relative_path=TEST_CODES_FILE_REPO
)

View file

@ -49,7 +49,7 @@ class Researcher(Role):
if self.rc.state + 1 < len(self.states):
self._set_state(self.rc.state + 1)
else:
self.rc.todo = None
self.set_todo(None)
return False
async def _act(self) -> Message:

View file

@ -154,6 +154,15 @@ class Role(SerializationMixin, is_polymorphic_base=True):
__hash__ = object.__hash__ # support Role as hashable type in `Environment.members`
@property
def todo(self) -> Action:
return self.rc.todo
def set_todo(self, value: Optional[Action]):
if value:
value.g_context = self.context
self.rc.todo = value
@property
def config(self):
return self.context.config
@ -326,7 +335,7 @@ class Role(SerializationMixin, is_polymorphic_base=True):
"""Update the current state."""
self.rc.state = state
logger.debug(f"actions={self.actions}, state={state}")
self.rc.todo = self.actions[self.rc.state] if state >= 0 else None
self.set_todo(self.actions[self.rc.state] if state >= 0 else None)
def set_env(self, env: "Environment"):
"""Set the environment in which the role works. The role can talk to the environment and can also receive
@ -521,7 +530,7 @@ class Role(SerializationMixin, is_polymorphic_base=True):
rsp = await self.react()
# Reset the next action to be taken.
self.rc.todo = None
self.set_todo(None)
# Send the response message to the Environment object to have it relay the message to the subscribers.
self.publish_message(rsp)
return rsp
@ -542,8 +551,9 @@ class Role(SerializationMixin, is_polymorphic_base=True):
return ActionOutput(content=msg.content, instruct_content=msg.instruct_content)
@property
def todo(self) -> str:
def first_action(self) -> str:
"""AgentStore uses this attribute to display to the user what actions the current role should take."""
# FIXME: this is a hack, we should not use the first action to represent the todo
if self.actions:
return any_to_name(self.actions[0])
return ""

View file

@ -59,7 +59,7 @@ class Teacher(Role):
self._set_state(self.rc.state + 1)
return True
self.rc.todo = None
self.set_todo(None)
return False
async def _react(self) -> Message:

View file

@ -104,9 +104,9 @@ class Context:
@pytest.fixture(scope="package")
def llm_api():
logger.info("Setting up the test")
_context = Context()
g_context = Context()
yield _context.llm_api
yield g_context.llm_api
logger.info("Tearing down the test")

View file

@ -9,8 +9,8 @@
import pytest
from metagpt.actions.prepare_documents import PrepareDocuments
from metagpt.config import CONFIG
from metagpt.const import DOCS_FILE_REPO, REQUIREMENT_FILENAME
from metagpt.context import context
from metagpt.schema import Message
from metagpt.utils.file_repository import FileRepository
@ -19,12 +19,12 @@ from metagpt.utils.file_repository import FileRepository
async def test_prepare_documents():
msg = Message(content="New user requirements balabala...")
if CONFIG.git_repo:
CONFIG.git_repo.delete_repository()
CONFIG.git_repo = None
if context.git_repo:
context.git_repo.delete_repository()
context.git_repo = None
await PrepareDocuments().run(with_messages=[msg])
assert CONFIG.git_repo
await PrepareDocuments(g_context=context).run(with_messages=[msg])
assert context.git_repo
doc = await FileRepository.get_file(filename=REQUIREMENT_FILENAME, relative_path=DOCS_FILE_REPO)
assert doc
assert doc.content == msg.content

View file

@ -1,7 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/5/11 14:44
@Author : alexanderwu
@File : test_action.py
"""

View file

@ -5,7 +5,7 @@
@Author : alexanderwu
@File : test_document.py
"""
from metagpt.config import CONFIG
from metagpt.config2 import config
from metagpt.document import Repo
from metagpt.logs import logger
@ -28,6 +28,6 @@ def load_existing_repo(path):
def test_repo_set_load():
repo_path = CONFIG.path / "test_repo"
repo_path = config.workspace.path / "test_repo"
set_existing_repo(repo_path)
load_existing_repo(repo_path)

View file

@ -13,7 +13,7 @@ from pathlib import Path
import pytest
from metagpt.actions import UserRequirement
from metagpt.config import CONFIG
from metagpt.context import context
from metagpt.environment import Environment
from metagpt.logs import logger
from metagpt.roles import Architect, ProductManager, Role
@ -46,9 +46,9 @@ def test_get_roles(env: Environment):
@pytest.mark.asyncio
async def test_publish_and_process_message(env: Environment):
if CONFIG.git_repo:
CONFIG.git_repo.delete_repository()
CONFIG.git_repo = None
if context.git_repo:
context.git_repo.delete_repository()
context.git_repo = None
product_manager = ProductManager(name="Alice", profile="Product Manager", goal="做AI Native产品", constraints="资源有限")
architect = Architect(

View file

@ -1,45 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/4/29 19:47
@Author : alexanderwu
@File : test_gpt.py
"""
import openai
import pytest
from metagpt.config import CONFIG
from metagpt.logs import logger
@pytest.mark.usefixtures("llm_api")
class TestGPT:
@pytest.mark.asyncio
async def test_llm_api_aask(self, llm_api):
answer = await llm_api.aask("hello chatgpt", stream=False)
logger.info(answer)
assert len(answer) > 0
answer = await llm_api.aask("hello chatgpt", stream=True)
logger.info(answer)
assert len(answer) > 0
@pytest.mark.asyncio
async def test_llm_api_aask_code(self, llm_api):
try:
answer = await llm_api.aask_code(["请扮演一个Google Python专家工程师如果理解回复明白", "写一个hello world"], timeout=60)
logger.info(answer)
assert len(answer) > 0
except openai.BadRequestError:
assert CONFIG.OPENAI_API_TYPE == "azure"
@pytest.mark.asyncio
async def test_llm_api_costs(self, llm_api):
await llm_api.aask("hello chatgpt", stream=False)
costs = llm_api.get_costs()
logger.info(costs)
assert costs.total_cost > 0
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -9,7 +9,7 @@
import pytest
from metagpt.provider.openai_api import OpenAILLM as LLM
from metagpt.llm import LLM
@pytest.fixture()
@ -23,6 +23,12 @@ async def test_llm_aask(llm):
assert len(rsp) > 0
@pytest.mark.asyncio
async def test_llm_aask_stream(llm):
rsp = await llm.aask("hello world", stream=True)
assert len(rsp) > 0
@pytest.mark.asyncio
async def test_llm_acompletion(llm):
hello_msg = [{"role": "user", "content": "hello"}]

View file

@ -1,7 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/5/11 14:45
@Author : alexanderwu
@File : test_manager.py
"""