diff --git a/examples/aflow/optimize.py b/examples/aflow/optimize.py index 62df68585..d07eab993 100644 --- a/examples/aflow/optimize.py +++ b/examples/aflow/optimize.py @@ -72,7 +72,12 @@ def parse_args(): parser.add_argument("--max_rounds", type=int, default=20, help="Max iteration rounds") parser.add_argument("--check_convergence", type=bool, default=True, help="Whether to enable early stop") parser.add_argument("--validation_rounds", type=int, default=5, help="Validation rounds") - parser.add_argument("--if_first_optimize", type=bool, default=True, help="Whether it's the first optimization") + parser.add_argument( + "--if_first_optimize", + type=lambda x: x.lower() == "true", + default=True, + help="Whether to download dataset for the first time", + ) return parser.parse_args() diff --git a/metagpt/ext/aflow/scripts/optimizer_utils/data_utils.py b/metagpt/ext/aflow/scripts/optimizer_utils/data_utils.py index 2df161ed8..2a09e0820 100644 --- a/metagpt/ext/aflow/scripts/optimizer_utils/data_utils.py +++ b/metagpt/ext/aflow/scripts/optimizer_utils/data_utils.py @@ -18,7 +18,11 @@ class DataUtils: def load_results(self, path: str) -> list: result_path = os.path.join(path, "results.json") if os.path.exists(result_path): - return read_json_file(result_path, encoding="utf-8") + with open(result_path, "r") as json_file: + try: + return json.load(json_file) + except json.JSONDecodeError: + return [] return [] def get_top_rounds(self, sample: int, path=None, mode="Graph"): diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index c922f2cb4..3b9533571 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -10,6 +10,7 @@ ref3: https://github.com/Significant-Gravitas/Auto-GPT/blob/master/autogpt/llm/t ref4: https://github.com/hwchase17/langchain/blob/master/langchain/chat_models/openai.py ref5: https://ai.google.dev/models/gemini """ +import anthropic import tiktoken from openai.types import CompletionUsage from openai.types.chat import ChatCompletionChunk @@ -377,6 +378,10 @@ SPARK_TOKENS = { def count_input_tokens(messages, model="gpt-3.5-turbo-0125"): """Return the number of tokens used by a list of messages.""" + if "claude" in model: + vo = anthropic.Client() + num_tokens = vo.count_tokens(str(messages)) + return num_tokens try: encoding = tiktoken.encoding_for_model(model) except KeyError: @@ -463,6 +468,10 @@ def count_output_tokens(string: str, model: str) -> int: Returns: int: The number of tokens in the text string. """ + if "claude" in model: + vo = anthropic.Client() + num_tokens = vo.count_tokens(string) + return num_tokens try: encoding = tiktoken.encoding_for_model(model) except KeyError: