diff --git a/metagpt/ext/spo/app.py b/metagpt/ext/spo/app.py
index 563eb92ff..183757124 100644
--- a/metagpt/ext/spo/app.py
+++ b/metagpt/ext/spo/app.py
@@ -1,3 +1,4 @@
+import asyncio
import sys
from pathlib import Path
@@ -9,7 +10,7 @@ sys.path.append(str(Path(__file__).parents[3]))
from metagpt.const import METAGPT_ROOT # noqa: E402
from metagpt.ext.spo.components.optimizer import PromptOptimizer # noqa: E402
-from metagpt.ext.spo.utils.llm_client import SPO_LLM # noqa: E402
+from metagpt.ext.spo.utils.llm_client import SPO_LLM, RequestType # noqa: E402
def load_yaml_template(template_path):
@@ -36,7 +37,47 @@ def save_yaml_template(template_path, data):
yaml.dump(template_format, f, allow_unicode=True, sort_keys=False, default_flow_style=False, indent=2)
+def display_optimization_results(result_data):
+ for result in result_data:
+ round_num = result["round"]
+ success = result["succeed"]
+ prompt = result["prompt"]
+
+ with st.expander(f"Round {round_num} {':white_check_mark:' if success else ':x:'}"):
+ st.markdown("**Prompt:**")
+ st.code(prompt, language="text")
+ st.markdown("
", unsafe_allow_html=True)
+
+ col1, col2 = st.columns(2)
+ with col1:
+ st.markdown(f"**Status:** {'Success ✅ ' if success else 'Failed ❌ '}")
+ with col2:
+ st.markdown(f"**Tokens:** {result['tokens']}")
+
+ st.markdown("**Answers:**")
+ for idx, answer in enumerate(result["answers"]):
+ st.markdown(f"**Question {idx + 1}:**")
+ st.text(answer["question"])
+ st.markdown("**Answer:**")
+ st.text(answer["answer"])
+ st.markdown("---")
+
+ # Summary
+ success_count = sum(1 for r in result_data if r["succeed"])
+ total_rounds = len(result_data)
+
+ st.markdown("### Summary")
+ col1, col2 = st.columns(2)
+ with col1:
+ st.metric("Total Rounds", total_rounds)
+ with col2:
+ st.metric("Successful Rounds", success_count)
+
+
def main():
+ if "optimization_results" not in st.session_state:
+ st.session_state.optimization_results = []
+
st.title("SPO | Self-Supervised Prompt Optimization 🤖")
# Sidebar for configurations
@@ -189,45 +230,57 @@ def main():
prompt_path = f"{optimizer.root_path}/prompts"
result_data = optimizer.data_utils.load_results(prompt_path)
- for result in result_data:
- round_num = result["round"]
- success = result["succeed"]
- prompt = result["prompt"]
-
- with st.expander(f"Round {round_num} {':white_check_mark:' if success else ':x:'}"):
- st.markdown("**Prompt:**")
- st.code(prompt, language="text")
- st.markdown("
", unsafe_allow_html=True)
-
- col1, col2 = st.columns(2)
- with col1:
- st.markdown(f"**Status:** {'Success ✅ ' if success else 'Failed ❌ '}")
- with col2:
- st.markdown(f"**Tokens:** {result['tokens']}")
-
- st.markdown("**Answers:**")
- for idx, answer in enumerate(result["answers"]):
- st.markdown(f"**Question {idx + 1}:**")
- st.text(answer["question"])
- st.markdown("**Answer:**")
- st.text(answer["answer"])
- st.markdown("---")
-
- # Summary
- success_count = sum(1 for r in result_data if r["succeed"])
- total_rounds = len(result_data)
-
- st.markdown("### Summary")
- col1, col2 = st.columns(2)
- with col1:
- st.metric("Total Rounds", total_rounds)
- with col2:
- st.metric("Successful Rounds", success_count)
+ st.session_state.optimization_results = result_data
except Exception as e:
st.error(f"An error occurred: {str(e)}")
_logger.error(f"Error during optimization: {str(e)}")
+ if st.session_state.optimization_results:
+ st.header("Optimization Results")
+ display_optimization_results(st.session_state.optimization_results)
+
+ st.markdown("---")
+ st.subheader("Test Optimized Prompt")
+ col1, col2 = st.columns(2)
+
+ with col1:
+ test_prompt = st.text_area("Optimized Prompt", value="", height=200, key="test_prompt")
+
+ with col2:
+ test_question = st.text_area("Your Question", value="", height=200, key="test_question")
+
+ if st.button("Test Prompt"):
+ if test_prompt and test_question:
+ try:
+ with st.spinner("Generating response..."):
+ SPO_LLM.initialize(
+ optimize_kwargs={"model": opt_model, "temperature": opt_temp},
+ evaluate_kwargs={"model": eval_model, "temperature": eval_temp},
+ execute_kwargs={"model": exec_model, "temperature": exec_temp},
+ )
+
+ llm = SPO_LLM.get_instance()
+ messages = [{"role": "user", "content": f"{test_prompt}\n\n{test_question}"}]
+
+ async def get_response():
+ return await llm.responser(request_type=RequestType.EXECUTE, messages=messages)
+
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+ try:
+ response = loop.run_until_complete(get_response())
+ finally:
+ loop.close()
+
+ st.subheader("Response:")
+ st.markdown(response)
+
+ except Exception as e:
+ st.error(f"Error generating response: {str(e)}")
+ else:
+ st.warning("Please enter both prompt and question.")
+
if __name__ == "__main__":
main()