doc-to-lora/demo/app.py
2026-02-27 03:47:04 +00:00

685 lines
20 KiB
Python

import os
import sys
from pathlib import Path
import gradio as gr
import torch
# Add the src directory to the path
sys.path.insert(0, str(Path(__file__).parent.parent))
from ctx_to_lora.data.processing import tokenize_ctx_text
from ctx_to_lora.model_loading import get_tokenizer
from ctx_to_lora.modeling import hypernet
sys.modules["ctx_to_lora.modeling_utils"] = hypernet
# Global state
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
modulated_model = None
chat_history = []
ctx_tokenizer = None
base_tokenizer = None
try:
DEFAULT_CONTEXT = Path("data/sakana_wiki.txt").read_text(encoding="utf-8").strip()
except FileNotFoundError:
DEFAULT_CONTEXT = ""
WARNING_MESSAGE = (
"⚠️ **Caution**: This is an educational proof-of-concept demonstration.\n"
"The model may generate inaccurate information or hallucinate facts."
)
FOOTER = """
⚠️ This model is an experimental prototype and is only available for educational and research and development purposes. It is not suitable for commercial use or in environments where failure can have significant effects (mission-critical environments).
The use of this model is at the user's own risk and its performance and results is not guaranteed in any way.
Sakana AI is not responsible for any direct or indirect loss resulting from using this model, regardless of the outcome.
"""
def load_custom_chat_template(tokenizer, model_name):
if "gemma" in model_name.lower():
template_path = "chat_templates/google/gemma-2-2b-it.jinja"
if os.path.exists(template_path):
with open(template_path) as f:
template_content = f.read()
tokenizer.chat_template = template_content
print(f"Loaded custom chat template from {template_path}")
return True
return False
def get_available_checkpoints():
trained_d2l_checkpoints = {
str(path)
for path in Path().glob("trained_d2l/**/pytorch_model.bin")
if path.is_file()
}
run_output_checkpoints = {
str(path)
for path in Path().glob("train_outputs/runs/**/pytorch_model.bin")
if path.is_file()
}
checkpoints = sorted(trained_d2l_checkpoints) + sorted(
run_output_checkpoints - trained_d2l_checkpoints
)
return checkpoints if checkpoints else ["No checkpoints found"]
def load_checkpoint(
checkpoint_path: str,
) -> tuple[str, gr.update, gr.update, gr.update, gr.update]:
global modulated_model, ctx_tokenizer, base_tokenizer, chat_history
if not checkpoint_path or checkpoint_path == "No checkpoints found":
return (
"⚠️ Please select a valid checkpoint",
gr.update(),
gr.update(),
gr.update(),
gr.update(),
)
try:
print(f"Loading checkpoint: {checkpoint_path}")
from ctx_to_lora.modeling.hypernet import ModulatedPretrainedModel
state_dict = torch.load(checkpoint_path, weights_only=False)
modulated_model = ModulatedPretrainedModel.from_state_dict(
state_dict,
train=False,
use_flash_attn=True,
use_sequence_packing=False,
)
modulated_model = modulated_model.to(device).to(torch.bfloat16)
modulated_model.eval()
ctx_encoder_model_name_or_path = (
modulated_model.ctx_encoder_args.ctx_encoder_model_name_or_path
or modulated_model.base_model.config.name_or_path
)
ctx_tokenizer = get_tokenizer(ctx_encoder_model_name_or_path, train=False)
base_tokenizer = get_tokenizer(
modulated_model.base_model.config.name_or_path, train=False
)
load_custom_chat_template(
base_tokenizer, modulated_model.base_model.config.name_or_path
)
chat_history = [{"role": "system", "content": ""}]
model_name = modulated_model.base_model.config.name_or_path
success_msg = (
f"✅ Successfully loaded checkpoint!\n\nBase Model: {model_name}\n\n"
"You can now add context and start chatting."
)
return (
success_msg,
gr.update(interactive=True), # msg
gr.update(interactive=True), # send_btn
gr.update(interactive=True), # system_msg
gr.update(interactive=True), # clear_btn
)
except Exception as e:
import traceback
error_msg = (
f"❌ Error loading checkpoint:\n{str(e)}\n\n{traceback.format_exc()}"
)
print(error_msg)
return (
error_msg,
gr.update(interactive=False),
gr.update(interactive=False),
gr.update(interactive=False),
gr.update(interactive=False),
)
def process_context(context: str) -> dict:
context = context.strip() if context else ""
tokenized_contexts = tokenize_ctx_text({"context": [context]}, ctx_tokenizer)
ctx_ids = tokenized_contexts["ctx_ids"]
ctx_ids = [
torch.tensor(ctx_id, dtype=torch.long, device=device) for ctx_id in ctx_ids
]
ctx_attn_mask = [torch.ones_like(ids) for ids in ctx_ids]
ctx_attn_mask = [
torch.tensor(mask, dtype=torch.long, device=device) for mask in ctx_attn_mask
]
ctx_ids = torch.nn.utils.rnn.pad_sequence(
ctx_ids,
batch_first=True,
padding_value=0,
)
ctx_attn_mask = torch.nn.utils.rnn.pad_sequence(
ctx_attn_mask,
batch_first=True,
padding_value=0,
)
return {"ctx_ids": ctx_ids, "ctx_attn_mask": ctx_attn_mask}
def add_user_message(message: str, history):
if not message.strip():
return history, ""
return history + [[message, None]], ""
def generate_response(
history: list[list[str]],
system_msg: str,
context: str,
context_scaler: float,
bias_scaler: float,
):
global modulated_model, chat_history, ctx_tokenizer, base_tokenizer
if modulated_model is None:
history[-1][1] = "Please load a checkpoint first."
yield history
return
if not history or history[-1][0] is None:
yield history
return
try:
user_message = history[-1][0]
if system_msg.strip() and chat_history[0]["role"] == "system":
chat_history[0]["content"] = system_msg.strip()
chat_history.append({"role": "user", "content": user_message})
context = context.strip() if context else ""
print(f"Processing single context with scaler: {context_scaler}")
print(f"Bias scaler: {bias_scaler}")
with torch.inference_mode(), torch.amp.autocast(str(device)):
ctx_inputs = process_context(context)
ctx_ids = ctx_inputs["ctx_ids"].to(device)
ctx_attn_mask = ctx_inputs["ctx_attn_mask"].to(device)
scalers_tensor = torch.tensor(
[context_scaler], dtype=torch.float32, device=device
)
model_inputs = base_tokenizer.apply_chat_template(
chat_history, return_tensors="pt", add_generation_prompt=True
).to(device)
print(f"Context: {context}")
print(f"Chat history: {chat_history}")
outputs = modulated_model.generate(
ctx_ids=ctx_ids,
ctx_attn_mask=ctx_attn_mask,
n_ctx_chunks=torch.tensor([len(ctx_ids)], device=ctx_ids.device),
scalers=scalers_tensor,
bias_scaler=bias_scaler,
input_ids=model_inputs,
max_new_tokens=512,
do_sample=False,
temperature=0,
)
response = base_tokenizer.decode(
outputs[0][model_inputs.shape[1] :], skip_special_tokens=True
)
chat_history.append({"role": "assistant", "content": response})
words = response.split()
partial_response = ""
for word in words:
partial_response += word + " "
history[-1][1] = partial_response.strip()
yield history
history[-1][1] = response
yield history
except Exception as e:
import traceback
error_msg = f"❌ Error: {str(e)}"
print(f"Error generating response: {str(e)}\n\n{traceback.format_exc()}")
history[-1][1] = error_msg
yield history
def reset_chat(system_msg: str):
global chat_history
chat_history = [
{"role": "system", "content": system_msg.strip() if system_msg else ""}
]
return [[None, WARNING_MESSAGE]], "Chat history reset successfully!"
custom_css = """
:root {
color-scheme: light;
}
.gradio-container {
font-family: 'Inter', sans-serif;
}
.chat-container {
border-radius: 10px;
border: 2px solid #d1d5db;
}
.context-field {
border-radius: 8px;
padding: 15px;
margin-bottom: 10px;
border: 2px solid #d1d5db;
}
.status-box {
border-radius: 8px;
padding: 15px;
margin: 10px 0;
border: 2px solid #e5e7eb;
}
.primary-button {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
border: none;
border-radius: 6px;
color: white;
font-weight: 600;
}
.secondary-button {
background-color: #607d8b;
border: none;
border-radius: 6px;
color: white;
}
#chatbot {
height: 500px;
}
.instruction-text {
font-style: italic;
color: #666;
font-size: 0.9em;
margin-bottom: 10px;
}
.warning-box {
background-color: #fff3cd;
border: 2px solid #ffc107;
border-radius: 8px;
padding: 12px;
margin: 10px 0;
color: #856404;
font-size: 0.95em;
}
.warning-box strong {
color: #d97706;
}
.disabled-overlay {
background-color: #f5f5f5;
border: 3px dashed #999;
border-radius: 10px;
padding: 20px;
text-align: center;
color: #999;
}
.chat-disabled-notice {
border: 3px solid #f59e0b;
border-radius: 8px;
padding: 15px;
margin-bottom: 15px;
color: #92400e;
font-weight: 600;
text-align: center;
background-color: #fef3c7;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
}
.internalization-banner {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
border-radius: 10px;
padding: 20px;
margin-bottom: 15px;
text-align: center;
font-weight: 600;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.15);
border: 3px solid #5568d3;
}
.internalization-banner h3 {
margin: 0 0 10px 0;
font-size: 1.2em;
color: white;
}
.internalization-banner p {
margin: 5px 0;
font-size: 0.95em;
font-weight: 400;
color: white;
}
.context-section-header {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
border: 3px solid #5568d3;
border-left: 6px solid #4c51bf;
padding: 15px;
margin-bottom: 15px;
border-radius: 6px;
color: white;
box-shadow: 0 2px 6px rgba(102, 126, 234, 0.3);
}
.context-section-header strong,
.context-section-header small {
color: white;
}
.chat-section-header {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
border: 3px solid #2563eb;
border-left: 6px solid #1d4ed8;
padding: 15px;
margin-bottom: 15px;
border-radius: 6px;
color: white;
box-shadow: 0 2px 6px rgba(59, 130, 246, 0.3);
}
.chat-section-header strong,
.chat-section-header small {
color: white;
}
.panel-box {
background-color: rgba(249, 250, 251, 0.5);
border: 2px solid rgba(209, 213, 219, 0.5);
border-radius: 12px;
padding: 20px;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.05);
}
.chat-panel-box {
background-color: rgba(249, 250, 251, 0.5);
border: 2px solid rgba(209, 213, 219, 0.5);
border-radius: 12px;
padding: 20px;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.05);
}
#checkpoint-dropdown {
position: relative;
z-index: 20;
}
#checkpoint-dropdown [role="listbox"] {
z-index: 9999 !important;
}
/* Dark mode support */
.dark .panel-box,
.dark .chat-panel-box {
background-color: rgba(31, 41, 55, 0.5);
border: 2px solid rgba(75, 85, 99, 0.5);
}
.dark .context-field,
.dark .chat-container {
border-color: rgba(75, 85, 99, 0.5);
}
.dark .status-box {
border-color: rgba(55, 65, 81, 0.5);
}
"""
def create_demo():
with gr.Blocks(
title="Doc-to-LoRA Chat Interface",
theme=gr.themes.Soft(),
css=custom_css,
) as demo:
gr.Markdown(
"""
# 📜 Doc-to-LoRA Chat Interface
Load a hypernetwork checkpoint and chat with a context-modulated language model.
Add one context with a scaling parameter to influence the model's responses.
"""
)
gr.HTML(
"""
<div class="internalization-banner">
<h3>🧠 How Context Internalization Works</h3>
<p>📥 Contexts are processed by the hypernetworkto dynamically modulate the base model's parameters</p>
<p>🚫 Contexts are NOT passed as text to the base model — they influence behavior internally</p>
<p>💬 Only your chat messages (below) are sent to the language model</p>
</div>
"""
)
gr.Markdown(
"""
### 📖 Usage Instructions
1. **Load a Checkpoint**: Select a hypernetwork checkpoint from the dropdown and click "Load Checkpoint"
2. **Configure Context**:
- Enter your context information in the text field
- Adjust the scaling slider to control context influence
3. **Set Bias Scaler**: Adjust the bias scaler to control overall model behavior
4. **Start Chatting**: Once the model is loaded, type your message and press Shift+Enter or click Send
5. **Reset**: Use the "Reset Chat" button to start a new conversation
💡 **Tip**: You can use context to provide background information or specific knowledge
that should influence the model's responses.
"""
)
with gr.Row():
with gr.Column(scale=1, elem_classes="panel-box"):
gr.Markdown("### 📦 Load Checkpoint")
gr.Markdown(
"*Select a trained hypernetwork checkpoint to begin.*",
elem_classes="instruction-text",
)
checkpoint_dropdown = gr.Dropdown(
choices=get_available_checkpoints(),
label="Select Checkpoint",
value=None,
interactive=True,
elem_id="checkpoint-dropdown",
)
load_btn = gr.Button("Load Checkpoint", variant="primary", size="lg")
status_box = gr.Textbox(
label="Status",
lines=8,
interactive=False,
elem_classes="status-box",
)
gr.Markdown("---")
gr.HTML(
"""
<div class="context-section-header">
<strong>🧠 Context Internalization (Hypernetwork Input)</strong><br>
<small>This context modulates the model internally — it is NOT shown to the base model</small>
</div>
"""
)
context = gr.Textbox(
label="🧠 Context (Internalized via Hypernetwork)",
placeholder="Enter context to be internalized by the hypernetwork...",
lines=4,
value=DEFAULT_CONTEXT,
)
context_scaler = gr.Slider(
minimum=-2.0,
maximum=2.0,
step=0.01,
value=1.0,
label="Context Scaling",
)
gr.Markdown("---")
bias_scaler = gr.Slider(
minimum=-2.0,
maximum=2.0,
step=0.01,
value=1.0,
label="Bias Scaler",
info="A single scalar applied to bias parameters (independent of contexts)",
)
with gr.Column(scale=2, elem_classes="chat-panel-box"):
gr.HTML(
"""
<div class="chat-section-header">
<strong>💬 Chat Interface (Direct Input to Base Model)</strong><br>
<small>Your messages here are the ONLY text the base model sees — contexts above influence it internally</small>
</div>
"""
)
chat_status_notice = gr.HTML(
"""
<div class="chat-disabled-notice">
🔒 <strong>Chat Disabled:</strong> Please load a checkpoint first to enable chat functionality.
</div>
""",
visible=True,
)
system_msg = gr.Textbox(
label="System Message (Optional - Sent to Base Model)",
placeholder="Load a checkpoint to enable chat...",
lines=2,
interactive=False,
)
chatbot = gr.Chatbot(
label="Conversation",
show_copy_button=True,
height=500,
elem_id="chatbot",
elem_classes="chat-container",
value=[
[
None,
"🔒 Chat is currently disabled. Please load a checkpoint from the left panel to begin chatting.",
]
],
)
with gr.Row():
msg = gr.Textbox(
label="Your Message (Sent Directly to Base Model)",
placeholder="⚠️ Load a checkpoint first to start chatting...",
lines=2,
scale=4,
interactive=False,
)
send_btn = gr.Button(
"🔒 Send (Disabled)",
variant="primary",
scale=1,
interactive=False,
)
with gr.Row():
clear_btn = gr.Button(
"🔒 Reset Chat (Disabled)",
variant="secondary",
interactive=False,
)
reset_status = gr.Textbox(label="Reset Status", visible=False)
load_btn.click(
fn=load_checkpoint,
inputs=[checkpoint_dropdown],
outputs=[status_box, msg, send_btn, system_msg, clear_btn],
).then(
fn=lambda: (
gr.update(visible=False),
gr.update(
placeholder="Type your message here... (Shift+Enter for new line)"
),
gr.update(value="Send"),
gr.update(value="🔄 Reset Chat"),
gr.update(value=[[None, WARNING_MESSAGE]]),
),
outputs=[chat_status_notice, msg, send_btn, clear_btn, chatbot],
)
msg.submit(
fn=add_user_message,
inputs=[msg, chatbot],
outputs=[chatbot, msg],
).then(
fn=generate_response,
inputs=[
chatbot,
system_msg,
context,
context_scaler,
bias_scaler,
],
outputs=[chatbot],
)
send_btn.click(
fn=add_user_message,
inputs=[msg, chatbot],
outputs=[chatbot, msg],
).then(
fn=generate_response,
inputs=[
chatbot,
system_msg,
context,
context_scaler,
bias_scaler,
],
outputs=[chatbot],
)
clear_btn.click(
fn=reset_chat,
inputs=[system_msg],
outputs=[chatbot, reset_status],
)
gr.Markdown(f"---\n{FOOTER.strip()}")
return demo
if __name__ == "__main__":
demo = create_demo()
demo.launch(
server_name="0.0.0.0",
server_port=7861,
share=False,
debug=True,
)