mirror of
https://github.com/VectifyAI/PageIndex.git
synced 2026-04-24 23:56:21 +02:00
fix config loader
This commit is contained in:
parent
e2cf8bb271
commit
95dbc87158
3 changed files with 48 additions and 28 deletions
7
config.yaml
Normal file
7
config.yaml
Normal file
|
|
@ -0,0 +1,7 @@
|
||||||
|
model: gpt-4o-2024-11-20
|
||||||
|
toc_check_page_num: 20
|
||||||
|
max_page_num_each_node: 10
|
||||||
|
max_token_num_each_node: 20000
|
||||||
|
if_add_node_id: yes
|
||||||
|
if_add_node_summary: no
|
||||||
|
if_add_doc_description: yes
|
||||||
|
|
@ -4,7 +4,7 @@ import copy
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
from .utils import *
|
from utils import *
|
||||||
import os
|
import os
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
import argparse
|
import argparse
|
||||||
|
|
@ -1012,8 +1012,6 @@ def tree_parser(page_list, opt, logger=None):
|
||||||
|
|
||||||
|
|
||||||
def page_index_main(doc, opt=None):
|
def page_index_main(doc, opt=None):
|
||||||
opt = merge_config(opt, get_default_opt())
|
|
||||||
|
|
||||||
logger = JsonLogger(doc)
|
logger = JsonLogger(doc)
|
||||||
|
|
||||||
is_valid_pdf = (
|
is_valid_pdf = (
|
||||||
|
|
@ -1048,6 +1046,16 @@ def page_index_main(doc, opt=None):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def page_index(doc, model=None, toc_check_page_num=None, max_page_num_each_node=None, max_token_num_each_node=None,
|
||||||
|
f_add_node_id=None, if_add_node_summary=None, if_add_doc_description=None):
|
||||||
|
|
||||||
|
user_opt = {
|
||||||
|
arg: value for arg, value in locals().items()
|
||||||
|
if arg != "doc" and value is not None
|
||||||
|
}
|
||||||
|
opt = ConfigLoader().load(user_opt)
|
||||||
|
return page_index_main(doc, opt)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Set up argument parser
|
# Set up argument parser
|
||||||
|
|
|
||||||
55
utils.py
55
utils.py
|
|
@ -13,6 +13,8 @@ from io import BytesIO
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
import logging
|
import logging
|
||||||
|
import yaml
|
||||||
|
from pathlib import Path
|
||||||
from types import SimpleNamespace as config
|
from types import SimpleNamespace as config
|
||||||
|
|
||||||
CHATGPT_API_KEY = os.getenv("CHATGPT_API_KEY")
|
CHATGPT_API_KEY = os.getenv("CHATGPT_API_KEY")
|
||||||
|
|
@ -589,32 +591,35 @@ def generate_doc_description(structure, model=None):
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
def get_default_opt():
|
class ConfigLoader:
|
||||||
return {
|
def __init__(self, default_path: str = None):
|
||||||
'model': 'gpt-4o-2024-11-20',
|
if default_path is None:
|
||||||
'toc_check_page_num': 20,
|
default_path = Path(__file__).parent / "config.yaml"
|
||||||
'max_page_num_each_node': 10,
|
self._default_dict = self._load_yaml(default_path)
|
||||||
'max_token_num_each_node': 20000,
|
|
||||||
'if_add_node_id': 'yes',
|
|
||||||
'if_add_node_summary': 'no',
|
|
||||||
'if_add_doc_description': 'yes'
|
|
||||||
}
|
|
||||||
|
|
||||||
def validate_config_keys(user_opt_dict, default_keys):
|
@staticmethod
|
||||||
unknown_keys = set(user_opt_dict) - set(default_keys)
|
def _load_yaml(path):
|
||||||
if unknown_keys:
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
raise ValueError(f"Unknown config keys: {unknown_keys}")
|
return yaml.safe_load(f) or {}
|
||||||
|
|
||||||
def merge_config(user_opt, default_opt):
|
def _validate_keys(self, user_dict):
|
||||||
|
unknown_keys = set(user_dict) - set(self._default_dict)
|
||||||
|
if unknown_keys:
|
||||||
|
raise ValueError(f"Unknown config keys: {unknown_keys}")
|
||||||
|
|
||||||
if isinstance(user_opt, config):
|
def load(self, user_opt=None) -> config:
|
||||||
user_opt = vars(user_opt)
|
"""
|
||||||
elif user_opt is None:
|
Load the configuration, merging user options with default values.
|
||||||
user_opt = {}
|
"""
|
||||||
elif not isinstance(user_opt, dict):
|
if user_opt is None:
|
||||||
raise TypeError("opt must be dict, SimpleNamespace or None")
|
user_dict = {}
|
||||||
|
elif isinstance(user_opt, config):
|
||||||
|
user_dict = vars(user_opt)
|
||||||
|
elif isinstance(user_opt, dict):
|
||||||
|
user_dict = user_opt
|
||||||
|
else:
|
||||||
|
raise TypeError("user_opt must be dict, config(SimpleNamespace) or None")
|
||||||
|
|
||||||
validate_config_keys(user_opt, default_opt)
|
self._validate_keys(user_dict)
|
||||||
|
merged = {**self._default_dict, **user_dict}
|
||||||
merged = {**default_opt, **user_opt}
|
return config(**merged)
|
||||||
return config(**merged)
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue