mirror of
https://github.com/VectifyAI/PageIndex.git
synced 2026-04-24 23:56:21 +02:00
Restructure and update CLI entry point to run_pageindex.py
This commit is contained in:
parent
1668a53602
commit
403a7a4f54
7 changed files with 181 additions and 261 deletions
1
pageindex/__init__.py
Normal file
1
pageindex/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from .page_index import *
|
||||
7
pageindex/config.yaml
Normal file
7
pageindex/config.yaml
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
model: gpt-4o-2024-11-20
|
||||
toc_check_page_num: 20
|
||||
max_page_num_each_node: 10
|
||||
max_token_num_each_node: 20000
|
||||
if_add_node_id: yes
|
||||
if_add_node_summary: no
|
||||
if_add_doc_description: yes
|
||||
1048
pageindex/page_index.py
Normal file
1048
pageindex/page_index.py
Normal file
File diff suppressed because it is too large
Load diff
626
pageindex/utils.py
Normal file
626
pageindex/utils.py
Normal file
|
|
@ -0,0 +1,626 @@
|
|||
import tiktoken
|
||||
import openai
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
import time
|
||||
import json
|
||||
import PyPDF2
|
||||
import copy
|
||||
import asyncio
|
||||
import pymupdf
|
||||
from io import BytesIO
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
import logging
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace as config
|
||||
|
||||
CHATGPT_API_KEY = os.getenv("CHATGPT_API_KEY")
|
||||
|
||||
|
||||
def count_tokens(text, model):
|
||||
enc = tiktoken.encoding_for_model(model)
|
||||
tokens = enc.encode(text)
|
||||
return len(tokens)
|
||||
|
||||
def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None):
|
||||
max_retries = 10
|
||||
client = openai.OpenAI(api_key=api_key)
|
||||
for i in range(max_retries):
|
||||
try:
|
||||
if chat_history:
|
||||
messages = chat_history
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
else:
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
)
|
||||
if response.choices[0].finish_reason == "length":
|
||||
return response.choices[0].message.content, "max_output_reached"
|
||||
else:
|
||||
return response.choices[0].message.content, "finished"
|
||||
|
||||
except Exception as e:
|
||||
print('************* Retrying *************')
|
||||
logging.error(f"Error: {e}")
|
||||
if i < max_retries - 1:
|
||||
time.sleep(1) # Wait for 1秒 before retrying
|
||||
else:
|
||||
logging.error('Max retries reached for prompt: ' + prompt)
|
||||
return "Error"
|
||||
|
||||
|
||||
|
||||
def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None):
|
||||
max_retries = 10
|
||||
client = openai.OpenAI(api_key=api_key)
|
||||
for i in range(max_retries):
|
||||
try:
|
||||
if chat_history:
|
||||
messages = chat_history
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
else:
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
except Exception as e:
|
||||
print('************* Retrying *************')
|
||||
logging.error(f"Error: {e}")
|
||||
if i < max_retries - 1:
|
||||
time.sleep(1) # Wait for 1秒 before retrying
|
||||
else:
|
||||
logging.error('Max retries reached for prompt: ' + prompt)
|
||||
return "Error"
|
||||
|
||||
|
||||
async def ChatGPT_API_async(model, prompt, api_key=CHATGPT_API_KEY):
|
||||
max_retries = 10
|
||||
client = openai.AsyncOpenAI(api_key=api_key)
|
||||
for i in range(max_retries):
|
||||
try:
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
except Exception as e:
|
||||
print('************* Retrying *************')
|
||||
logging.error(f"Error: {e}")
|
||||
if i < max_retries - 1:
|
||||
await asyncio.sleep(1) # Wait for 1秒 before retrying
|
||||
else:
|
||||
logging.error('Max retries reached for prompt: ' + prompt)
|
||||
return "Error"
|
||||
|
||||
def get_json_content(response):
|
||||
start_idx = response.find("```json")
|
||||
if start_idx != -1:
|
||||
start_idx += 7
|
||||
response = response[start_idx:]
|
||||
|
||||
end_idx = response.rfind("```")
|
||||
if end_idx != -1:
|
||||
response = response[:end_idx]
|
||||
|
||||
json_content = response.strip()
|
||||
return json_content
|
||||
|
||||
|
||||
def extract_json(content):
|
||||
try:
|
||||
# First, try to extract JSON enclosed within ```json and ```
|
||||
start_idx = content.find("```json")
|
||||
if start_idx != -1:
|
||||
start_idx += 7 # Adjust index to start after the delimiter
|
||||
end_idx = content.rfind("```")
|
||||
json_content = content[start_idx:end_idx].strip()
|
||||
else:
|
||||
# If no delimiters, assume entire content could be JSON
|
||||
json_content = content.strip()
|
||||
|
||||
# Clean up common issues that might cause parsing errors
|
||||
json_content = json_content.replace('None', 'null') # Replace Python None with JSON null
|
||||
json_content = json_content.replace('\n', ' ').replace('\r', ' ') # Remove newlines
|
||||
json_content = ' '.join(json_content.split()) # Normalize whitespace
|
||||
|
||||
# Attempt to parse and return the JSON object
|
||||
return json.loads(json_content)
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"Failed to extract JSON: {e}")
|
||||
# Try to clean up the content further if initial parsing fails
|
||||
try:
|
||||
# Remove any trailing commas before closing brackets/braces
|
||||
json_content = json_content.replace(',]', ']').replace(',}', '}')
|
||||
return json.loads(json_content)
|
||||
except:
|
||||
logging.error("Failed to parse JSON even after cleanup")
|
||||
return {}
|
||||
except Exception as e:
|
||||
logging.error(f"Unexpected error while extracting JSON: {e}")
|
||||
return {}
|
||||
|
||||
def write_node_id(data, node_id=0):
|
||||
if isinstance(data, dict):
|
||||
data['node_id'] = str(node_id).zfill(4)
|
||||
node_id += 1
|
||||
for key in list(data.keys()):
|
||||
if 'nodes' in key:
|
||||
node_id = write_node_id(data[key], node_id)
|
||||
elif isinstance(data, list):
|
||||
for index in range(len(data)):
|
||||
node_id = write_node_id(data[index], node_id)
|
||||
return node_id
|
||||
|
||||
def get_nodes(structure):
|
||||
if isinstance(structure, dict):
|
||||
structure_node = copy.deepcopy(structure)
|
||||
structure_node.pop('nodes', None)
|
||||
nodes = [structure_node]
|
||||
for key in list(structure.keys()):
|
||||
if 'nodes' in key:
|
||||
nodes.extend(get_nodes(structure[key]))
|
||||
return nodes
|
||||
elif isinstance(structure, list):
|
||||
nodes = []
|
||||
for item in structure:
|
||||
nodes.extend(get_nodes(item))
|
||||
return nodes
|
||||
|
||||
def structure_to_list(structure):
|
||||
if isinstance(structure, dict):
|
||||
nodes = []
|
||||
nodes.append(structure)
|
||||
if 'nodes' in structure:
|
||||
nodes.extend(structure_to_list(structure['nodes']))
|
||||
return nodes
|
||||
elif isinstance(structure, list):
|
||||
nodes = []
|
||||
for item in structure:
|
||||
nodes.extend(structure_to_list(item))
|
||||
return nodes
|
||||
|
||||
|
||||
def get_leaf_nodes(structure):
|
||||
if isinstance(structure, dict):
|
||||
if not structure['nodes']:
|
||||
structure_node = copy.deepcopy(structure)
|
||||
structure_node.pop('nodes', None)
|
||||
return [structure_node]
|
||||
else:
|
||||
leaf_nodes = []
|
||||
for key in list(structure.keys()):
|
||||
if 'nodes' in key:
|
||||
leaf_nodes.extend(get_leaf_nodes(structure[key]))
|
||||
return leaf_nodes
|
||||
elif isinstance(structure, list):
|
||||
leaf_nodes = []
|
||||
for item in structure:
|
||||
leaf_nodes.extend(get_leaf_nodes(item))
|
||||
return leaf_nodes
|
||||
|
||||
def is_leaf_node(data, node_id):
|
||||
# Helper function to find the node by its node_id
|
||||
def find_node(data, node_id):
|
||||
if isinstance(data, dict):
|
||||
if data.get('node_id') == node_id:
|
||||
return data
|
||||
for key in data.keys():
|
||||
if 'nodes' in key:
|
||||
result = find_node(data[key], node_id)
|
||||
if result:
|
||||
return result
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
result = find_node(item, node_id)
|
||||
if result:
|
||||
return result
|
||||
return None
|
||||
|
||||
# Find the node with the given node_id
|
||||
node = find_node(data, node_id)
|
||||
|
||||
# Check if the node is a leaf node
|
||||
if node and not node.get('nodes'):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_last_node(structure):
|
||||
return structure[-1]
|
||||
|
||||
|
||||
def extract_text_from_pdf(pdf_path):
|
||||
pdf_reader = PyPDF2.PdfReader(pdf_path)
|
||||
###return text not list
|
||||
text=""
|
||||
for page_num in range(len(pdf_reader.pages)):
|
||||
page = pdf_reader.pages[page_num]
|
||||
text+=page.extract_text()
|
||||
return text
|
||||
|
||||
def get_pdf_title(pdf_path):
|
||||
pdf_reader = PyPDF2.PdfReader(pdf_path)
|
||||
meta = pdf_reader.metadata
|
||||
title = meta.title
|
||||
return title
|
||||
|
||||
def get_text_of_pages(pdf_path, start_page, end_page, tag=True):
|
||||
pdf_reader = PyPDF2.PdfReader(pdf_path)
|
||||
text = ""
|
||||
for page_num in range(start_page-1, end_page):
|
||||
page = pdf_reader.pages[page_num]
|
||||
page_text = page.extract_text()
|
||||
if tag:
|
||||
text += f"<start_index_{page_num+1}>\n{page_text}\n<end_index_{page_num+1}>\n"
|
||||
else:
|
||||
text += page_text
|
||||
return text
|
||||
|
||||
def get_first_start_page_from_text(text):
|
||||
start_page = -1
|
||||
start_page_match = re.search(r'<start_index_(\d+)>', text)
|
||||
if start_page_match:
|
||||
start_page = int(start_page_match.group(1))
|
||||
return start_page
|
||||
|
||||
def get_last_start_page_from_text(text):
|
||||
start_page = -1
|
||||
# Find all matches of start_index tags
|
||||
start_page_matches = re.finditer(r'<start_index_(\d+)>', text)
|
||||
# Convert iterator to list and get the last match if any exist
|
||||
matches_list = list(start_page_matches)
|
||||
if matches_list:
|
||||
start_page = int(matches_list[-1].group(1))
|
||||
return start_page
|
||||
|
||||
|
||||
def sanitize_filename(filename, replacement='-'):
|
||||
# In Linux, only '/' and '\0' (null) are invalid in filenames.
|
||||
# Null can't be represented in strings, so we only handle '/'.
|
||||
return filename.replace('/', replacement)
|
||||
|
||||
def get_pdf_name(pdf_path):
|
||||
# Extract PDF name
|
||||
if isinstance(pdf_path, str):
|
||||
pdf_name = os.path.basename(pdf_path)
|
||||
elif isinstance(pdf_path, BytesIO):
|
||||
pdf_reader = PyPDF2.PdfReader(pdf_path)
|
||||
meta = pdf_reader.metadata
|
||||
pdf_name = meta.title if meta.title else 'Untitled'
|
||||
pdf_name = sanitize_filename(pdf_name)
|
||||
return pdf_name
|
||||
|
||||
|
||||
class JsonLogger:
|
||||
def __init__(self, file_path):
|
||||
# Extract PDF name for logger name
|
||||
pdf_name = get_pdf_name(file_path)
|
||||
|
||||
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
self.filename = f"{pdf_name}_{current_time}.json"
|
||||
os.makedirs("./logs", exist_ok=True)
|
||||
# Initialize empty list to store all messages
|
||||
self.log_data = []
|
||||
|
||||
def log(self, level, message, **kwargs):
|
||||
if isinstance(message, dict):
|
||||
self.log_data.append(message)
|
||||
else:
|
||||
self.log_data.append({'message': message})
|
||||
# Add new message to the log data
|
||||
|
||||
# Write entire log data to file
|
||||
with open(self._filepath(), "w") as f:
|
||||
json.dump(self.log_data, f, indent=2)
|
||||
|
||||
def info(self, message, **kwargs):
|
||||
self.log("INFO", message, **kwargs)
|
||||
|
||||
def error(self, message, **kwargs):
|
||||
self.log("ERROR", message, **kwargs)
|
||||
|
||||
def debug(self, message, **kwargs):
|
||||
self.log("DEBUG", message, **kwargs)
|
||||
|
||||
def exception(self, message, **kwargs):
|
||||
kwargs["exception"] = True
|
||||
self.log("ERROR", message, **kwargs)
|
||||
|
||||
def _filepath(self):
|
||||
return os.path.join("logs", self.filename)
|
||||
|
||||
|
||||
|
||||
|
||||
def list_to_tree(data):
|
||||
def get_parent_structure(structure):
|
||||
"""Helper function to get the parent structure code"""
|
||||
if not structure:
|
||||
return None
|
||||
parts = str(structure).split('.')
|
||||
return '.'.join(parts[:-1]) if len(parts) > 1 else None
|
||||
|
||||
# First pass: Create nodes and track parent-child relationships
|
||||
nodes = {}
|
||||
root_nodes = []
|
||||
|
||||
for item in data:
|
||||
structure = item.get('structure')
|
||||
node = {
|
||||
'title': item.get('title'),
|
||||
'start_index': item.get('start_index'),
|
||||
'end_index': item.get('end_index'),
|
||||
'nodes': []
|
||||
}
|
||||
|
||||
nodes[structure] = node
|
||||
|
||||
# Find parent
|
||||
parent_structure = get_parent_structure(structure)
|
||||
|
||||
if parent_structure:
|
||||
# Add as child to parent if parent exists
|
||||
if parent_structure in nodes:
|
||||
nodes[parent_structure]['nodes'].append(node)
|
||||
else:
|
||||
root_nodes.append(node)
|
||||
else:
|
||||
# No parent, this is a root node
|
||||
root_nodes.append(node)
|
||||
|
||||
# Helper function to clean empty children arrays
|
||||
def clean_node(node):
|
||||
if not node['nodes']:
|
||||
del node['nodes']
|
||||
else:
|
||||
for child in node['nodes']:
|
||||
clean_node(child)
|
||||
return node
|
||||
|
||||
# Clean and return the tree
|
||||
return [clean_node(node) for node in root_nodes]
|
||||
|
||||
def add_preface_if_needed(data):
|
||||
if not isinstance(data, list) or not data:
|
||||
return data
|
||||
|
||||
if data[0]['physical_index'] is not None and data[0]['physical_index'] > 1:
|
||||
preface_node = {
|
||||
"structure": "0",
|
||||
"title": "Preface",
|
||||
"physical_index": 1,
|
||||
}
|
||||
data.insert(0, preface_node)
|
||||
return data
|
||||
|
||||
|
||||
|
||||
def get_page_tokens(pdf_path, model="gpt-4o-2024-11-20", pdf_parser="PyPDF2"):
|
||||
if pdf_parser == "PyPDF2":
|
||||
pdf_reader = PyPDF2.PdfReader(pdf_path)
|
||||
elif pdf_parser == "PyMuPDF":
|
||||
pdf_reader = pymupdf.open(pdf_path)
|
||||
else:
|
||||
raise ValueError(f"Unsupported PDF parser: {pdf_parser}")
|
||||
|
||||
enc = tiktoken.encoding_for_model(model)
|
||||
|
||||
page_list = []
|
||||
for page_num in range(len(pdf_reader.pages)):
|
||||
page = pdf_reader.pages[page_num]
|
||||
page_text = page.extract_text()
|
||||
token_length = len(enc.encode(page_text))
|
||||
page_list.append((page_text, token_length))
|
||||
|
||||
return page_list
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def get_text_of_pdf_pages(pdf_pages, start_page, end_page):
|
||||
text = ""
|
||||
for page_num in range(start_page-1, end_page):
|
||||
text += pdf_pages[page_num][0]
|
||||
return text
|
||||
|
||||
def get_number_of_pages(pdf_path):
|
||||
pdf_reader = PyPDF2.PdfReader(pdf_path)
|
||||
num = len(pdf_reader.pages)
|
||||
return num
|
||||
|
||||
|
||||
|
||||
def post_processing(structure, end_physical_index):
|
||||
# First convert page_number to start_index in flat list
|
||||
for i, item in enumerate(structure):
|
||||
item['start_index'] = item.get('physical_index')
|
||||
if i < len(structure) - 1:
|
||||
if structure[i + 1].get('appear_start') == 'yes':
|
||||
item['end_index'] = structure[i + 1]['physical_index']-1
|
||||
else:
|
||||
item['end_index'] = structure[i + 1]['physical_index']
|
||||
else:
|
||||
item['end_index'] = end_physical_index
|
||||
tree = list_to_tree(structure)
|
||||
if len(tree)!=0:
|
||||
return tree
|
||||
else:
|
||||
### remove appear_start
|
||||
for node in structure:
|
||||
node.pop('appear_start', None)
|
||||
node.pop('physical_index', None)
|
||||
return structure
|
||||
|
||||
def clean_structure_post(data):
|
||||
if isinstance(data, dict):
|
||||
data.pop('page_number', None)
|
||||
data.pop('start_index', None)
|
||||
data.pop('end_index', None)
|
||||
if 'nodes' in data:
|
||||
clean_structure_post(data['nodes'])
|
||||
elif isinstance(data, list):
|
||||
for section in data:
|
||||
clean_structure_post(section)
|
||||
return data
|
||||
|
||||
|
||||
def remove_structure_text(data):
|
||||
if isinstance(data, dict):
|
||||
data.pop('text', None)
|
||||
if 'nodes' in data:
|
||||
remove_structure_text(data['nodes'])
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
remove_structure_text(item)
|
||||
return data
|
||||
|
||||
|
||||
def check_token_limit(structure, limit=110000):
|
||||
list = structure_to_list(structure)
|
||||
for node in list:
|
||||
num_tokens = count_tokens(node['text'], model='gpt-4o')
|
||||
if num_tokens > limit:
|
||||
print(f"Node ID: {node['node_id']} has {num_tokens} tokens")
|
||||
print("Start Index:", node['start_index'])
|
||||
print("End Index:", node['end_index'])
|
||||
print("Title:", node['title'])
|
||||
print("\n")
|
||||
|
||||
|
||||
def convert_physical_index_to_int(data):
|
||||
if isinstance(data, list):
|
||||
for i in range(len(data)):
|
||||
# Check if item is a dictionary and has 'physical_index' key
|
||||
if isinstance(data[i], dict) and 'physical_index' in data[i]:
|
||||
if isinstance(data[i]['physical_index'], str):
|
||||
if data[i]['physical_index'].startswith('<physical_index_'):
|
||||
data[i]['physical_index'] = int(data[i]['physical_index'].split('_')[-1].rstrip('>').strip())
|
||||
elif data[i]['physical_index'].startswith('physical_index_'):
|
||||
data[i]['physical_index'] = int(data[i]['physical_index'].split('_')[-1].strip())
|
||||
elif isinstance(data, str):
|
||||
if data.startswith('<physical_index_'):
|
||||
data = int(data.split('_')[-1].rstrip('>').strip())
|
||||
elif data.startswith('physical_index_'):
|
||||
data = int(data.split('_')[-1].strip())
|
||||
# Check data is int
|
||||
if isinstance(data, int):
|
||||
return data
|
||||
else:
|
||||
return None
|
||||
return data
|
||||
|
||||
|
||||
def convert_page_to_int(data):
|
||||
for item in data:
|
||||
if 'page' in item and isinstance(item['page'], str):
|
||||
try:
|
||||
item['page'] = int(item['page'])
|
||||
except ValueError:
|
||||
# Keep original value if conversion fails
|
||||
pass
|
||||
return data
|
||||
|
||||
def write_node_id(data, node_id=0):
|
||||
if isinstance(data, dict):
|
||||
data['node_id'] = str(node_id).zfill(4)
|
||||
node_id += 1
|
||||
for key in list(data.keys()):
|
||||
if 'nodes' in key:
|
||||
node_id = write_node_id(data[key], node_id)
|
||||
elif isinstance(data, list):
|
||||
for index in range(len(data)):
|
||||
node_id = write_node_id(data[index], node_id)
|
||||
return node_id
|
||||
|
||||
|
||||
def add_node_text(node, pdf_pages):
|
||||
if isinstance(node, dict):
|
||||
start_page = node.get('start_index')
|
||||
end_page = node.get('end_index')
|
||||
node['text'] = get_text_of_pdf_pages(pdf_pages, start_page, end_page)
|
||||
if 'nodes' in node:
|
||||
add_node_text(node['nodes'], pdf_pages)
|
||||
elif isinstance(node, list):
|
||||
for index in range(len(node)):
|
||||
add_node_text(node[index], pdf_pages)
|
||||
return
|
||||
|
||||
async def generate_node_summary(node, model=None):
|
||||
prompt = f"""You are given a part of a document, your task is to generate a description of the partial document about what are main points covered in the partial document.
|
||||
|
||||
Partial Document Text: {node['text']}
|
||||
|
||||
Directly return the description, do not include any other text.
|
||||
"""
|
||||
response = await ChatGPT_API_async(model, prompt)
|
||||
return response
|
||||
|
||||
|
||||
async def generate_summaries_for_structure(structure, model=None):
|
||||
nodes = structure_to_list(structure)
|
||||
tasks = [generate_node_summary(node, model=model) for node in nodes]
|
||||
summaries = await asyncio.gather(*tasks)
|
||||
|
||||
for node, summary in zip(nodes, summaries):
|
||||
node['summary'] = summary
|
||||
return structure
|
||||
|
||||
|
||||
def generate_doc_description(structure, model=None):
|
||||
prompt = f"""Your are an expert in generating descriptions for a document.
|
||||
You are given a structure of a document. Your task is to generate a one-sentence description for the document, which makes it easy to distinguish the document from other documents.
|
||||
|
||||
Document Structure: {structure}
|
||||
|
||||
Directly return the description, do not include any other text.
|
||||
"""
|
||||
response = ChatGPT_API(model, prompt)
|
||||
return response
|
||||
|
||||
|
||||
class ConfigLoader:
|
||||
def __init__(self, default_path: str = None):
|
||||
if default_path is None:
|
||||
default_path = Path(__file__).parent / "config.yaml"
|
||||
self._default_dict = self._load_yaml(default_path)
|
||||
|
||||
@staticmethod
|
||||
def _load_yaml(path):
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return yaml.safe_load(f) or {}
|
||||
|
||||
def _validate_keys(self, user_dict):
|
||||
unknown_keys = set(user_dict) - set(self._default_dict)
|
||||
if unknown_keys:
|
||||
raise ValueError(f"Unknown config keys: {unknown_keys}")
|
||||
|
||||
def load(self, user_opt=None) -> config:
|
||||
"""
|
||||
Load the configuration, merging user options with default values.
|
||||
"""
|
||||
if user_opt is None:
|
||||
user_dict = {}
|
||||
elif isinstance(user_opt, config):
|
||||
user_dict = vars(user_opt)
|
||||
elif isinstance(user_opt, dict):
|
||||
user_dict = user_opt
|
||||
else:
|
||||
raise TypeError("user_opt must be dict, config(SimpleNamespace) or None")
|
||||
|
||||
self._validate_keys(user_dict)
|
||||
merged = {**self._default_dict, **user_dict}
|
||||
return config(**merged)
|
||||
Loading…
Add table
Add a link
Reference in a new issue