Merge pull request #1433 from iorisa/fixbug/1422

fixbug: #1422 Role's validate_role_extra resets watch in Environment obj deserialization
This commit is contained in:
better629 2024-08-07 17:26:32 +08:00 committed by GitHub
commit d73cc94356
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 26 additions and 10 deletions

View file

@ -170,7 +170,8 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
self._check_actions()
self.llm.system_prompt = self._get_prefix()
self.llm.cost_manager = self.context.cost_manager
self._watch(kwargs.pop("watch", [UserRequirement]))
if not self.rc.watch:
self._watch(kwargs.pop("watch", [UserRequirement]))
if self.latest_observed_msg:
self.recovered = True

View file

@ -21,12 +21,12 @@ def make_sk_kernel():
if llm := config.get_azure_llm():
kernel.add_chat_service(
"chat_completion",
AzureChatCompletion(llm.model, llm.base_url, llm.api_key),
AzureChatCompletion(deployment_name=llm.model, base_url=llm.base_url, api_key=llm.api_key),
)
elif llm := config.get_openai_llm():
kernel.add_chat_service(
"chat_completion",
OpenAIChatCompletion(llm.model, llm.api_key),
OpenAIChatCompletion(ai_model_id=llm.model, api_key=llm.api_key),
)
return kernel

View file

@ -1,7 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
import pytest
from metagpt.actions.action_node import ActionNode
from metagpt.actions.add_requirement import UserRequirement
@ -55,6 +55,7 @@ def test_environment_serdeser(context):
assert isinstance(list(environment.roles.values())[0].actions[0], ActionOK)
assert type(list(new_env.roles.values())[0].actions[0]) == ActionOK
assert type(list(new_env.roles.values())[0].actions[1]) == ActionRaise
assert list(new_env.roles.values())[0].rc.watch == role_c.rc.watch
def test_environment_serdeser_v2(context):
@ -69,6 +70,7 @@ def test_environment_serdeser_v2(context):
assert isinstance(role, ProjectManager)
assert isinstance(role.actions[0], WriteTasks)
assert isinstance(list(new_env.roles.values())[0].actions[0], WriteTasks)
assert list(new_env.roles.values())[0].rc.watch == pm.rc.watch
def test_environment_serdeser_save(context):
@ -85,3 +87,8 @@ def test_environment_serdeser_save(context):
new_env: Environment = Environment(**env_dict, context=context)
assert len(new_env.roles) == 1
assert type(list(new_env.roles.values())[0].actions[0]) == ActionOK
assert list(new_env.roles.values())[0].rc.watch == role_c.rc.watch
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -28,9 +28,9 @@ from tests.metagpt.serialize_deserialize.test_serdeser_base import (
def test_roles(context):
role_a = RoleA()
assert len(role_a.rc.watch) == 1
assert len(role_a.rc.watch) == 2
role_b = RoleB()
assert len(role_a.rc.watch) == 1
assert len(role_a.rc.watch) == 2
assert len(role_b.rc.watch) == 1
role_d = RoleD(actions=[ActionOK()])

View file

@ -8,9 +8,9 @@ from typing import Optional
from pydantic import BaseModel, Field
from metagpt.actions import Action, ActionOutput
from metagpt.actions import Action, ActionOutput, UserRequirement
from metagpt.actions.action_node import ActionNode
from metagpt.actions.add_requirement import UserRequirement
from metagpt.actions.fix_bug import FixBug
from metagpt.roles.role import Role, RoleReactMode
serdeser_path = Path(__file__).absolute().parent.joinpath("..", "..", "data", "serdeser_storage")
@ -68,7 +68,7 @@ class RoleA(Role):
def __init__(self, **kwargs):
super(RoleA, self).__init__(**kwargs)
self.set_actions([ActionPass])
self._watch([UserRequirement])
self._watch([FixBug, UserRequirement])
class RoleB(Role):
@ -93,7 +93,7 @@ class RoleC(Role):
def __init__(self, **kwargs):
super(RoleC, self).__init__(**kwargs)
self.set_actions([ActionOK, ActionRaise])
self._watch([UserRequirement])
self._watch([FixBug, UserRequirement])
self.rc.react_mode = RoleReactMode.BY_ORDER
self.rc.memory.ignore_id = True

View file

@ -15,3 +15,7 @@ async def test_sk_agent_serdeser():
new_role = SkAgent(**ser_role_dict)
assert new_role.name == "Sunshine"
assert len(new_role.actions) == 1
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -29,3 +29,7 @@ def div(a: int, b: int = 0):
assert new_action.name == "WriteCodeReview"
await new_action.run()
if __name__ == "__main__":
pytest.main([__file__, "-s"])