import litellm import logging import time import json import copy import re import asyncio import PyPDF2 logger = logging.getLogger(__name__) def count_tokens(text, model=None): if not text: return 0 return litellm.token_counter(model=model, text=text) def llm_completion(model, prompt, chat_history=None, return_finish_reason=False): if model: model = model.removeprefix("litellm/") max_retries = 10 messages = list(chat_history) + [{"role": "user", "content": prompt}] if chat_history else [{"role": "user", "content": prompt}] for i in range(max_retries): try: litellm.drop_params = True response = litellm.completion( model=model, messages=messages, temperature=0, ) content = response.choices[0].message.content if return_finish_reason: finish_reason = "max_output_reached" if response.choices[0].finish_reason == "length" else "finished" return content, finish_reason return content except Exception as e: logger.warning("Retrying LLM completion (%d/%d)", i + 1, max_retries) logger.error(f"Error: {e}") if i < max_retries - 1: time.sleep(1) else: logger.error('Max retries reached for prompt: ' + prompt) raise RuntimeError(f"LLM call failed after {max_retries} retries") from e async def llm_acompletion(model, prompt): if model: model = model.removeprefix("litellm/") max_retries = 10 messages = [{"role": "user", "content": prompt}] for i in range(max_retries): try: litellm.drop_params = True response = await litellm.acompletion( model=model, messages=messages, temperature=0, ) return response.choices[0].message.content except Exception as e: logger.warning("Retrying async LLM completion (%d/%d)", i + 1, max_retries) logger.error(f"Error: {e}") if i < max_retries - 1: await asyncio.sleep(1) else: logger.error('Max retries reached for prompt: ' + prompt) raise RuntimeError(f"LLM call failed after {max_retries} retries") from e def extract_json(content): try: # First, try to extract JSON enclosed within ```json and ``` start_idx = content.find("```json") if start_idx != -1: start_idx += 7 # Adjust index to start after the delimiter end_idx = content.rfind("```") json_content = content[start_idx:end_idx].strip() else: # If no delimiters, assume entire content could be JSON json_content = content.strip() # Clean up common issues that might cause parsing errors json_content = json_content.replace('None', 'null') # Replace Python None with JSON null json_content = json_content.replace('\n', ' ').replace('\r', ' ') # Remove newlines json_content = ' '.join(json_content.split()) # Normalize whitespace # Attempt to parse and return the JSON object return json.loads(json_content) except json.JSONDecodeError as e: logging.error(f"Failed to extract JSON: {e}") # Try to clean up the content further if initial parsing fails try: # Remove any trailing commas before closing brackets/braces json_content = json_content.replace(',]', ']').replace(',}', '}') return json.loads(json_content) except Exception: logging.error("Failed to parse JSON even after cleanup") return {} except Exception as e: logging.error(f"Unexpected error while extracting JSON: {e}") return {} def get_json_content(response): start_idx = response.find("```json") if start_idx != -1: start_idx += 7 response = response[start_idx:] end_idx = response.rfind("```") if end_idx != -1: response = response[:end_idx] json_content = response.strip() return json_content def write_node_id(data, node_id=0): if isinstance(data, dict): data['node_id'] = str(node_id).zfill(4) node_id += 1 for key in list(data.keys()): if 'nodes' in key: node_id = write_node_id(data[key], node_id) elif isinstance(data, list): for index in range(len(data)): node_id = write_node_id(data[index], node_id) return node_id def remove_fields(data, fields=None): fields = fields or ["text"] if isinstance(data, dict): return {k: remove_fields(v, fields) for k, v in data.items() if k not in fields} elif isinstance(data, list): return [remove_fields(item, fields) for item in data] return data def structure_to_list(structure): if isinstance(structure, dict): nodes = [] nodes.append(structure) if 'nodes' in structure: nodes.extend(structure_to_list(structure['nodes'])) return nodes elif isinstance(structure, list): nodes = [] for item in structure: nodes.extend(structure_to_list(item)) return nodes def get_nodes(structure): if isinstance(structure, dict): structure_node = copy.deepcopy(structure) structure_node.pop('nodes', None) nodes = [structure_node] for key in list(structure.keys()): if 'nodes' in key: nodes.extend(get_nodes(structure[key])) return nodes elif isinstance(structure, list): nodes = [] for item in structure: nodes.extend(get_nodes(item)) return nodes def get_leaf_nodes(structure): if isinstance(structure, dict): if not structure['nodes']: structure_node = copy.deepcopy(structure) structure_node.pop('nodes', None) return [structure_node] else: leaf_nodes = [] for key in list(structure.keys()): if 'nodes' in key: leaf_nodes.extend(get_leaf_nodes(structure[key])) return leaf_nodes elif isinstance(structure, list): leaf_nodes = [] for item in structure: leaf_nodes.extend(get_leaf_nodes(item)) return leaf_nodes async def generate_node_summary(node, model=None): prompt = f"""You are given a part of a document, your task is to generate a description of the partial document about what are main points covered in the partial document. Partial Document Text: {node['text']} Directly return the description, do not include any other text. """ response = await llm_acompletion(model, prompt) return response async def generate_summaries_for_structure(structure, model=None): nodes = structure_to_list(structure) tasks = [generate_node_summary(node, model=model) for node in nodes] summaries = await asyncio.gather(*tasks) for node, summary in zip(nodes, summaries): node['summary'] = summary return structure def generate_doc_description(structure, model=None): prompt = f"""Your are an expert in generating descriptions for a document. You are given a structure of a document. Your task is to generate a one-sentence description for the document, which makes it easy to distinguish the document from other documents. Document Structure: {structure} Directly return the description, do not include any other text. """ response = llm_completion(model, prompt) return response def list_to_tree(data): def get_parent_structure(structure): """Helper function to get the parent structure code""" if not structure: return None parts = str(structure).split('.') return '.'.join(parts[:-1]) if len(parts) > 1 else None # First pass: Create nodes and track parent-child relationships nodes = {} root_nodes = [] for item in data: structure = item.get('structure') node = { 'title': item.get('title'), 'start_index': item.get('start_index'), 'end_index': item.get('end_index'), 'nodes': [] } nodes[structure] = node # Find parent parent_structure = get_parent_structure(structure) if parent_structure: # Add as child to parent if parent exists if parent_structure in nodes: nodes[parent_structure]['nodes'].append(node) else: root_nodes.append(node) else: # No parent, this is a root node root_nodes.append(node) # Helper function to clean empty children arrays def clean_node(node): if not node['nodes']: del node['nodes'] else: for child in node['nodes']: clean_node(child) return node # Clean and return the tree return [clean_node(node) for node in root_nodes] def post_processing(structure, end_physical_index): # First convert page_number to start_index in flat list for i, item in enumerate(structure): item['start_index'] = item.get('physical_index') if i < len(structure) - 1: if structure[i + 1].get('appear_start') == 'yes': item['end_index'] = structure[i + 1]['physical_index']-1 else: item['end_index'] = structure[i + 1]['physical_index'] else: item['end_index'] = end_physical_index tree = list_to_tree(structure) if len(tree)!=0: return tree else: ### remove appear_start for node in structure: node.pop('appear_start', None) node.pop('physical_index', None) return structure def reorder_dict(data, key_order): if not key_order: return data return {key: data[key] for key in key_order if key in data} def format_structure(structure, order=None): if not order: return structure if isinstance(structure, dict): if 'nodes' in structure: structure['nodes'] = format_structure(structure['nodes'], order) if not structure.get('nodes'): structure.pop('nodes', None) structure = reorder_dict(structure, order) elif isinstance(structure, list): structure = [format_structure(item, order) for item in structure] return structure def create_clean_structure_for_description(structure): """ Create a clean structure for document description generation, excluding unnecessary fields like 'text'. """ if isinstance(structure, dict): clean_node = {} # Only include essential fields for description for key in ['title', 'node_id', 'summary', 'prefix_summary']: if key in structure: clean_node[key] = structure[key] # Recursively process child nodes if 'nodes' in structure and structure['nodes']: clean_node['nodes'] = create_clean_structure_for_description(structure['nodes']) return clean_node elif isinstance(structure, list): return [create_clean_structure_for_description(item) for item in structure] else: return structure def _get_text_of_pages(page_list, start_page, end_page): """Concatenate text from page_list for pages [start_page, end_page] (1-indexed).""" text = "" for page_num in range(start_page - 1, end_page): text += page_list[page_num][0] return text def add_node_text(node, page_list): """Recursively add 'text' field to each node from page_list content. Each node must have 'start_index' and 'end_index' (1-indexed page numbers). page_list is [(page_text, token_count), ...]. """ if isinstance(node, dict): start_page = node.get('start_index') end_page = node.get('end_index') if start_page is not None and end_page is not None: node['text'] = _get_text_of_pages(page_list, start_page, end_page) if 'nodes' in node: add_node_text(node['nodes'], page_list) elif isinstance(node, list): for item in node: add_node_text(item, page_list) def remove_structure_text(data): if isinstance(data, dict): data.pop('text', None) if 'nodes' in data: remove_structure_text(data['nodes']) elif isinstance(data, list): for item in data: remove_structure_text(item) return data # ── Functions migrated from retrieve.py ────────────────────────────────────── def parse_pages(pages: str) -> list[int]: """Parse a pages string like '5-7', '3,8', or '12' into a sorted list of ints.""" result = [] for part in pages.split(','): part = part.strip() if '-' in part: start, end = int(part.split('-', 1)[0].strip()), int(part.split('-', 1)[1].strip()) if start > end: raise ValueError(f"Invalid range '{part}': start must be <= end") result.extend(range(start, end + 1)) else: result.append(int(part)) result = [p for p in result if p >= 1] result = sorted(set(result)) if len(result) > 1000: raise ValueError(f"Page range too large: {len(result)} pages (max 1000)") return result def get_pdf_page_content(file_path: str, page_nums: list[int]) -> list[dict]: """Extract text for specific PDF pages (1-indexed), opening the PDF once.""" with open(file_path, 'rb') as f: pdf_reader = PyPDF2.PdfReader(f) total = len(pdf_reader.pages) valid_pages = [p for p in page_nums if 1 <= p <= total] return [ {'page': p, 'content': pdf_reader.pages[p - 1].extract_text() or ''} for p in valid_pages ] def get_md_page_content(structure: list, page_nums: list[int]) -> list[dict]: """ For Markdown documents, 'pages' are line numbers. Find nodes whose line_num falls within [min(page_nums), max(page_nums)] and return their text. """ if not page_nums: return [] min_line, max_line = min(page_nums), max(page_nums) results = [] seen = set() def _traverse(nodes): for node in nodes: ln = node.get('line_num') if ln and min_line <= ln <= max_line and ln not in seen: seen.add(ln) results.append({'page': ln, 'content': node.get('text', '')}) if node.get('nodes'): _traverse(node['nodes']) _traverse(structure) results.sort(key=lambda x: x['page']) return results