doc-to-lora/webui/self_gen_viewer.py
2026-02-27 03:47:04 +00:00

170 lines
5.2 KiB
Python

import traceback
from pathlib import Path
from datasets import load_dataset
from flask import Flask, jsonify, render_template, request
from transformers import AutoTokenizer
app = Flask(__name__)
# Base path for self_gen data
BASE_DATA_PATH = Path(__file__).parent.parent / "data" / "raw_datasets" / "self_gen"
# Cache for tokenizers
tokenizer_cache = {}
def get_tokenizer(model_path):
"""Get or create tokenizer with caching"""
if model_path not in tokenizer_cache:
try:
tokenizer_cache[model_path] = AutoTokenizer.from_pretrained(model_path)
except Exception as e:
print(f"Error loading tokenizer for {model_path}: {e}")
return None
return tokenizer_cache[model_path]
def discover_folders():
"""Discover all model folders in self_gen directory"""
folders = []
if not BASE_DATA_PATH.exists():
return folders
for vendor_dir in BASE_DATA_PATH.iterdir():
if vendor_dir.is_dir():
for model_dir in vendor_dir.iterdir():
if model_dir.is_dir():
rel_path = model_dir.relative_to(BASE_DATA_PATH)
folders.append(str(rel_path))
return sorted(folders)
def discover_parquet_files(folder_path):
"""Discover all parquet files in a folder"""
full_path = BASE_DATA_PATH / folder_path
parquet_files = []
if full_path.exists():
for parquet_file in full_path.glob("**/*.parquet"):
rel_path = parquet_file.relative_to(full_path)
parquet_files.append(str(rel_path))
return sorted(parquet_files)
def extract_model_name_from_folder(folder_path):
"""Extract base model name from folder path"""
# e.g., "google/gemma-2-2b-it_temp_0.0_closed_qa_prob_1.0" -> "google/gemma-2-2b-it"
parts = folder_path.split("/")
if len(parts) >= 2:
vendor = parts[0]
model_part = parts[1].split("_temp_")[0]
return f"{vendor}/{model_part}"
return None
@app.route("/")
def index():
"""Main page"""
folders = discover_folders()
return render_template("self_gen_viewer.html", folders=folders)
@app.route("/api/folders")
def api_folders():
"""API endpoint to get available folders"""
folders = discover_folders()
return jsonify({"folders": folders})
@app.route("/api/parquet_files")
def api_parquet_files():
"""API endpoint to get parquet files in a folder"""
folder = request.args.get("folder", "")
if not folder:
return jsonify({"error": "No folder specified"}), 400
files = discover_parquet_files(folder)
return jsonify({"files": files})
@app.route("/api/load_data")
def api_load_data():
"""API endpoint to load and display data from a parquet file"""
folder = request.args.get("folder", "")
parquet_file = request.args.get("file", "")
num_samples = int(request.args.get("num_samples", 100))
if not folder or not parquet_file:
return jsonify({"error": "Missing parameters"}), 400
try:
# Construct full path
full_path = BASE_DATA_PATH / folder / parquet_file
if not full_path.exists():
return jsonify({"error": f"File not found: {full_path}"}), 404
# Extract model name for tokenizer
model_name = extract_model_name_from_folder(folder)
if not model_name:
return jsonify({"error": "Could not extract model name from folder"}), 400
# Load tokenizer
tokenizer = get_tokenizer(model_name)
if tokenizer is None:
return jsonify({"error": f"Could not load tokenizer for {model_name}"}), 500
# Load dataset
ds = load_dataset(
"parquet", data_files=str(full_path), split=f"train[:{num_samples}]"
)
# Process samples
samples = []
for i, sample in enumerate(ds):
processed_sample = {
"index": i,
"ctx": tokenizer.decode(sample["ctx_ids"], skip_special_tokens=False)
if "ctx_ids" in sample
else "N/A",
"questions": [],
}
# Decode input_ids if present
if "input_ids" in sample:
if isinstance(sample["input_ids"][0], list):
# Multiple Q&A pairs
processed_sample["questions"] = [
tokenizer.decode(qa, skip_special_tokens=False)
for qa in sample["input_ids"]
]
else:
# Single item
processed_sample["questions"] = [
tokenizer.decode(sample["input_ids"], skip_special_tokens=False)
]
samples.append(processed_sample)
return jsonify(
{
"success": True,
"num_samples": len(samples),
"model_name": model_name,
"file_path": str(parquet_file),
"samples": samples,
}
)
except Exception as e:
traceback.print_exc()
return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500
if __name__ == "__main__":
print(f"Data path: {BASE_DATA_PATH}")
print(f"Available folders: {discover_folders()}")
app.run(debug=True, host="0.0.0.0", port=5001)