diff --git a/metagpt/schema.py b/metagpt/schema.py index 758149efa..9916bffff 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -163,8 +163,14 @@ class Message(BaseModel): def load(val): """Convert the json string to object.""" try: - d = json.loads(val) - return Message(**d) + m = json.loads(val) + id = m.get("id") + if "id" in m: + del m["id"] + msg = Message(**m) + if id: + msg.id = id + return msg except JSONDecodeError as err: logger.error(f"parse json failed: {val}, error:{err}") return None diff --git a/tests/metagpt/test_role.py b/tests/metagpt/test_role.py index 8fac2503c..cf09d6f0a 100644 --- a/tests/metagpt/test_role.py +++ b/tests/metagpt/test_role.py @@ -88,13 +88,13 @@ async def test_react(): @pytest.mark.asyncio async def test_msg_to(): m = Message(content="a", send_to=["a", MockRole, Message]) - assert m.send_to == set({"a", get_class_name(MockRole), get_class_name(Message)}) + assert m.send_to == {"a", get_class_name(MockRole), get_class_name(Message)} m = Message(content="a", cause_by=MockAction, send_to={"a", MockRole, Message}) - assert m.send_to == set({"a", get_class_name(MockRole), get_class_name(Message)}) + assert m.send_to == {"a", get_class_name(MockRole), get_class_name(Message)} m = Message(content="a", send_to=("a", MockRole, Message)) - assert m.send_to == set({"a", get_class_name(MockRole), get_class_name(Message)}) + assert m.send_to == {"a", get_class_name(MockRole), get_class_name(Message)} if __name__ == "__main__": diff --git a/tests/metagpt/test_schema.py b/tests/metagpt/test_schema.py index 51ebd5baa..40b18e0f4 100644 --- a/tests/metagpt/test_schema.py +++ b/tests/metagpt/test_schema.py @@ -16,7 +16,6 @@ from metagpt.schema import AIMessage, Message, SystemMessage, UserMessage from metagpt.utils.common import get_class_name -@pytest.mark.asyncio def test_messages(): test_content = "test_message" msgs = [ @@ -30,7 +29,6 @@ def test_messages(): assert all([i in text for i in roles]) -@pytest.mark.asyncio def test_message(): m = Message("a", role="v1") v = m.dump() @@ -61,7 +59,6 @@ def test_message(): assert m.content == "b" -@pytest.mark.asyncio def test_routes(): m = Message("a", role="b", cause_by="c", x="d", send_to="c") m.send_to = "b" diff --git a/tests/metagpt/utils/test_token_counter.py b/tests/metagpt/utils/test_token_counter.py index 479ccc22d..acb99d717 100644 --- a/tests/metagpt/utils/test_token_counter.py +++ b/tests/metagpt/utils/test_token_counter.py @@ -15,7 +15,7 @@ def test_count_message_tokens(): {"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi there!"}, ] - assert count_message_tokens(messages) == 17 + assert count_message_tokens(messages) == 15 def test_count_message_tokens_with_name(): @@ -67,3 +67,7 @@ def test_count_string_tokens_gpt_4(): string = "Hello, world!" assert count_string_tokens(string, model_name="gpt-4-0314") == 4 + + +if __name__ == "__main__": + pytest.main([__file__, "-s"])