diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index d68b9e752..e977884ab 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -403,20 +403,25 @@ SPARK_TOKENS = { } +def count_claude_message_tokens(messages: list[dict], model: str) -> int: + # rough estimation for models newer than claude-2.1, needs api_key or auth_token + ac = anthropic.Client() + system_prompt = "" + new_messages = [] + for msg in messages: + if msg.get("role") == "system": + system_prompt = msg.get("content") + else: + new_messages.append(msg) + num_tokens = ac.beta.messages.count_tokens(messages=new_messages, model=model, system=system_prompt) + return num_tokens.input_tokens + + def count_message_tokens(messages, model="gpt-3.5-turbo-0125"): """Return the number of tokens used by a list of messages.""" if "claude" in model: - # rough estimation for models newer than claude-2.1, needs api_key or auth_token - ac = anthropic.Client() - system_prompt = "" - new_messages = [] - for msg in messages: - if msg.get("role") == "system": - system_prompt = msg.get("content") - else: - new_messages.append(msg) - num_tokens = ac.beta.messages.count_tokens(messages=new_messages, model=model, system=system_prompt) - return num_tokens.get("input_tokens", 0) + num_tokens = count_claude_message_tokens(messages, model) + return num_tokens.input_tokens try: encoding = tiktoken.encoding_for_model(model) except KeyError: @@ -504,8 +509,8 @@ def count_output_tokens(string: str, model: str) -> int: int: The number of tokens in the text string. """ if "claude" in model: - vo = anthropic.Client() - num_tokens = vo.count_tokens(string) + messages = [{"role": "assistant", "content": string}] + num_tokens = count_claude_message_tokens(messages, model) return num_tokens try: encoding = tiktoken.encoding_for_model(model)