fixbug: Message id, token counter

This commit is contained in:
莘权 马 2023-12-19 10:44:06 +08:00
parent 5593008110
commit e42b1969cc
4 changed files with 16 additions and 9 deletions

View file

@ -163,8 +163,14 @@ class Message(BaseModel):
def load(val):
"""Convert the json string to object."""
try:
d = json.loads(val)
return Message(**d)
m = json.loads(val)
id = m.get("id")
if "id" in m:
del m["id"]
msg = Message(**m)
if id:
msg.id = id
return msg
except JSONDecodeError as err:
logger.error(f"parse json failed: {val}, error:{err}")
return None

View file

@ -88,13 +88,13 @@ async def test_react():
@pytest.mark.asyncio
async def test_msg_to():
m = Message(content="a", send_to=["a", MockRole, Message])
assert m.send_to == set({"a", get_class_name(MockRole), get_class_name(Message)})
assert m.send_to == {"a", get_class_name(MockRole), get_class_name(Message)}
m = Message(content="a", cause_by=MockAction, send_to={"a", MockRole, Message})
assert m.send_to == set({"a", get_class_name(MockRole), get_class_name(Message)})
assert m.send_to == {"a", get_class_name(MockRole), get_class_name(Message)}
m = Message(content="a", send_to=("a", MockRole, Message))
assert m.send_to == set({"a", get_class_name(MockRole), get_class_name(Message)})
assert m.send_to == {"a", get_class_name(MockRole), get_class_name(Message)}
if __name__ == "__main__":

View file

@ -16,7 +16,6 @@ from metagpt.schema import AIMessage, Message, SystemMessage, UserMessage
from metagpt.utils.common import get_class_name
@pytest.mark.asyncio
def test_messages():
test_content = "test_message"
msgs = [
@ -30,7 +29,6 @@ def test_messages():
assert all([i in text for i in roles])
@pytest.mark.asyncio
def test_message():
m = Message("a", role="v1")
v = m.dump()
@ -61,7 +59,6 @@ def test_message():
assert m.content == "b"
@pytest.mark.asyncio
def test_routes():
m = Message("a", role="b", cause_by="c", x="d", send_to="c")
m.send_to = "b"

View file

@ -15,7 +15,7 @@ def test_count_message_tokens():
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
]
assert count_message_tokens(messages) == 17
assert count_message_tokens(messages) == 15
def test_count_message_tokens_with_name():
@ -67,3 +67,7 @@ def test_count_string_tokens_gpt_4():
string = "Hello, world!"
assert count_string_tokens(string, model_name="gpt-4-0314") == 4
if __name__ == "__main__":
pytest.main([__file__, "-s"])