import os import sys from pathlib import Path import gradio as gr import torch # Add the src directory to the path sys.path.insert(0, str(Path(__file__).parent.parent)) from ctx_to_lora.data.processing import tokenize_ctx_text from ctx_to_lora.model_loading import get_tokenizer from ctx_to_lora.modeling import hypernet sys.modules["ctx_to_lora.modeling_utils"] = hypernet # Global state device = torch.device("cuda" if torch.cuda.is_available() else "cpu") modulated_model = None chat_history = [] ctx_tokenizer = None base_tokenizer = None try: DEFAULT_CONTEXT = Path("data/sakana_wiki.txt").read_text(encoding="utf-8").strip() except FileNotFoundError: DEFAULT_CONTEXT = "" WARNING_MESSAGE = ( "⚠️ **Caution**: This is an educational proof-of-concept demonstration.\n" "The model may generate inaccurate information or hallucinate facts." ) FOOTER = """ ⚠️ This model is an experimental prototype and is only available for educational and research and development purposes. It is not suitable for commercial use or in environments where failure can have significant effects (mission-critical environments). The use of this model is at the user's own risk and its performance and results is not guaranteed in any way. Sakana AI is not responsible for any direct or indirect loss resulting from using this model, regardless of the outcome. """ def load_custom_chat_template(tokenizer, model_name): if "gemma" in model_name.lower(): template_path = "chat_templates/google/gemma-2-2b-it.jinja" if os.path.exists(template_path): with open(template_path) as f: template_content = f.read() tokenizer.chat_template = template_content print(f"Loaded custom chat template from {template_path}") return True return False def get_available_checkpoints(): trained_d2l_checkpoints = { str(path) for path in Path().glob("trained_d2l/**/pytorch_model.bin") if path.is_file() } run_output_checkpoints = { str(path) for path in Path().glob("train_outputs/runs/**/pytorch_model.bin") if path.is_file() } checkpoints = sorted(trained_d2l_checkpoints) + sorted( run_output_checkpoints - trained_d2l_checkpoints ) return checkpoints if checkpoints else ["No checkpoints found"] def load_checkpoint( checkpoint_path: str, ) -> tuple[str, gr.update, gr.update, gr.update, gr.update]: global modulated_model, ctx_tokenizer, base_tokenizer, chat_history if not checkpoint_path or checkpoint_path == "No checkpoints found": return ( "⚠️ Please select a valid checkpoint", gr.update(), gr.update(), gr.update(), gr.update(), ) try: print(f"Loading checkpoint: {checkpoint_path}") from ctx_to_lora.modeling.hypernet import ModulatedPretrainedModel state_dict = torch.load(checkpoint_path, weights_only=False) modulated_model = ModulatedPretrainedModel.from_state_dict( state_dict, train=False, use_flash_attn=True, use_sequence_packing=False, ) modulated_model = modulated_model.to(device).to(torch.bfloat16) modulated_model.eval() ctx_encoder_model_name_or_path = ( modulated_model.ctx_encoder_args.ctx_encoder_model_name_or_path or modulated_model.base_model.config.name_or_path ) ctx_tokenizer = get_tokenizer(ctx_encoder_model_name_or_path, train=False) base_tokenizer = get_tokenizer( modulated_model.base_model.config.name_or_path, train=False ) load_custom_chat_template( base_tokenizer, modulated_model.base_model.config.name_or_path ) chat_history = [{"role": "system", "content": ""}] model_name = modulated_model.base_model.config.name_or_path success_msg = ( f"✅ Successfully loaded checkpoint!\n\nBase Model: {model_name}\n\n" "You can now add context and start chatting." ) return ( success_msg, gr.update(interactive=True), # msg gr.update(interactive=True), # send_btn gr.update(interactive=True), # system_msg gr.update(interactive=True), # clear_btn ) except Exception as e: import traceback error_msg = ( f"❌ Error loading checkpoint:\n{str(e)}\n\n{traceback.format_exc()}" ) print(error_msg) return ( error_msg, gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), ) def process_context(context: str) -> dict: context = context.strip() if context else "" tokenized_contexts = tokenize_ctx_text({"context": [context]}, ctx_tokenizer) ctx_ids = tokenized_contexts["ctx_ids"] ctx_ids = [ torch.tensor(ctx_id, dtype=torch.long, device=device) for ctx_id in ctx_ids ] ctx_attn_mask = [torch.ones_like(ids) for ids in ctx_ids] ctx_attn_mask = [ torch.tensor(mask, dtype=torch.long, device=device) for mask in ctx_attn_mask ] ctx_ids = torch.nn.utils.rnn.pad_sequence( ctx_ids, batch_first=True, padding_value=0, ) ctx_attn_mask = torch.nn.utils.rnn.pad_sequence( ctx_attn_mask, batch_first=True, padding_value=0, ) return {"ctx_ids": ctx_ids, "ctx_attn_mask": ctx_attn_mask} def add_user_message(message: str, history): if not message.strip(): return history, "" return history + [[message, None]], "" def generate_response( history: list[list[str]], system_msg: str, context: str, context_scaler: float, bias_scaler: float, ): global modulated_model, chat_history, ctx_tokenizer, base_tokenizer if modulated_model is None: history[-1][1] = "Please load a checkpoint first." yield history return if not history or history[-1][0] is None: yield history return try: user_message = history[-1][0] if system_msg.strip() and chat_history[0]["role"] == "system": chat_history[0]["content"] = system_msg.strip() chat_history.append({"role": "user", "content": user_message}) context = context.strip() if context else "" print(f"Processing single context with scaler: {context_scaler}") print(f"Bias scaler: {bias_scaler}") with torch.inference_mode(), torch.amp.autocast(str(device)): ctx_inputs = process_context(context) ctx_ids = ctx_inputs["ctx_ids"].to(device) ctx_attn_mask = ctx_inputs["ctx_attn_mask"].to(device) scalers_tensor = torch.tensor( [context_scaler], dtype=torch.float32, device=device ) model_inputs = base_tokenizer.apply_chat_template( chat_history, return_tensors="pt", add_generation_prompt=True ).to(device) print(f"Context: {context}") print(f"Chat history: {chat_history}") outputs = modulated_model.generate( ctx_ids=ctx_ids, ctx_attn_mask=ctx_attn_mask, n_ctx_chunks=torch.tensor([len(ctx_ids)], device=ctx_ids.device), scalers=scalers_tensor, bias_scaler=bias_scaler, input_ids=model_inputs, max_new_tokens=512, do_sample=False, temperature=0, ) response = base_tokenizer.decode( outputs[0][model_inputs.shape[1] :], skip_special_tokens=True ) chat_history.append({"role": "assistant", "content": response}) words = response.split() partial_response = "" for word in words: partial_response += word + " " history[-1][1] = partial_response.strip() yield history history[-1][1] = response yield history except Exception as e: import traceback error_msg = f"❌ Error: {str(e)}" print(f"Error generating response: {str(e)}\n\n{traceback.format_exc()}") history[-1][1] = error_msg yield history def reset_chat(system_msg: str): global chat_history chat_history = [ {"role": "system", "content": system_msg.strip() if system_msg else ""} ] return [[None, WARNING_MESSAGE]], "Chat history reset successfully!" custom_css = """ :root { color-scheme: light; } .gradio-container { font-family: 'Inter', sans-serif; } .chat-container { border-radius: 10px; border: 2px solid #d1d5db; } .context-field { border-radius: 8px; padding: 15px; margin-bottom: 10px; border: 2px solid #d1d5db; } .status-box { border-radius: 8px; padding: 15px; margin: 10px 0; border: 2px solid #e5e7eb; } .primary-button { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border: none; border-radius: 6px; color: white; font-weight: 600; } .secondary-button { background-color: #607d8b; border: none; border-radius: 6px; color: white; } #chatbot { height: 500px; } .instruction-text { font-style: italic; color: #666; font-size: 0.9em; margin-bottom: 10px; } .warning-box { background-color: #fff3cd; border: 2px solid #ffc107; border-radius: 8px; padding: 12px; margin: 10px 0; color: #856404; font-size: 0.95em; } .warning-box strong { color: #d97706; } .disabled-overlay { background-color: #f5f5f5; border: 3px dashed #999; border-radius: 10px; padding: 20px; text-align: center; color: #999; } .chat-disabled-notice { border: 3px solid #f59e0b; border-radius: 8px; padding: 15px; margin-bottom: 15px; color: #92400e; font-weight: 600; text-align: center; background-color: #fef3c7; box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1); } .internalization-banner { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; border-radius: 10px; padding: 20px; margin-bottom: 15px; text-align: center; font-weight: 600; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.15); border: 3px solid #5568d3; } .internalization-banner h3 { margin: 0 0 10px 0; font-size: 1.2em; color: white; } .internalization-banner p { margin: 5px 0; font-size: 0.95em; font-weight: 400; color: white; } .context-section-header { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border: 3px solid #5568d3; border-left: 6px solid #4c51bf; padding: 15px; margin-bottom: 15px; border-radius: 6px; color: white; box-shadow: 0 2px 6px rgba(102, 126, 234, 0.3); } .context-section-header strong, .context-section-header small { color: white; } .chat-section-header { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border: 3px solid #2563eb; border-left: 6px solid #1d4ed8; padding: 15px; margin-bottom: 15px; border-radius: 6px; color: white; box-shadow: 0 2px 6px rgba(59, 130, 246, 0.3); } .chat-section-header strong, .chat-section-header small { color: white; } .panel-box { background-color: rgba(249, 250, 251, 0.5); border: 2px solid rgba(209, 213, 219, 0.5); border-radius: 12px; padding: 20px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.05); } .chat-panel-box { background-color: rgba(249, 250, 251, 0.5); border: 2px solid rgba(209, 213, 219, 0.5); border-radius: 12px; padding: 20px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.05); } #checkpoint-dropdown { position: relative; z-index: 20; } #checkpoint-dropdown [role="listbox"] { z-index: 9999 !important; } /* Dark mode support */ .dark .panel-box, .dark .chat-panel-box { background-color: rgba(31, 41, 55, 0.5); border: 2px solid rgba(75, 85, 99, 0.5); } .dark .context-field, .dark .chat-container { border-color: rgba(75, 85, 99, 0.5); } .dark .status-box { border-color: rgba(55, 65, 81, 0.5); } """ def create_demo(): with gr.Blocks( title="Doc-to-LoRA Chat Interface", theme=gr.themes.Soft(), css=custom_css, ) as demo: gr.Markdown( """ # 📜 Doc-to-LoRA Chat Interface Load a hypernetwork checkpoint and chat with a context-modulated language model. Add one context with a scaling parameter to influence the model's responses. """ ) gr.HTML( """
""" ) gr.Markdown( """ ### 📖 Usage Instructions 1. **Load a Checkpoint**: Select a hypernetwork checkpoint from the dropdown and click "Load Checkpoint" 2. **Configure Context**: - Enter your context information in the text field - Adjust the scaling slider to control context influence 3. **Set Bias Scaler**: Adjust the bias scaler to control overall model behavior 4. **Start Chatting**: Once the model is loaded, type your message and press Shift+Enter or click Send 5. **Reset**: Use the "Reset Chat" button to start a new conversation 💡 **Tip**: You can use context to provide background information or specific knowledge that should influence the model's responses. """ ) with gr.Row(): with gr.Column(scale=1, elem_classes="panel-box"): gr.Markdown("### 📦 Load Checkpoint") gr.Markdown( "*Select a trained hypernetwork checkpoint to begin.*", elem_classes="instruction-text", ) checkpoint_dropdown = gr.Dropdown( choices=get_available_checkpoints(), label="Select Checkpoint", value=None, interactive=True, elem_id="checkpoint-dropdown", ) load_btn = gr.Button("Load Checkpoint", variant="primary", size="lg") status_box = gr.Textbox( label="Status", lines=8, interactive=False, elem_classes="status-box", ) gr.Markdown("---") gr.HTML( """