mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-30 11:26:23 +02:00
add spo streamlit app
This commit is contained in:
parent
46fcaa097e
commit
7d5ed2a7f8
1 changed files with 199 additions and 0 deletions
199
metagpt/ext/spo/app.py
Normal file
199
metagpt/ext/spo/app.py
Normal file
|
|
@ -0,0 +1,199 @@
|
|||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
root_path = Path(__file__).parent.parent.parent.parent
|
||||
sys.path.append(str(root_path))
|
||||
|
||||
import streamlit as st
|
||||
import yaml
|
||||
import os
|
||||
from metagpt.ext.spo.components.optimizer import PromptOptimizer
|
||||
from metagpt.ext.spo.utils.llm_client import SPO_LLM
|
||||
|
||||
|
||||
def load_yaml_template(template_path):
|
||||
if template_path.exists():
|
||||
with open(template_path, 'r', encoding='utf-8') as f:
|
||||
return yaml.safe_load(f)
|
||||
return {"prompt": "", "requirements": "", "count": None, "faq": [{"question": "", "answer": ""}]}
|
||||
|
||||
|
||||
def save_yaml_template(template_path, data):
|
||||
# 确保数据结构正确
|
||||
template_format = {
|
||||
"prompt": str(data.get("prompt", "")),
|
||||
"requirements": str(data.get("requirements", "")),
|
||||
"count": data.get("count"),
|
||||
"faq": [
|
||||
{
|
||||
"question": str(faq.get("question", "")),
|
||||
"answer": str(faq.get("answer", ""))
|
||||
}
|
||||
for faq in data.get("faq", [])
|
||||
]
|
||||
}
|
||||
|
||||
# 创建目录(如果不存在)
|
||||
template_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 保存文件
|
||||
with open(template_path, 'w', encoding='utf-8') as f:
|
||||
yaml.dump(template_format, f, allow_unicode=True, sort_keys=False)
|
||||
|
||||
|
||||
def main():
|
||||
st.title("SPO Prompt Optimizer")
|
||||
|
||||
# Sidebar for configurations
|
||||
with st.sidebar:
|
||||
st.header("Configuration")
|
||||
|
||||
# Template Selection/Creation
|
||||
settings_path = Path("metagpt/ext/spo/settings")
|
||||
existing_templates = [f.stem for f in settings_path.glob("*.yaml")]
|
||||
|
||||
template_mode = st.radio("Template Mode", ["Use Existing", "Create New"])
|
||||
|
||||
if template_mode == "Use Existing":
|
||||
template_name = st.selectbox("Select Template", existing_templates)
|
||||
else:
|
||||
template_name = st.text_input("New Template Name")
|
||||
if template_name and not template_name.endswith('.yaml'):
|
||||
template_name = f"{template_name}"
|
||||
|
||||
# LLM Settings
|
||||
st.subheader("LLM Settings")
|
||||
opt_model = st.selectbox(
|
||||
"Optimization Model",
|
||||
["claude-3-5-sonnet-20240620", "gpt-4o", "gpt-4o-mini", "deepseek-chat"],
|
||||
index=0
|
||||
)
|
||||
opt_temp = st.slider("Optimization Temperature", 0.0, 1.0, 0.7)
|
||||
|
||||
eval_model = st.selectbox(
|
||||
"Evaluation Model",
|
||||
["claude-3-5-sonnet-20240620", "gpt-4o", "gpt-4o-mini", "deepseek-chat"],
|
||||
index=0
|
||||
)
|
||||
eval_temp = st.slider("Evaluation Temperature", 0.0, 1.0, 0.3)
|
||||
|
||||
exec_model = st.selectbox(
|
||||
"Execution Model",
|
||||
["claude-3-5-sonnet-20240620", "gpt-4o", "gpt-4o-mini", "deepseek-chat"],
|
||||
index=0
|
||||
)
|
||||
exec_temp = st.slider("Execution Temperature", 0.0, 1.0, 0.0)
|
||||
|
||||
# Optimizer Settings
|
||||
st.subheader("Optimizer Settings")
|
||||
initial_round = st.number_input("Initial Round", 1, 100, 1)
|
||||
max_rounds = st.number_input("Maximum Rounds", 1, 100, 10)
|
||||
|
||||
|
||||
# Main content area
|
||||
st.header("Template Configuration")
|
||||
|
||||
if template_name:
|
||||
template_path = settings_path / f"{template_name}.yaml"
|
||||
template_data = load_yaml_template(template_path)
|
||||
|
||||
# 使用key来检测模板是否改变
|
||||
if 'current_template' not in st.session_state or st.session_state.current_template != template_name:
|
||||
st.session_state.current_template = template_name
|
||||
st.session_state.faqs = template_data.get('faq', [])
|
||||
|
||||
# Edit template sections
|
||||
prompt = st.text_area("Prompt", template_data.get('prompt', ''), height=100)
|
||||
requirements = st.text_area("Requirements", template_data.get('requirements', ''), height=100)
|
||||
|
||||
# FAQ section
|
||||
st.subheader("FAQ Examples")
|
||||
|
||||
# Add new FAQ button
|
||||
if st.button("Add New FAQ"):
|
||||
st.session_state.faqs.append({"question": "", "answer": ""})
|
||||
|
||||
# Edit FAQs
|
||||
new_faqs = []
|
||||
for i in range(len(st.session_state.faqs)):
|
||||
st.markdown(f"**FAQ #{i + 1}**")
|
||||
col1, col2, col3 = st.columns([45, 45, 10])
|
||||
|
||||
# 使用unique key确保每个FAQ都有独立的状态
|
||||
with col1:
|
||||
question = st.text_area(
|
||||
f"Question {i + 1}",
|
||||
st.session_state.faqs[i].get('question', ''),
|
||||
key=f"q_{i}",
|
||||
height=100
|
||||
)
|
||||
with col2:
|
||||
answer = st.text_area(
|
||||
f"Answer {i + 1}",
|
||||
st.session_state.faqs[i].get('answer', ''),
|
||||
key=f"a_{i}",
|
||||
height=100
|
||||
)
|
||||
with col3:
|
||||
if st.button("🗑️", key=f"delete_{i}"):
|
||||
st.session_state.faqs.pop(i)
|
||||
st.rerun()
|
||||
|
||||
new_faqs.append({"question": question, "answer": answer})
|
||||
|
||||
# Save template button
|
||||
if st.button("Save Template"):
|
||||
new_template_data = {
|
||||
"prompt": prompt,
|
||||
"requirements": requirements,
|
||||
"count": None,
|
||||
"faq": new_faqs
|
||||
}
|
||||
# 保存到文件
|
||||
save_yaml_template(template_path, new_template_data)
|
||||
# 更新session state
|
||||
st.session_state.faqs = new_faqs
|
||||
st.success(f"Template saved to {template_path}")
|
||||
|
||||
# 显示当前YAML预览
|
||||
st.subheader("Current Template Preview")
|
||||
preview_data = {
|
||||
"prompt": prompt,
|
||||
"requirements": requirements,
|
||||
"count": None,
|
||||
"faq": new_faqs
|
||||
}
|
||||
st.code(yaml.dump(preview_data, allow_unicode=True), language="yaml")
|
||||
|
||||
# Start optimization button
|
||||
if st.button("Start Optimization"):
|
||||
try:
|
||||
# Initialize LLM
|
||||
SPO_LLM.initialize(
|
||||
optimize_kwargs={"model": opt_model, "temperature": opt_temp},
|
||||
evaluate_kwargs={"model": eval_model, "temperature": eval_temp},
|
||||
execute_kwargs={"model": exec_model, "temperature": exec_temp},
|
||||
)
|
||||
|
||||
# Create optimizer instance
|
||||
optimizer = PromptOptimizer(
|
||||
optimized_path="workspace",
|
||||
initial_round=initial_round,
|
||||
max_rounds=max_rounds,
|
||||
template=f"{template_name}.yaml",
|
||||
name=template_name,
|
||||
iteration=True,
|
||||
)
|
||||
|
||||
# Run optimization with progress bar
|
||||
with st.spinner("Optimizing prompts..."):
|
||||
optimizer.optimize()
|
||||
|
||||
st.success("Optimization completed!")
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"An error occurred: {str(e)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue