Merge pull request #1552 from didiforgithub/main

Fix JSON Load Error & Fix Claude Token Calculate Error.
This commit is contained in:
Alexander Wu 2024-10-30 13:35:14 +08:00 committed by GitHub
commit fd7feb57fa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 20 additions and 2 deletions

View file

@ -72,7 +72,12 @@ def parse_args():
parser.add_argument("--max_rounds", type=int, default=20, help="Max iteration rounds")
parser.add_argument("--check_convergence", type=bool, default=True, help="Whether to enable early stop")
parser.add_argument("--validation_rounds", type=int, default=5, help="Validation rounds")
parser.add_argument("--if_first_optimize", type=bool, default=True, help="Whether it's the first optimization")
parser.add_argument(
"--if_first_optimize",
type=lambda x: x.lower() == "true",
default=True,
help="Whether to download dataset for the first time",
)
return parser.parse_args()

View file

@ -18,7 +18,11 @@ class DataUtils:
def load_results(self, path: str) -> list:
result_path = os.path.join(path, "results.json")
if os.path.exists(result_path):
return read_json_file(result_path, encoding="utf-8")
with open(result_path, "r") as json_file:
try:
return json.load(json_file)
except json.JSONDecodeError:
return []
return []
def get_top_rounds(self, sample: int, path=None, mode="Graph"):

View file

@ -10,6 +10,7 @@ ref3: https://github.com/Significant-Gravitas/Auto-GPT/blob/master/autogpt/llm/t
ref4: https://github.com/hwchase17/langchain/blob/master/langchain/chat_models/openai.py
ref5: https://ai.google.dev/models/gemini
"""
import anthropic
import tiktoken
from openai.types import CompletionUsage
from openai.types.chat import ChatCompletionChunk
@ -377,6 +378,10 @@ SPARK_TOKENS = {
def count_input_tokens(messages, model="gpt-3.5-turbo-0125"):
"""Return the number of tokens used by a list of messages."""
if "claude" in model:
vo = anthropic.Client()
num_tokens = vo.count_tokens(str(messages))
return num_tokens
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
@ -463,6 +468,10 @@ def count_output_tokens(string: str, model: str) -> int:
Returns:
int: The number of tokens in the text string.
"""
if "claude" in model:
vo = anthropic.Client()
num_tokens = vo.count_tokens(string)
return num_tokens
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError: