fix config loader

This commit is contained in:
Ray 2025-04-06 19:11:45 +08:00
parent e2cf8bb271
commit 95dbc87158
3 changed files with 48 additions and 28 deletions

7
config.yaml Normal file
View file

@ -0,0 +1,7 @@
model: gpt-4o-2024-11-20
toc_check_page_num: 20
max_page_num_each_node: 10
max_token_num_each_node: 20000
if_add_node_id: yes
if_add_node_summary: no
if_add_doc_description: yes

View file

@ -4,7 +4,7 @@ import copy
import math import math
import random import random
import re import re
from .utils import * from utils import *
import os import os
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
import argparse import argparse
@ -1012,8 +1012,6 @@ def tree_parser(page_list, opt, logger=None):
def page_index_main(doc, opt=None): def page_index_main(doc, opt=None):
opt = merge_config(opt, get_default_opt())
logger = JsonLogger(doc) logger = JsonLogger(doc)
is_valid_pdf = ( is_valid_pdf = (
@ -1048,6 +1046,16 @@ def page_index_main(doc, opt=None):
} }
def page_index(doc, model=None, toc_check_page_num=None, max_page_num_each_node=None, max_token_num_each_node=None,
f_add_node_id=None, if_add_node_summary=None, if_add_doc_description=None):
user_opt = {
arg: value for arg, value in locals().items()
if arg != "doc" and value is not None
}
opt = ConfigLoader().load(user_opt)
return page_index_main(doc, opt)
if __name__ == "__main__": if __name__ == "__main__":
# Set up argument parser # Set up argument parser

View file

@ -13,6 +13,8 @@ from io import BytesIO
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
import logging import logging
import yaml
from pathlib import Path
from types import SimpleNamespace as config from types import SimpleNamespace as config
CHATGPT_API_KEY = os.getenv("CHATGPT_API_KEY") CHATGPT_API_KEY = os.getenv("CHATGPT_API_KEY")
@ -589,32 +591,35 @@ def generate_doc_description(structure, model=None):
return response return response
def get_default_opt(): class ConfigLoader:
return { def __init__(self, default_path: str = None):
'model': 'gpt-4o-2024-11-20', if default_path is None:
'toc_check_page_num': 20, default_path = Path(__file__).parent / "config.yaml"
'max_page_num_each_node': 10, self._default_dict = self._load_yaml(default_path)
'max_token_num_each_node': 20000,
'if_add_node_id': 'yes',
'if_add_node_summary': 'no',
'if_add_doc_description': 'yes'
}
def validate_config_keys(user_opt_dict, default_keys): @staticmethod
unknown_keys = set(user_opt_dict) - set(default_keys) def _load_yaml(path):
if unknown_keys: with open(path, "r", encoding="utf-8") as f:
raise ValueError(f"Unknown config keys: {unknown_keys}") return yaml.safe_load(f) or {}
def merge_config(user_opt, default_opt): def _validate_keys(self, user_dict):
unknown_keys = set(user_dict) - set(self._default_dict)
if unknown_keys:
raise ValueError(f"Unknown config keys: {unknown_keys}")
if isinstance(user_opt, config): def load(self, user_opt=None) -> config:
user_opt = vars(user_opt) """
elif user_opt is None: Load the configuration, merging user options with default values.
user_opt = {} """
elif not isinstance(user_opt, dict): if user_opt is None:
raise TypeError("opt must be dict, SimpleNamespace or None") user_dict = {}
elif isinstance(user_opt, config):
user_dict = vars(user_opt)
elif isinstance(user_opt, dict):
user_dict = user_opt
else:
raise TypeError("user_opt must be dict, config(SimpleNamespace) or None")
validate_config_keys(user_opt, default_opt) self._validate_keys(user_dict)
merged = {**self._default_dict, **user_dict}
merged = {**default_opt, **user_opt} return config(**merged)
return config(**merged)