PageIndex/pageindex/utils.py

640 lines
21 KiB
Python
Raw Normal View History

2025-04-01 18:54:08 +08:00
import tiktoken
import openai
import logging
import os
from datetime import datetime
import time
import json
import PyPDF2
import copy
import asyncio
import pymupdf
from io import BytesIO
from dotenv import load_dotenv
load_dotenv()
2025-04-01 18:54:08 +08:00
import logging
2025-04-06 19:11:45 +08:00
import yaml
from pathlib import Path
2025-04-06 14:49:12 +08:00
from types import SimpleNamespace as config
2025-04-01 18:54:08 +08:00
CHATGPT_API_KEY = os.getenv("CHATGPT_API_KEY")
2025-04-01 18:54:08 +08:00
def count_tokens(text, model):
enc = tiktoken.encoding_for_model(model)
tokens = enc.encode(text)
return len(tokens)
def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None):
2025-04-01 18:54:08 +08:00
max_retries = 10
client = openai.OpenAI(api_key=api_key)
for i in range(max_retries):
try:
if chat_history:
messages = chat_history
messages.append({"role": "user", "content": prompt})
else:
messages = [{"role": "user", "content": prompt}]
response = client.chat.completions.create(
model=model,
messages=messages,
temperature=0,
)
if response.choices[0].finish_reason == "length":
return response.choices[0].message.content, "max_output_reached"
else:
return response.choices[0].message.content, "finished"
except Exception as e:
print('************* Retrying *************')
logging.error(f"Error: {e}")
if i < max_retries - 1:
time.sleep(1) # Wait for 1秒 before retrying
else:
logging.error('Max retries reached for prompt: ' + prompt)
return "Error"
def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None):
2025-04-01 18:54:08 +08:00
max_retries = 10
client = openai.OpenAI(api_key=api_key)
for i in range(max_retries):
try:
if chat_history:
messages = chat_history
messages.append({"role": "user", "content": prompt})
else:
messages = [{"role": "user", "content": prompt}]
response = client.chat.completions.create(
model=model,
messages=messages,
temperature=0,
)
return response.choices[0].message.content
except Exception as e:
print('************* Retrying *************')
logging.error(f"Error: {e}")
if i < max_retries - 1:
time.sleep(1) # Wait for 1秒 before retrying
else:
logging.error('Max retries reached for prompt: ' + prompt)
return "Error"
async def ChatGPT_API_async(model, prompt, api_key=CHATGPT_API_KEY):
2025-04-01 18:54:08 +08:00
max_retries = 10
2025-06-26 01:37:20 +08:00
messages = [{"role": "user", "content": prompt}]
2025-04-01 18:54:08 +08:00
for i in range(max_retries):
try:
2025-06-26 01:37:20 +08:00
async with openai.AsyncOpenAI(api_key=api_key) as client:
response = await client.chat.completions.create(
model=model,
messages=messages,
temperature=0,
)
return response.choices[0].message.content
2025-04-01 18:54:08 +08:00
except Exception as e:
print('************* Retrying *************')
logging.error(f"Error: {e}")
if i < max_retries - 1:
2025-06-26 01:37:20 +08:00
await asyncio.sleep(1) # Wait for 1s before retrying
2025-04-01 18:54:08 +08:00
else:
logging.error('Max retries reached for prompt: ' + prompt)
return "Error"
2025-06-26 01:37:20 +08:00
2025-04-01 18:54:08 +08:00
def get_json_content(response):
start_idx = response.find("```json")
if start_idx != -1:
start_idx += 7
response = response[start_idx:]
end_idx = response.rfind("```")
if end_idx != -1:
response = response[:end_idx]
json_content = response.strip()
return json_content
def extract_json(content):
try:
# First, try to extract JSON enclosed within ```json and ```
start_idx = content.find("```json")
if start_idx != -1:
start_idx += 7 # Adjust index to start after the delimiter
end_idx = content.rfind("```")
json_content = content[start_idx:end_idx].strip()
else:
# If no delimiters, assume entire content could be JSON
json_content = content.strip()
# Clean up common issues that might cause parsing errors
json_content = json_content.replace('None', 'null') # Replace Python None with JSON null
json_content = json_content.replace('\n', ' ').replace('\r', ' ') # Remove newlines
json_content = ' '.join(json_content.split()) # Normalize whitespace
# Attempt to parse and return the JSON object
return json.loads(json_content)
except json.JSONDecodeError as e:
logging.error(f"Failed to extract JSON: {e}")
# Try to clean up the content further if initial parsing fails
try:
# Remove any trailing commas before closing brackets/braces
json_content = json_content.replace(',]', ']').replace(',}', '}')
return json.loads(json_content)
except:
logging.error("Failed to parse JSON even after cleanup")
return {}
except Exception as e:
logging.error(f"Unexpected error while extracting JSON: {e}")
return {}
def write_node_id(data, node_id=0):
if isinstance(data, dict):
data['node_id'] = str(node_id).zfill(4)
node_id += 1
for key in list(data.keys()):
if 'nodes' in key:
2025-04-01 18:54:08 +08:00
node_id = write_node_id(data[key], node_id)
elif isinstance(data, list):
for index in range(len(data)):
node_id = write_node_id(data[index], node_id)
return node_id
def get_nodes(structure):
if isinstance(structure, dict):
structure_node = copy.deepcopy(structure)
structure_node.pop('nodes', None)
2025-04-01 18:54:08 +08:00
nodes = [structure_node]
for key in list(structure.keys()):
if 'nodes' in key:
2025-04-01 18:54:08 +08:00
nodes.extend(get_nodes(structure[key]))
return nodes
elif isinstance(structure, list):
nodes = []
for item in structure:
nodes.extend(get_nodes(item))
return nodes
def structure_to_list(structure):
if isinstance(structure, dict):
nodes = []
nodes.append(structure)
if 'nodes' in structure:
nodes.extend(structure_to_list(structure['nodes']))
2025-04-01 18:54:08 +08:00
return nodes
elif isinstance(structure, list):
nodes = []
for item in structure:
nodes.extend(structure_to_list(item))
return nodes
def get_leaf_nodes(structure):
if isinstance(structure, dict):
if not structure['nodes']:
2025-04-01 18:54:08 +08:00
structure_node = copy.deepcopy(structure)
structure_node.pop('nodes', None)
2025-04-01 18:54:08 +08:00
return [structure_node]
else:
leaf_nodes = []
for key in list(structure.keys()):
if 'nodes' in key:
2025-04-01 18:54:08 +08:00
leaf_nodes.extend(get_leaf_nodes(structure[key]))
return leaf_nodes
elif isinstance(structure, list):
leaf_nodes = []
for item in structure:
leaf_nodes.extend(get_leaf_nodes(item))
return leaf_nodes
def is_leaf_node(data, node_id):
# Helper function to find the node by its node_id
def find_node(data, node_id):
if isinstance(data, dict):
if data.get('node_id') == node_id:
return data
for key in data.keys():
if 'nodes' in key:
2025-04-01 18:54:08 +08:00
result = find_node(data[key], node_id)
if result:
return result
elif isinstance(data, list):
for item in data:
result = find_node(item, node_id)
if result:
return result
return None
# Find the node with the given node_id
node = find_node(data, node_id)
# Check if the node is a leaf node
if node and not node.get('nodes'):
2025-04-01 18:54:08 +08:00
return True
return False
def get_last_node(structure):
return structure[-1]
def extract_text_from_pdf(pdf_path):
pdf_reader = PyPDF2.PdfReader(pdf_path)
###return text not list
text=""
for page_num in range(len(pdf_reader.pages)):
page = pdf_reader.pages[page_num]
text+=page.extract_text()
return text
def get_pdf_title(pdf_path):
pdf_reader = PyPDF2.PdfReader(pdf_path)
meta = pdf_reader.metadata
2025-04-09 19:37:00 +08:00
title = meta.title if meta and meta.title else 'Untitled'
2025-04-01 18:54:08 +08:00
return title
def get_text_of_pages(pdf_path, start_page, end_page, tag=True):
pdf_reader = PyPDF2.PdfReader(pdf_path)
text = ""
for page_num in range(start_page-1, end_page):
page = pdf_reader.pages[page_num]
page_text = page.extract_text()
if tag:
text += f"<start_index_{page_num+1}>\n{page_text}\n<end_index_{page_num+1}>\n"
else:
text += page_text
return text
def get_first_start_page_from_text(text):
start_page = -1
start_page_match = re.search(r'<start_index_(\d+)>', text)
if start_page_match:
start_page = int(start_page_match.group(1))
return start_page
def get_last_start_page_from_text(text):
start_page = -1
# Find all matches of start_index tags
start_page_matches = re.finditer(r'<start_index_(\d+)>', text)
# Convert iterator to list and get the last match if any exist
matches_list = list(start_page_matches)
if matches_list:
start_page = int(matches_list[-1].group(1))
return start_page
def sanitize_filename(filename, replacement='-'):
# In Linux, only '/' and '\0' (null) are invalid in filenames.
# Null can't be represented in strings, so we only handle '/'.
return filename.replace('/', replacement)
2025-04-06 14:49:12 +08:00
def get_pdf_name(pdf_path):
# Extract PDF name
if isinstance(pdf_path, str):
pdf_name = os.path.basename(pdf_path)
elif isinstance(pdf_path, BytesIO):
pdf_reader = PyPDF2.PdfReader(pdf_path)
meta = pdf_reader.metadata
2025-04-09 19:37:00 +08:00
pdf_name = meta.title if meta and meta.title else 'Untitled'
2025-04-06 14:49:12 +08:00
pdf_name = sanitize_filename(pdf_name)
return pdf_name
2025-04-01 18:54:08 +08:00
class JsonLogger:
def __init__(self, file_path):
2025-04-06 14:49:12 +08:00
# Extract PDF name for logger name
pdf_name = get_pdf_name(file_path)
2025-04-01 18:54:08 +08:00
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
self.filename = f"{pdf_name}_{current_time}.json"
os.makedirs("./logs", exist_ok=True)
# Initialize empty list to store all messages
self.log_data = []
def log(self, level, message, **kwargs):
if isinstance(message, dict):
self.log_data.append(message)
else:
self.log_data.append({'message': message})
# Add new message to the log data
# Write entire log data to file
with open(self._filepath(), "w") as f:
json.dump(self.log_data, f, indent=2)
def info(self, message, **kwargs):
self.log("INFO", message, **kwargs)
def error(self, message, **kwargs):
self.log("ERROR", message, **kwargs)
def debug(self, message, **kwargs):
self.log("DEBUG", message, **kwargs)
def exception(self, message, **kwargs):
kwargs["exception"] = True
self.log("ERROR", message, **kwargs)
def _filepath(self):
return os.path.join("logs", self.filename)
def list_to_tree(data):
def get_parent_structure(structure):
"""Helper function to get the parent structure code"""
if not structure:
return None
parts = str(structure).split('.')
return '.'.join(parts[:-1]) if len(parts) > 1 else None
# First pass: Create nodes and track parent-child relationships
nodes = {}
root_nodes = []
for item in data:
structure = item.get('structure')
node = {
'title': item.get('title'),
'start_index': item.get('start_index'),
'end_index': item.get('end_index'),
'nodes': []
2025-04-01 18:54:08 +08:00
}
nodes[structure] = node
# Find parent
parent_structure = get_parent_structure(structure)
if parent_structure:
# Add as child to parent if parent exists
if parent_structure in nodes:
nodes[parent_structure]['nodes'].append(node)
2025-04-01 18:54:08 +08:00
else:
root_nodes.append(node)
else:
# No parent, this is a root node
root_nodes.append(node)
# Helper function to clean empty children arrays
def clean_node(node):
if not node['nodes']:
del node['nodes']
2025-04-01 18:54:08 +08:00
else:
for child in node['nodes']:
2025-04-01 18:54:08 +08:00
clean_node(child)
return node
# Clean and return the tree
return [clean_node(node) for node in root_nodes]
def add_preface_if_needed(data):
if not isinstance(data, list) or not data:
return data
if data[0]['physical_index'] is not None and data[0]['physical_index'] > 1:
preface_node = {
"structure": "0",
"title": "Preface",
"physical_index": 1,
}
data.insert(0, preface_node)
return data
def get_page_tokens(pdf_path, model="gpt-4o-2024-11-20", pdf_parser="PyPDF2"):
2025-04-20 07:57:07 +08:00
enc = tiktoken.encoding_for_model(model)
2025-04-01 18:54:08 +08:00
if pdf_parser == "PyPDF2":
pdf_reader = PyPDF2.PdfReader(pdf_path)
2025-04-20 07:57:07 +08:00
page_list = []
for page_num in range(len(pdf_reader.pages)):
page = pdf_reader.pages[page_num]
page_text = page.extract_text()
token_length = len(enc.encode(page_text))
page_list.append((page_text, token_length))
return page_list
2025-04-01 18:54:08 +08:00
elif pdf_parser == "PyMuPDF":
2025-04-20 07:57:07 +08:00
if isinstance(pdf_path, BytesIO):
pdf_stream = pdf_path
doc = pymupdf.open(stream=pdf_stream, filetype="pdf")
elif isinstance(pdf_path, str) and os.path.isfile(pdf_path) and pdf_path.lower().endswith(".pdf"):
doc = pymupdf.open(pdf_path)
page_list = []
for page in doc:
page_text = page.get_text()
token_length = len(enc.encode(page_text))
page_list.append((page_text, token_length))
return page_list
2025-04-01 18:54:08 +08:00
else:
raise ValueError(f"Unsupported PDF parser: {pdf_parser}")
def get_text_of_pdf_pages(pdf_pages, start_page, end_page):
text = ""
for page_num in range(start_page-1, end_page):
text += pdf_pages[page_num][0]
2025-04-01 18:54:08 +08:00
return text
2025-04-20 07:57:07 +08:00
def get_text_of_pdf_pages_with_labels(pdf_pages, start_page, end_page):
text = ""
for page_num in range(start_page-1, end_page):
text += f"<physical_index_{page_num+1}>\n{pdf_pages[page_num][0]}\n<physical_index_{page_num+1}>\n"
return text
2025-04-01 18:54:08 +08:00
def get_number_of_pages(pdf_path):
pdf_reader = PyPDF2.PdfReader(pdf_path)
num = len(pdf_reader.pages)
return num
def post_processing(structure, end_physical_index):
# First convert page_number to start_index in flat list
for i, item in enumerate(structure):
item['start_index'] = item.get('physical_index')
if i < len(structure) - 1:
if structure[i + 1].get('appear_start') == 'yes':
item['end_index'] = structure[i + 1]['physical_index']-1
else:
item['end_index'] = structure[i + 1]['physical_index']
else:
item['end_index'] = end_physical_index
tree = list_to_tree(structure)
if len(tree)!=0:
return tree
else:
### remove appear_start
for node in structure:
node.pop('appear_start', None)
node.pop('physical_index', None)
return structure
def clean_structure_post(data):
if isinstance(data, dict):
data.pop('page_number', None)
data.pop('start_index', None)
data.pop('end_index', None)
if 'nodes' in data:
clean_structure_post(data['nodes'])
2025-04-01 18:54:08 +08:00
elif isinstance(data, list):
for section in data:
clean_structure_post(section)
return data
def remove_structure_text(data):
if isinstance(data, dict):
data.pop('text', None)
if 'nodes' in data:
remove_structure_text(data['nodes'])
2025-04-01 18:54:08 +08:00
elif isinstance(data, list):
for item in data:
remove_structure_text(item)
return data
def check_token_limit(structure, limit=110000):
list = structure_to_list(structure)
for node in list:
num_tokens = count_tokens(node['text'], model='gpt-4o')
if num_tokens > limit:
print(f"Node ID: {node['node_id']} has {num_tokens} tokens")
print("Start Index:", node['start_index'])
print("End Index:", node['end_index'])
print("Title:", node['title'])
print("\n")
def convert_physical_index_to_int(data):
if isinstance(data, list):
for i in range(len(data)):
2025-04-06 19:29:01 +08:00
# Check if item is a dictionary and has 'physical_index' key
if isinstance(data[i], dict) and 'physical_index' in data[i]:
if isinstance(data[i]['physical_index'], str):
if data[i]['physical_index'].startswith('<physical_index_'):
data[i]['physical_index'] = int(data[i]['physical_index'].split('_')[-1].rstrip('>').strip())
elif data[i]['physical_index'].startswith('physical_index_'):
data[i]['physical_index'] = int(data[i]['physical_index'].split('_')[-1].strip())
2025-04-01 18:54:08 +08:00
elif isinstance(data, str):
if data.startswith('<physical_index_'):
data = int(data.split('_')[-1].rstrip('>').strip())
elif data.startswith('physical_index_'):
data = int(data.split('_')[-1].strip())
2025-04-06 19:29:01 +08:00
# Check data is int
2025-04-01 18:54:08 +08:00
if isinstance(data, int):
return data
else:
return None
return data
def convert_page_to_int(data):
for item in data:
if 'page' in item and isinstance(item['page'], str):
try:
item['page'] = int(item['page'])
except ValueError:
# Keep original value if conversion fails
pass
return data
def add_node_text(node, pdf_pages):
if isinstance(node, dict):
start_page = node.get('start_index')
end_page = node.get('end_index')
node['text'] = get_text_of_pdf_pages(pdf_pages, start_page, end_page)
if 'nodes' in node:
add_node_text(node['nodes'], pdf_pages)
elif isinstance(node, list):
for index in range(len(node)):
add_node_text(node[index], pdf_pages)
return
2025-04-20 07:57:07 +08:00
def add_node_text_with_labels(node, pdf_pages):
if isinstance(node, dict):
start_page = node.get('start_index')
end_page = node.get('end_index')
node['text'] = get_text_of_pdf_pages_with_labels(pdf_pages, start_page, end_page)
if 'nodes' in node:
add_node_text_with_labels(node['nodes'], pdf_pages)
elif isinstance(node, list):
for index in range(len(node)):
add_node_text_with_labels(node[index], pdf_pages)
return
async def generate_node_summary(node, model=None):
prompt = f"""You are given a part of a document, your task is to generate a description of the partial document about what are main points covered in the partial document.
Partial Document Text: {node['text']}
Directly return the description, do not include any other text.
"""
response = await ChatGPT_API_async(model, prompt)
return response
async def generate_summaries_for_structure(structure, model=None):
nodes = structure_to_list(structure)
tasks = [generate_node_summary(node, model=model) for node in nodes]
summaries = await asyncio.gather(*tasks)
for node, summary in zip(nodes, summaries):
node['summary'] = summary
return structure
def generate_doc_description(structure, model=None):
prompt = f"""Your are an expert in generating descriptions for a document.
You are given a structure of a document. Your task is to generate a one-sentence description for the document, which makes it easy to distinguish the document from other documents.
Document Structure: {structure}
Directly return the description, do not include any other text.
"""
response = ChatGPT_API(model, prompt)
2025-04-06 14:49:12 +08:00
return response
2025-04-06 19:11:45 +08:00
class ConfigLoader:
def __init__(self, default_path: str = None):
if default_path is None:
default_path = Path(__file__).parent / "config.yaml"
self._default_dict = self._load_yaml(default_path)
@staticmethod
def _load_yaml(path):
with open(path, "r", encoding="utf-8") as f:
return yaml.safe_load(f) or {}
def _validate_keys(self, user_dict):
unknown_keys = set(user_dict) - set(self._default_dict)
if unknown_keys:
raise ValueError(f"Unknown config keys: {unknown_keys}")
def load(self, user_opt=None) -> config:
"""
Load the configuration, merging user options with default values.
"""
if user_opt is None:
user_dict = {}
elif isinstance(user_opt, config):
user_dict = vars(user_opt)
elif isinstance(user_opt, dict):
user_dict = user_opt
else:
raise TypeError("user_opt must be dict, config(SimpleNamespace) or None")
2025-04-06 14:49:12 +08:00
2025-04-06 19:11:45 +08:00
self._validate_keys(user_dict)
merged = {**self._default_dict, **user_dict}
return config(**merged)