feat: replace global CONTEXT with Config()

fixbug: unit test
This commit is contained in:
莘权 马 2024-01-22 17:13:20 +08:00
parent ff314388bb
commit e8b3e6762b
15 changed files with 44 additions and 33 deletions

View file

@ -161,7 +161,7 @@ class WriteCodeReview(Action):
format_example=format_example,
)
len1 = len(iterative_code) if iterative_code else 0
len2 = len(self.context.code_doc.content) if self.context.code_doc.content else 0
len2 = len(self.i_context.code_doc.content) if self.i_context.code_doc.content else 0
logger.info(
f"Code review and rewrite {self.i_context.code_doc.filename}: {i + 1}/{k} | len(iterative_code)={len1}, "
f"len(self.i_context.code_doc.content)={len2}"

View file

@ -38,6 +38,7 @@ class CLIParams(BaseModel):
if self.project_path:
self.inc = True
self.project_name = self.project_name or Path(self.project_path).name
return self
class Config(CLIParams, YamlModel):

View file

@ -95,7 +95,3 @@ class Context(BaseModel):
if llm.cost_manager is None:
llm.cost_manager = self.cost_manager
return llm
# Global context, not in Env
CONTEXT = Context()

View file

@ -10,7 +10,7 @@ from typing import Optional
from pydantic import BaseModel, ConfigDict, Field
from metagpt.config2 import Config
from metagpt.context import CONTEXT, Context
from metagpt.context import Context
from metagpt.provider.base_llm import BaseLLM
@ -34,7 +34,7 @@ class ContextMixin(BaseModel):
def __init__(
self,
context: Optional[Context] = CONTEXT,
context: Optional[Context] = None,
config: Optional[Config] = None,
llm: Optional[BaseLLM] = None,
**kwargs,
@ -81,7 +81,7 @@ class ContextMixin(BaseModel):
"""Role context: role context > context"""
if self.private_context:
return self.private_context
return CONTEXT
return Context()
@context.setter
def context(self, context: Context) -> None:

View file

@ -13,7 +13,7 @@ import aiofiles
import yaml
from pydantic import BaseModel, Field
from metagpt.context import CONTEXT, Context
from metagpt.context import Context
class Example(BaseModel):
@ -73,14 +73,15 @@ class SkillsDeclaration(BaseModel):
skill_data = yaml.safe_load(data)
return SkillsDeclaration(**skill_data)
def get_skill_list(self, entity_name: str = "Assistant", context: Context = CONTEXT) -> Dict:
def get_skill_list(self, entity_name: str = "Assistant", context: Context = None) -> Dict:
"""Return the skill name based on the skill description."""
entity = self.entities.get(entity_name)
if not entity:
return {}
# List of skills that the agent chooses to activate.
agent_skills = context.kwargs.agent_skills
ctx = context or Context()
agent_skills = ctx.kwargs.agent_skills
if not agent_skills:
return {}

View file

@ -8,12 +8,13 @@
from typing import Optional
from metagpt.configs.llm_config import LLMConfig
from metagpt.context import CONTEXT
from metagpt.context import Context
from metagpt.provider.base_llm import BaseLLM
def LLM(llm_config: Optional[LLMConfig] = None) -> BaseLLM:
def LLM(llm_config: Optional[LLMConfig] = None, context: Context = None) -> BaseLLM:
"""get the default llm provider if name is None"""
ctx = context or Context()
if llm_config is not None:
CONTEXT.llm_with_cost_manager_from_llm_config(llm_config)
return CONTEXT.llm()
ctx.llm_with_cost_manager_from_llm_config(llm_config)
return ctx.llm()

View file

@ -22,7 +22,6 @@ from pydantic import Field
from metagpt.actions.skill_action import ArgumentsParingAction, SkillAction
from metagpt.actions.talk_action import TalkAction
from metagpt.context import CONTEXT
from metagpt.learn.skill_loader import SkillsDeclaration
from metagpt.logs import logger
from metagpt.memory.brain_memory import BrainMemory
@ -48,7 +47,7 @@ class Assistant(Role):
def __init__(self, **kwargs):
super().__init__(**kwargs)
language = kwargs.get("language") or self.context.kwargs.language or CONTEXT.kwargs.language
language = kwargs.get("language") or self.context.kwargs.language
self.constraints = self.constraints.format(language=language)
async def think(self) -> bool:

View file

@ -8,6 +8,7 @@ import typer
from metagpt.config2 import config
from metagpt.const import CONFIG_ROOT, METAGPT_ROOT
from metagpt.context import Context
app = typer.Typer(add_completion=False, pretty_exceptions_show_locals=False)
@ -37,9 +38,10 @@ def generate_repo(
from metagpt.team import Team
config.update_via_cli(project_path, project_name, inc, reqa_file, max_auto_summarize_code)
ctx = Context(config=config)
if not recover_path:
company = Team()
company = Team(context=ctx)
company.hire(
[
ProductManager(),
@ -58,7 +60,7 @@ def generate_repo(
if not stg_path.exists() or not str(stg_path).endswith("team"):
raise FileNotFoundError(f"{recover_path} not exists or not endswith `team`")
company = Team.deserialize(stg_path=stg_path)
company = Team.deserialize(stg_path=stg_path, context=ctx)
idea = company.idea
company.invest(investment)

View file

@ -10,12 +10,13 @@
import warnings
from pathlib import Path
from typing import Any
from typing import Any, Optional
from pydantic import BaseModel, ConfigDict, Field
from metagpt.actions import UserRequirement
from metagpt.const import MESSAGE_ROUTE_TO_ALL, SERDESER_PATH
from metagpt.context import Context
from metagpt.environment import Environment
from metagpt.logs import logger
from metagpt.roles import Role
@ -36,12 +37,17 @@ class Team(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
env: Environment = Field(default_factory=Environment)
env: Optional[Environment] = None
investment: float = Field(default=10.0)
idea: str = Field(default="")
def __init__(self, **data: Any):
def __init__(self, context: Context = None, **data: Any):
super(Team, self).__init__(**data)
ctx = context or Context()
if not self.env:
self.env = Environment(context=ctx)
else:
self.env.context = ctx # The `env` object is allocated by deserialization
if "roles" in data:
self.hire(data["roles"])
if "env_desc" in data:
@ -54,7 +60,7 @@ class Team(BaseModel):
write_json_file(team_info_path, self.model_dump())
@classmethod
def deserialize(cls, stg_path: Path) -> "Team":
def deserialize(cls, stg_path: Path, context: Context = None) -> "Team":
"""stg_path = ./storage/team"""
# recover team_info
team_info_path = stg_path.joinpath("team.json")
@ -64,7 +70,8 @@ class Team(BaseModel):
)
team_info: dict = read_json_file(team_info_path)
team = Team(**team_info)
ctx = context or Context()
team = Team(**team_info, context=ctx)
return team
def hire(self, roles: list[Role]):

BIN
tests/data/audio/hello.mp3 Normal file

Binary file not shown.

View file

@ -99,7 +99,7 @@ def test_parse_code():
def test_todo():
role = Engineer()
assert role.todo == any_to_name(WriteCode)
assert role.action_description == any_to_name(WriteCode)
@pytest.mark.asyncio

View file

@ -144,3 +144,7 @@ async def test_team_recover_multi_roles_save(mocker, context):
assert new_company.env.get_role(role_b.profile).rc.state == 1
await new_company.run(n_round=4)
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -6,7 +6,7 @@
@File : test_context.py
"""
from metagpt.configs.llm_config import LLMType
from metagpt.context import CONTEXT, AttrDict, Context
from metagpt.context import AttrDict, Context
def test_attr_dict_1():
@ -51,11 +51,12 @@ def test_context_1():
def test_context_2():
llm = CONTEXT.config.get_openai_llm()
ctx = Context()
llm = ctx.config.get_openai_llm()
assert llm is not None
assert llm.api_type == LLMType.OPENAI
kwargs = CONTEXT.kwargs
kwargs = ctx.kwargs
assert kwargs is not None
kwargs.test_key = "test_value"

View file

@ -11,7 +11,6 @@ from pathlib import Path
import pytest
from metagpt.actions import UserRequirement
from metagpt.context import CONTEXT
from metagpt.environment import Environment
from metagpt.logs import logger
from metagpt.roles import Architect, ProductManager, Role
@ -44,9 +43,9 @@ def test_get_roles(env: Environment):
@pytest.mark.asyncio
async def test_publish_and_process_message(env: Environment):
if CONTEXT.git_repo:
CONTEXT.git_repo.delete_repository()
CONTEXT.git_repo = None
if env.context.git_repo:
env.context.git_repo.delete_repository()
env.context.git_repo = None
product_manager = ProductManager(name="Alice", profile="Product Manager", goal="做AI Native产品", constraints="资源有限")
architect = Architect(

View file

@ -131,7 +131,7 @@ async def test_recover():
role.recovered = True
role.latest_observed_msg = Message(content="recover_test")
role.rc.state = 0
assert role.first_action == any_to_name(MockAction)
assert role.action_description == any_to_name(MockAction)
rsp = await role.run()
assert rsp.cause_by == any_to_str(MockAction)