fix config and logger

This commit is contained in:
Ray 2025-04-06 14:49:12 +08:00
parent 9daa4101d1
commit e2cf8bb271
3 changed files with 53 additions and 18 deletions

View file

@ -0,0 +1 @@
from .page_index import *

View file

@ -2,13 +2,10 @@ import os
import json import json
import copy import copy
import math import math
import sys
import random import random
sys.path.append('../..')
import re import re
from utils import * from .utils import *
import os import os
from types import SimpleNamespace as config
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
import argparse import argparse
@ -1015,6 +1012,8 @@ def tree_parser(page_list, opt, logger=None):
def page_index_main(doc, opt=None): def page_index_main(doc, opt=None):
opt = merge_config(opt, get_default_opt())
logger = JsonLogger(doc) logger = JsonLogger(doc)
is_valid_pdf = ( is_valid_pdf = (
@ -1039,12 +1038,12 @@ def page_index_main(doc, opt=None):
if opt.if_add_doc_description == 'yes': if opt.if_add_doc_description == 'yes':
doc_description = generate_doc_description(structure, model=opt.model) doc_description = generate_doc_description(structure, model=opt.model)
return { return {
'doc_name': os.path.basename(doc), 'doc_name': get_pdf_name(doc),
'doc_description': doc_description, 'doc_description': doc_description,
'structure': structure, 'structure': structure,
} }
return { return {
'doc_name': os.path.basename(doc), 'doc_name': get_pdf_name(doc),
'structure': structure, 'structure': structure,
} }

View file

@ -13,6 +13,7 @@ from io import BytesIO
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
import logging import logging
from types import SimpleNamespace as config
CHATGPT_API_KEY = os.getenv("CHATGPT_API_KEY") CHATGPT_API_KEY = os.getenv("CHATGPT_API_KEY")
@ -284,24 +285,27 @@ def get_last_start_page_from_text(text):
return start_page return start_page
def sanitize_filename(filename, replacement='-'): def sanitize_filename(filename, replacement='-'):
# In Linux, only '/' and '\0' (null) are invalid in filenames. # In Linux, only '/' and '\0' (null) are invalid in filenames.
# Null can't be represented in strings, so we only handle '/'. # Null can't be represented in strings, so we only handle '/'.
return filename.replace('/', replacement) return filename.replace('/', replacement)
def get_pdf_name(pdf_path):
# Extract PDF name
if isinstance(pdf_path, str):
pdf_name = os.path.basename(pdf_path)
elif isinstance(pdf_path, BytesIO):
pdf_reader = PyPDF2.PdfReader(pdf_path)
meta = pdf_reader.metadata
pdf_name = meta.title if meta.title else 'Untitled'
pdf_name = sanitize_filename(pdf_name)
return pdf_name
class JsonLogger: class JsonLogger:
def __init__(self, file_path): def __init__(self, file_path):
# Extract PDF name without extension for logger name and filename # Extract PDF name for logger name
# pdf_name = os.path.splitext(os.path.basename(file_path))[0] pdf_name = get_pdf_name(file_path)
if isinstance(file_path, str):
pdf_name = os.path.splitext(os.path.basename(file_path))[0]
elif isinstance(file_path, BytesIO):
pdf_reader = PyPDF2.PdfReader(file_path)
meta = pdf_reader.metadata
pdf_name = meta.title if meta.title else 'Untitled'
pdf_name = sanitize_filename(pdf_name)
current_time = datetime.now().strftime("%Y%m%d_%H%M%S") current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
self.filename = f"{pdf_name}_{current_time}.json" self.filename = f"{pdf_name}_{current_time}.json"
@ -582,4 +586,35 @@ def generate_doc_description(structure, model=None):
Directly return the description, do not include any other text. Directly return the description, do not include any other text.
""" """
response = ChatGPT_API(model, prompt) response = ChatGPT_API(model, prompt)
return response return response
def get_default_opt():
return {
'model': 'gpt-4o-2024-11-20',
'toc_check_page_num': 20,
'max_page_num_each_node': 10,
'max_token_num_each_node': 20000,
'if_add_node_id': 'yes',
'if_add_node_summary': 'no',
'if_add_doc_description': 'yes'
}
def validate_config_keys(user_opt_dict, default_keys):
unknown_keys = set(user_opt_dict) - set(default_keys)
if unknown_keys:
raise ValueError(f"Unknown config keys: {unknown_keys}")
def merge_config(user_opt, default_opt):
if isinstance(user_opt, config):
user_opt = vars(user_opt)
elif user_opt is None:
user_opt = {}
elif not isinstance(user_opt, dict):
raise TypeError("opt must be dict, SimpleNamespace or None")
validate_config_keys(user_opt, default_opt)
merged = {**default_opt, **user_opt}
return config(**merged)