mirror of
https://github.com/VectifyAI/PageIndex.git
synced 2026-04-25 16:16:22 +02:00
640 lines
No EOL
21 KiB
Python
640 lines
No EOL
21 KiB
Python
import tiktoken
|
|
import openai
|
|
import logging
|
|
import os
|
|
from datetime import datetime
|
|
import time
|
|
import json
|
|
import PyPDF2
|
|
import copy
|
|
import asyncio
|
|
import pymupdf
|
|
from io import BytesIO
|
|
from dotenv import load_dotenv
|
|
load_dotenv()
|
|
import logging
|
|
import yaml
|
|
from pathlib import Path
|
|
from types import SimpleNamespace as config
|
|
|
|
CHATGPT_API_KEY = os.getenv("CHATGPT_API_KEY")
|
|
|
|
|
|
def count_tokens(text, model):
|
|
enc = tiktoken.encoding_for_model(model)
|
|
tokens = enc.encode(text)
|
|
return len(tokens)
|
|
|
|
def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None):
|
|
max_retries = 10
|
|
client = openai.OpenAI(api_key=api_key)
|
|
for i in range(max_retries):
|
|
try:
|
|
if chat_history:
|
|
messages = chat_history
|
|
messages.append({"role": "user", "content": prompt})
|
|
else:
|
|
messages = [{"role": "user", "content": prompt}]
|
|
|
|
response = client.chat.completions.create(
|
|
model=model,
|
|
messages=messages,
|
|
temperature=0,
|
|
)
|
|
if response.choices[0].finish_reason == "length":
|
|
return response.choices[0].message.content, "max_output_reached"
|
|
else:
|
|
return response.choices[0].message.content, "finished"
|
|
|
|
except Exception as e:
|
|
print('************* Retrying *************')
|
|
logging.error(f"Error: {e}")
|
|
if i < max_retries - 1:
|
|
time.sleep(1) # Wait for 1秒 before retrying
|
|
else:
|
|
logging.error('Max retries reached for prompt: ' + prompt)
|
|
return "Error"
|
|
|
|
|
|
|
|
def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None):
|
|
max_retries = 10
|
|
client = openai.OpenAI(api_key=api_key)
|
|
for i in range(max_retries):
|
|
try:
|
|
if chat_history:
|
|
messages = chat_history
|
|
messages.append({"role": "user", "content": prompt})
|
|
else:
|
|
messages = [{"role": "user", "content": prompt}]
|
|
|
|
response = client.chat.completions.create(
|
|
model=model,
|
|
messages=messages,
|
|
temperature=0,
|
|
)
|
|
|
|
return response.choices[0].message.content
|
|
except Exception as e:
|
|
print('************* Retrying *************')
|
|
logging.error(f"Error: {e}")
|
|
if i < max_retries - 1:
|
|
time.sleep(1) # Wait for 1秒 before retrying
|
|
else:
|
|
logging.error('Max retries reached for prompt: ' + prompt)
|
|
return "Error"
|
|
|
|
|
|
async def ChatGPT_API_async(model, prompt, api_key=CHATGPT_API_KEY):
|
|
max_retries = 10
|
|
messages = [{"role": "user", "content": prompt}]
|
|
for i in range(max_retries):
|
|
try:
|
|
async with openai.AsyncOpenAI(api_key=api_key) as client:
|
|
response = await client.chat.completions.create(
|
|
model=model,
|
|
messages=messages,
|
|
temperature=0,
|
|
)
|
|
return response.choices[0].message.content
|
|
except Exception as e:
|
|
print('************* Retrying *************')
|
|
logging.error(f"Error: {e}")
|
|
if i < max_retries - 1:
|
|
await asyncio.sleep(1) # Wait for 1s before retrying
|
|
else:
|
|
logging.error('Max retries reached for prompt: ' + prompt)
|
|
return "Error"
|
|
|
|
|
|
def get_json_content(response):
|
|
start_idx = response.find("```json")
|
|
if start_idx != -1:
|
|
start_idx += 7
|
|
response = response[start_idx:]
|
|
|
|
end_idx = response.rfind("```")
|
|
if end_idx != -1:
|
|
response = response[:end_idx]
|
|
|
|
json_content = response.strip()
|
|
return json_content
|
|
|
|
|
|
def extract_json(content):
|
|
try:
|
|
# First, try to extract JSON enclosed within ```json and ```
|
|
start_idx = content.find("```json")
|
|
if start_idx != -1:
|
|
start_idx += 7 # Adjust index to start after the delimiter
|
|
end_idx = content.rfind("```")
|
|
json_content = content[start_idx:end_idx].strip()
|
|
else:
|
|
# If no delimiters, assume entire content could be JSON
|
|
json_content = content.strip()
|
|
|
|
# Clean up common issues that might cause parsing errors
|
|
json_content = json_content.replace('None', 'null') # Replace Python None with JSON null
|
|
json_content = json_content.replace('\n', ' ').replace('\r', ' ') # Remove newlines
|
|
json_content = ' '.join(json_content.split()) # Normalize whitespace
|
|
|
|
# Attempt to parse and return the JSON object
|
|
return json.loads(json_content)
|
|
except json.JSONDecodeError as e:
|
|
logging.error(f"Failed to extract JSON: {e}")
|
|
# Try to clean up the content further if initial parsing fails
|
|
try:
|
|
# Remove any trailing commas before closing brackets/braces
|
|
json_content = json_content.replace(',]', ']').replace(',}', '}')
|
|
return json.loads(json_content)
|
|
except:
|
|
logging.error("Failed to parse JSON even after cleanup")
|
|
return {}
|
|
except Exception as e:
|
|
logging.error(f"Unexpected error while extracting JSON: {e}")
|
|
return {}
|
|
|
|
def write_node_id(data, node_id=0):
|
|
if isinstance(data, dict):
|
|
data['node_id'] = str(node_id).zfill(4)
|
|
node_id += 1
|
|
for key in list(data.keys()):
|
|
if 'nodes' in key:
|
|
node_id = write_node_id(data[key], node_id)
|
|
elif isinstance(data, list):
|
|
for index in range(len(data)):
|
|
node_id = write_node_id(data[index], node_id)
|
|
return node_id
|
|
|
|
def get_nodes(structure):
|
|
if isinstance(structure, dict):
|
|
structure_node = copy.deepcopy(structure)
|
|
structure_node.pop('nodes', None)
|
|
nodes = [structure_node]
|
|
for key in list(structure.keys()):
|
|
if 'nodes' in key:
|
|
nodes.extend(get_nodes(structure[key]))
|
|
return nodes
|
|
elif isinstance(structure, list):
|
|
nodes = []
|
|
for item in structure:
|
|
nodes.extend(get_nodes(item))
|
|
return nodes
|
|
|
|
def structure_to_list(structure):
|
|
if isinstance(structure, dict):
|
|
nodes = []
|
|
nodes.append(structure)
|
|
if 'nodes' in structure:
|
|
nodes.extend(structure_to_list(structure['nodes']))
|
|
return nodes
|
|
elif isinstance(structure, list):
|
|
nodes = []
|
|
for item in structure:
|
|
nodes.extend(structure_to_list(item))
|
|
return nodes
|
|
|
|
|
|
def get_leaf_nodes(structure):
|
|
if isinstance(structure, dict):
|
|
if not structure['nodes']:
|
|
structure_node = copy.deepcopy(structure)
|
|
structure_node.pop('nodes', None)
|
|
return [structure_node]
|
|
else:
|
|
leaf_nodes = []
|
|
for key in list(structure.keys()):
|
|
if 'nodes' in key:
|
|
leaf_nodes.extend(get_leaf_nodes(structure[key]))
|
|
return leaf_nodes
|
|
elif isinstance(structure, list):
|
|
leaf_nodes = []
|
|
for item in structure:
|
|
leaf_nodes.extend(get_leaf_nodes(item))
|
|
return leaf_nodes
|
|
|
|
def is_leaf_node(data, node_id):
|
|
# Helper function to find the node by its node_id
|
|
def find_node(data, node_id):
|
|
if isinstance(data, dict):
|
|
if data.get('node_id') == node_id:
|
|
return data
|
|
for key in data.keys():
|
|
if 'nodes' in key:
|
|
result = find_node(data[key], node_id)
|
|
if result:
|
|
return result
|
|
elif isinstance(data, list):
|
|
for item in data:
|
|
result = find_node(item, node_id)
|
|
if result:
|
|
return result
|
|
return None
|
|
|
|
# Find the node with the given node_id
|
|
node = find_node(data, node_id)
|
|
|
|
# Check if the node is a leaf node
|
|
if node and not node.get('nodes'):
|
|
return True
|
|
return False
|
|
|
|
def get_last_node(structure):
|
|
return structure[-1]
|
|
|
|
|
|
def extract_text_from_pdf(pdf_path):
|
|
pdf_reader = PyPDF2.PdfReader(pdf_path)
|
|
###return text not list
|
|
text=""
|
|
for page_num in range(len(pdf_reader.pages)):
|
|
page = pdf_reader.pages[page_num]
|
|
text+=page.extract_text()
|
|
return text
|
|
|
|
def get_pdf_title(pdf_path):
|
|
pdf_reader = PyPDF2.PdfReader(pdf_path)
|
|
meta = pdf_reader.metadata
|
|
title = meta.title if meta and meta.title else 'Untitled'
|
|
return title
|
|
|
|
def get_text_of_pages(pdf_path, start_page, end_page, tag=True):
|
|
pdf_reader = PyPDF2.PdfReader(pdf_path)
|
|
text = ""
|
|
for page_num in range(start_page-1, end_page):
|
|
page = pdf_reader.pages[page_num]
|
|
page_text = page.extract_text()
|
|
if tag:
|
|
text += f"<start_index_{page_num+1}>\n{page_text}\n<end_index_{page_num+1}>\n"
|
|
else:
|
|
text += page_text
|
|
return text
|
|
|
|
def get_first_start_page_from_text(text):
|
|
start_page = -1
|
|
start_page_match = re.search(r'<start_index_(\d+)>', text)
|
|
if start_page_match:
|
|
start_page = int(start_page_match.group(1))
|
|
return start_page
|
|
|
|
def get_last_start_page_from_text(text):
|
|
start_page = -1
|
|
# Find all matches of start_index tags
|
|
start_page_matches = re.finditer(r'<start_index_(\d+)>', text)
|
|
# Convert iterator to list and get the last match if any exist
|
|
matches_list = list(start_page_matches)
|
|
if matches_list:
|
|
start_page = int(matches_list[-1].group(1))
|
|
return start_page
|
|
|
|
|
|
def sanitize_filename(filename, replacement='-'):
|
|
# In Linux, only '/' and '\0' (null) are invalid in filenames.
|
|
# Null can't be represented in strings, so we only handle '/'.
|
|
return filename.replace('/', replacement)
|
|
|
|
def get_pdf_name(pdf_path):
|
|
# Extract PDF name
|
|
if isinstance(pdf_path, str):
|
|
pdf_name = os.path.basename(pdf_path)
|
|
elif isinstance(pdf_path, BytesIO):
|
|
pdf_reader = PyPDF2.PdfReader(pdf_path)
|
|
meta = pdf_reader.metadata
|
|
pdf_name = meta.title if meta and meta.title else 'Untitled'
|
|
pdf_name = sanitize_filename(pdf_name)
|
|
return pdf_name
|
|
|
|
|
|
class JsonLogger:
|
|
def __init__(self, file_path):
|
|
# Extract PDF name for logger name
|
|
pdf_name = get_pdf_name(file_path)
|
|
|
|
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
self.filename = f"{pdf_name}_{current_time}.json"
|
|
os.makedirs("./logs", exist_ok=True)
|
|
# Initialize empty list to store all messages
|
|
self.log_data = []
|
|
|
|
def log(self, level, message, **kwargs):
|
|
if isinstance(message, dict):
|
|
self.log_data.append(message)
|
|
else:
|
|
self.log_data.append({'message': message})
|
|
# Add new message to the log data
|
|
|
|
# Write entire log data to file
|
|
with open(self._filepath(), "w") as f:
|
|
json.dump(self.log_data, f, indent=2)
|
|
|
|
def info(self, message, **kwargs):
|
|
self.log("INFO", message, **kwargs)
|
|
|
|
def error(self, message, **kwargs):
|
|
self.log("ERROR", message, **kwargs)
|
|
|
|
def debug(self, message, **kwargs):
|
|
self.log("DEBUG", message, **kwargs)
|
|
|
|
def exception(self, message, **kwargs):
|
|
kwargs["exception"] = True
|
|
self.log("ERROR", message, **kwargs)
|
|
|
|
def _filepath(self):
|
|
return os.path.join("logs", self.filename)
|
|
|
|
|
|
|
|
|
|
def list_to_tree(data):
|
|
def get_parent_structure(structure):
|
|
"""Helper function to get the parent structure code"""
|
|
if not structure:
|
|
return None
|
|
parts = str(structure).split('.')
|
|
return '.'.join(parts[:-1]) if len(parts) > 1 else None
|
|
|
|
# First pass: Create nodes and track parent-child relationships
|
|
nodes = {}
|
|
root_nodes = []
|
|
|
|
for item in data:
|
|
structure = item.get('structure')
|
|
node = {
|
|
'title': item.get('title'),
|
|
'start_index': item.get('start_index'),
|
|
'end_index': item.get('end_index'),
|
|
'nodes': []
|
|
}
|
|
|
|
nodes[structure] = node
|
|
|
|
# Find parent
|
|
parent_structure = get_parent_structure(structure)
|
|
|
|
if parent_structure:
|
|
# Add as child to parent if parent exists
|
|
if parent_structure in nodes:
|
|
nodes[parent_structure]['nodes'].append(node)
|
|
else:
|
|
root_nodes.append(node)
|
|
else:
|
|
# No parent, this is a root node
|
|
root_nodes.append(node)
|
|
|
|
# Helper function to clean empty children arrays
|
|
def clean_node(node):
|
|
if not node['nodes']:
|
|
del node['nodes']
|
|
else:
|
|
for child in node['nodes']:
|
|
clean_node(child)
|
|
return node
|
|
|
|
# Clean and return the tree
|
|
return [clean_node(node) for node in root_nodes]
|
|
|
|
def add_preface_if_needed(data):
|
|
if not isinstance(data, list) or not data:
|
|
return data
|
|
|
|
if data[0]['physical_index'] is not None and data[0]['physical_index'] > 1:
|
|
preface_node = {
|
|
"structure": "0",
|
|
"title": "Preface",
|
|
"physical_index": 1,
|
|
}
|
|
data.insert(0, preface_node)
|
|
return data
|
|
|
|
|
|
|
|
def get_page_tokens(pdf_path, model="gpt-4o-2024-11-20", pdf_parser="PyPDF2"):
|
|
enc = tiktoken.encoding_for_model(model)
|
|
if pdf_parser == "PyPDF2":
|
|
pdf_reader = PyPDF2.PdfReader(pdf_path)
|
|
page_list = []
|
|
for page_num in range(len(pdf_reader.pages)):
|
|
page = pdf_reader.pages[page_num]
|
|
page_text = page.extract_text()
|
|
token_length = len(enc.encode(page_text))
|
|
page_list.append((page_text, token_length))
|
|
return page_list
|
|
elif pdf_parser == "PyMuPDF":
|
|
if isinstance(pdf_path, BytesIO):
|
|
pdf_stream = pdf_path
|
|
doc = pymupdf.open(stream=pdf_stream, filetype="pdf")
|
|
elif isinstance(pdf_path, str) and os.path.isfile(pdf_path) and pdf_path.lower().endswith(".pdf"):
|
|
doc = pymupdf.open(pdf_path)
|
|
page_list = []
|
|
for page in doc:
|
|
page_text = page.get_text()
|
|
token_length = len(enc.encode(page_text))
|
|
page_list.append((page_text, token_length))
|
|
return page_list
|
|
else:
|
|
raise ValueError(f"Unsupported PDF parser: {pdf_parser}")
|
|
|
|
|
|
|
|
def get_text_of_pdf_pages(pdf_pages, start_page, end_page):
|
|
text = ""
|
|
for page_num in range(start_page-1, end_page):
|
|
text += pdf_pages[page_num][0]
|
|
return text
|
|
|
|
def get_text_of_pdf_pages_with_labels(pdf_pages, start_page, end_page):
|
|
text = ""
|
|
for page_num in range(start_page-1, end_page):
|
|
text += f"<physical_index_{page_num+1}>\n{pdf_pages[page_num][0]}\n<physical_index_{page_num+1}>\n"
|
|
return text
|
|
|
|
def get_number_of_pages(pdf_path):
|
|
pdf_reader = PyPDF2.PdfReader(pdf_path)
|
|
num = len(pdf_reader.pages)
|
|
return num
|
|
|
|
|
|
|
|
def post_processing(structure, end_physical_index):
|
|
# First convert page_number to start_index in flat list
|
|
for i, item in enumerate(structure):
|
|
item['start_index'] = item.get('physical_index')
|
|
if i < len(structure) - 1:
|
|
if structure[i + 1].get('appear_start') == 'yes':
|
|
item['end_index'] = structure[i + 1]['physical_index']-1
|
|
else:
|
|
item['end_index'] = structure[i + 1]['physical_index']
|
|
else:
|
|
item['end_index'] = end_physical_index
|
|
tree = list_to_tree(structure)
|
|
if len(tree)!=0:
|
|
return tree
|
|
else:
|
|
### remove appear_start
|
|
for node in structure:
|
|
node.pop('appear_start', None)
|
|
node.pop('physical_index', None)
|
|
return structure
|
|
|
|
def clean_structure_post(data):
|
|
if isinstance(data, dict):
|
|
data.pop('page_number', None)
|
|
data.pop('start_index', None)
|
|
data.pop('end_index', None)
|
|
if 'nodes' in data:
|
|
clean_structure_post(data['nodes'])
|
|
elif isinstance(data, list):
|
|
for section in data:
|
|
clean_structure_post(section)
|
|
return data
|
|
|
|
|
|
def remove_structure_text(data):
|
|
if isinstance(data, dict):
|
|
data.pop('text', None)
|
|
if 'nodes' in data:
|
|
remove_structure_text(data['nodes'])
|
|
elif isinstance(data, list):
|
|
for item in data:
|
|
remove_structure_text(item)
|
|
return data
|
|
|
|
|
|
def check_token_limit(structure, limit=110000):
|
|
list = structure_to_list(structure)
|
|
for node in list:
|
|
num_tokens = count_tokens(node['text'], model='gpt-4o')
|
|
if num_tokens > limit:
|
|
print(f"Node ID: {node['node_id']} has {num_tokens} tokens")
|
|
print("Start Index:", node['start_index'])
|
|
print("End Index:", node['end_index'])
|
|
print("Title:", node['title'])
|
|
print("\n")
|
|
|
|
|
|
def convert_physical_index_to_int(data):
|
|
if isinstance(data, list):
|
|
for i in range(len(data)):
|
|
# Check if item is a dictionary and has 'physical_index' key
|
|
if isinstance(data[i], dict) and 'physical_index' in data[i]:
|
|
if isinstance(data[i]['physical_index'], str):
|
|
if data[i]['physical_index'].startswith('<physical_index_'):
|
|
data[i]['physical_index'] = int(data[i]['physical_index'].split('_')[-1].rstrip('>').strip())
|
|
elif data[i]['physical_index'].startswith('physical_index_'):
|
|
data[i]['physical_index'] = int(data[i]['physical_index'].split('_')[-1].strip())
|
|
elif isinstance(data, str):
|
|
if data.startswith('<physical_index_'):
|
|
data = int(data.split('_')[-1].rstrip('>').strip())
|
|
elif data.startswith('physical_index_'):
|
|
data = int(data.split('_')[-1].strip())
|
|
# Check data is int
|
|
if isinstance(data, int):
|
|
return data
|
|
else:
|
|
return None
|
|
return data
|
|
|
|
|
|
def convert_page_to_int(data):
|
|
for item in data:
|
|
if 'page' in item and isinstance(item['page'], str):
|
|
try:
|
|
item['page'] = int(item['page'])
|
|
except ValueError:
|
|
# Keep original value if conversion fails
|
|
pass
|
|
return data
|
|
|
|
|
|
def add_node_text(node, pdf_pages):
|
|
if isinstance(node, dict):
|
|
start_page = node.get('start_index')
|
|
end_page = node.get('end_index')
|
|
node['text'] = get_text_of_pdf_pages(pdf_pages, start_page, end_page)
|
|
if 'nodes' in node:
|
|
add_node_text(node['nodes'], pdf_pages)
|
|
elif isinstance(node, list):
|
|
for index in range(len(node)):
|
|
add_node_text(node[index], pdf_pages)
|
|
return
|
|
|
|
|
|
def add_node_text_with_labels(node, pdf_pages):
|
|
if isinstance(node, dict):
|
|
start_page = node.get('start_index')
|
|
end_page = node.get('end_index')
|
|
node['text'] = get_text_of_pdf_pages_with_labels(pdf_pages, start_page, end_page)
|
|
if 'nodes' in node:
|
|
add_node_text_with_labels(node['nodes'], pdf_pages)
|
|
elif isinstance(node, list):
|
|
for index in range(len(node)):
|
|
add_node_text_with_labels(node[index], pdf_pages)
|
|
return
|
|
|
|
|
|
async def generate_node_summary(node, model=None):
|
|
prompt = f"""You are given a part of a document, your task is to generate a description of the partial document about what are main points covered in the partial document.
|
|
|
|
Partial Document Text: {node['text']}
|
|
|
|
Directly return the description, do not include any other text.
|
|
"""
|
|
response = await ChatGPT_API_async(model, prompt)
|
|
return response
|
|
|
|
|
|
async def generate_summaries_for_structure(structure, model=None):
|
|
nodes = structure_to_list(structure)
|
|
tasks = [generate_node_summary(node, model=model) for node in nodes]
|
|
summaries = await asyncio.gather(*tasks)
|
|
|
|
for node, summary in zip(nodes, summaries):
|
|
node['summary'] = summary
|
|
return structure
|
|
|
|
|
|
def generate_doc_description(structure, model=None):
|
|
prompt = f"""Your are an expert in generating descriptions for a document.
|
|
You are given a structure of a document. Your task is to generate a one-sentence description for the document, which makes it easy to distinguish the document from other documents.
|
|
|
|
Document Structure: {structure}
|
|
|
|
Directly return the description, do not include any other text.
|
|
"""
|
|
response = ChatGPT_API(model, prompt)
|
|
return response
|
|
|
|
|
|
class ConfigLoader:
|
|
def __init__(self, default_path: str = None):
|
|
if default_path is None:
|
|
default_path = Path(__file__).parent / "config.yaml"
|
|
self._default_dict = self._load_yaml(default_path)
|
|
|
|
@staticmethod
|
|
def _load_yaml(path):
|
|
with open(path, "r", encoding="utf-8") as f:
|
|
return yaml.safe_load(f) or {}
|
|
|
|
def _validate_keys(self, user_dict):
|
|
unknown_keys = set(user_dict) - set(self._default_dict)
|
|
if unknown_keys:
|
|
raise ValueError(f"Unknown config keys: {unknown_keys}")
|
|
|
|
def load(self, user_opt=None) -> config:
|
|
"""
|
|
Load the configuration, merging user options with default values.
|
|
"""
|
|
if user_opt is None:
|
|
user_dict = {}
|
|
elif isinstance(user_opt, config):
|
|
user_dict = vars(user_opt)
|
|
elif isinstance(user_opt, dict):
|
|
user_dict = user_opt
|
|
else:
|
|
raise TypeError("user_opt must be dict, config(SimpleNamespace) or None")
|
|
|
|
self._validate_keys(user_dict)
|
|
merged = {**self._default_dict, **user_dict}
|
|
return config(**merged) |