diff --git a/config.yaml b/config.yaml new file mode 100644 index 0000000..e1d4666 --- /dev/null +++ b/config.yaml @@ -0,0 +1,7 @@ +model: gpt-4o-2024-11-20 +toc_check_page_num: 20 +max_page_num_each_node: 10 +max_token_num_each_node: 20000 +if_add_node_id: yes +if_add_node_summary: no +if_add_doc_description: yes \ No newline at end of file diff --git a/page_index.py b/page_index.py index 0e45863..9f264a0 100644 --- a/page_index.py +++ b/page_index.py @@ -4,7 +4,7 @@ import copy import math import random import re -from .utils import * +from utils import * import os from concurrent.futures import ThreadPoolExecutor, as_completed import argparse @@ -1012,8 +1012,6 @@ def tree_parser(page_list, opt, logger=None): def page_index_main(doc, opt=None): - opt = merge_config(opt, get_default_opt()) - logger = JsonLogger(doc) is_valid_pdf = ( @@ -1048,6 +1046,16 @@ def page_index_main(doc, opt=None): } +def page_index(doc, model=None, toc_check_page_num=None, max_page_num_each_node=None, max_token_num_each_node=None, + f_add_node_id=None, if_add_node_summary=None, if_add_doc_description=None): + + user_opt = { + arg: value for arg, value in locals().items() + if arg != "doc" and value is not None + } + opt = ConfigLoader().load(user_opt) + return page_index_main(doc, opt) + if __name__ == "__main__": # Set up argument parser diff --git a/utils.py b/utils.py index b261aba..b77348b 100644 --- a/utils.py +++ b/utils.py @@ -13,6 +13,8 @@ from io import BytesIO from dotenv import load_dotenv load_dotenv() import logging +import yaml +from pathlib import Path from types import SimpleNamespace as config CHATGPT_API_KEY = os.getenv("CHATGPT_API_KEY") @@ -589,32 +591,35 @@ def generate_doc_description(structure, model=None): return response -def get_default_opt(): - return { - 'model': 'gpt-4o-2024-11-20', - 'toc_check_page_num': 20, - 'max_page_num_each_node': 10, - 'max_token_num_each_node': 20000, - 'if_add_node_id': 'yes', - 'if_add_node_summary': 'no', - 'if_add_doc_description': 'yes' - } +class ConfigLoader: + def __init__(self, default_path: str = None): + if default_path is None: + default_path = Path(__file__).parent / "config.yaml" + self._default_dict = self._load_yaml(default_path) -def validate_config_keys(user_opt_dict, default_keys): - unknown_keys = set(user_opt_dict) - set(default_keys) - if unknown_keys: - raise ValueError(f"Unknown config keys: {unknown_keys}") + @staticmethod + def _load_yaml(path): + with open(path, "r", encoding="utf-8") as f: + return yaml.safe_load(f) or {} -def merge_config(user_opt, default_opt): + def _validate_keys(self, user_dict): + unknown_keys = set(user_dict) - set(self._default_dict) + if unknown_keys: + raise ValueError(f"Unknown config keys: {unknown_keys}") - if isinstance(user_opt, config): - user_opt = vars(user_opt) - elif user_opt is None: - user_opt = {} - elif not isinstance(user_opt, dict): - raise TypeError("opt must be dict, SimpleNamespace or None") + def load(self, user_opt=None) -> config: + """ + Load the configuration, merging user options with default values. + """ + if user_opt is None: + user_dict = {} + elif isinstance(user_opt, config): + user_dict = vars(user_opt) + elif isinstance(user_opt, dict): + user_dict = user_opt + else: + raise TypeError("user_opt must be dict, config(SimpleNamespace) or None") - validate_config_keys(user_opt, default_opt) - - merged = {**default_opt, **user_opt} - return config(**merged) \ No newline at end of file + self._validate_keys(user_dict) + merged = {**self._default_dict, **user_dict} + return config(**merged) \ No newline at end of file