refine utils code

This commit is contained in:
geekan 2023-12-19 14:17:54 +08:00 committed by better629
parent f71753ba0d
commit 8f64925290
3 changed files with 42 additions and 26 deletions

View file

@ -18,7 +18,7 @@ from metagpt.actions import Action, ActionOutput, UserRequirement
from metagpt.environment import Environment
from metagpt.roles import Role
from metagpt.schema import Message
from metagpt.utils.common import any_to_str, get_class_name
from metagpt.utils.common import any_to_str
class MockAction(Action):
@ -88,13 +88,13 @@ async def test_react():
@pytest.mark.asyncio
async def test_msg_to():
m = Message(content="a", send_to=["a", MockRole, Message])
assert m.send_to == set({"a", get_class_name(MockRole), get_class_name(Message)})
assert m.send_to == {"a", any_to_str(MockRole), any_to_str(Message)}
m = Message(content="a", cause_by=MockAction, send_to={"a", MockRole, Message})
assert m.send_to == set({"a", get_class_name(MockRole), get_class_name(Message)})
assert m.send_to == {"a", any_to_str(MockRole), any_to_str(Message)}
m = Message(content="a", send_to=("a", MockRole, Message))
assert m.send_to == set({"a", get_class_name(MockRole), get_class_name(Message)})
assert m.send_to == {"a", any_to_str(MockRole), any_to_str(Message)}
if __name__ == "__main__":

View file

@ -16,8 +16,7 @@ from metagpt.schema import AIMessage, Message, SystemMessage, UserMessage
from metagpt.actions.action_output import ActionOutput
from metagpt.actions.write_code import WriteCode
from metagpt.utils.serialize import serialize_general_message, deserialize_general_message
from metagpt.utils.common import get_class_name
from metagpt.utils.common import any_to_str
@pytest.mark.asyncio
@ -58,9 +57,9 @@ def test_message():
m.cause_by = "Message"
assert m.cause_by == "Message"
m.cause_by = Action
assert m.cause_by == get_class_name(Action)
assert m.cause_by == any_to_str(Action)
m.cause_by = Action()
assert m.cause_by == get_class_name(Action)
assert m.cause_by == any_to_str(Action)
m.content = "b"
assert m.content == "b"
@ -71,7 +70,7 @@ def test_routes():
m.send_to = "b"
assert m.send_to == {"b"}
m.send_to = {"e", Action}
assert m.send_to == {"e", get_class_name(Action)}
assert m.send_to == {"e", any_to_str(Action)}
def test_message_serdeser():