From ac3623cd84c391b64c4ada6f73d863e858af9f2b Mon Sep 17 00:00:00 2001 From: xiangjinyu Date: Wed, 12 Feb 2025 16:38:16 +0800 Subject: [PATCH] update spo app.py to test prompt --- metagpt/ext/spo/app.py | 123 +++++++++++++++++++++++++++++------------ 1 file changed, 88 insertions(+), 35 deletions(-) diff --git a/metagpt/ext/spo/app.py b/metagpt/ext/spo/app.py index 563eb92ff..183757124 100644 --- a/metagpt/ext/spo/app.py +++ b/metagpt/ext/spo/app.py @@ -1,3 +1,4 @@ +import asyncio import sys from pathlib import Path @@ -9,7 +10,7 @@ sys.path.append(str(Path(__file__).parents[3])) from metagpt.const import METAGPT_ROOT # noqa: E402 from metagpt.ext.spo.components.optimizer import PromptOptimizer # noqa: E402 -from metagpt.ext.spo.utils.llm_client import SPO_LLM # noqa: E402 +from metagpt.ext.spo.utils.llm_client import SPO_LLM, RequestType # noqa: E402 def load_yaml_template(template_path): @@ -36,7 +37,47 @@ def save_yaml_template(template_path, data): yaml.dump(template_format, f, allow_unicode=True, sort_keys=False, default_flow_style=False, indent=2) +def display_optimization_results(result_data): + for result in result_data: + round_num = result["round"] + success = result["succeed"] + prompt = result["prompt"] + + with st.expander(f"Round {round_num} {':white_check_mark:' if success else ':x:'}"): + st.markdown("**Prompt:**") + st.code(prompt, language="text") + st.markdown("
", unsafe_allow_html=True) + + col1, col2 = st.columns(2) + with col1: + st.markdown(f"**Status:** {'Success ✅ ' if success else 'Failed ❌ '}") + with col2: + st.markdown(f"**Tokens:** {result['tokens']}") + + st.markdown("**Answers:**") + for idx, answer in enumerate(result["answers"]): + st.markdown(f"**Question {idx + 1}:**") + st.text(answer["question"]) + st.markdown("**Answer:**") + st.text(answer["answer"]) + st.markdown("---") + + # Summary + success_count = sum(1 for r in result_data if r["succeed"]) + total_rounds = len(result_data) + + st.markdown("### Summary") + col1, col2 = st.columns(2) + with col1: + st.metric("Total Rounds", total_rounds) + with col2: + st.metric("Successful Rounds", success_count) + + def main(): + if "optimization_results" not in st.session_state: + st.session_state.optimization_results = [] + st.title("SPO | Self-Supervised Prompt Optimization 🤖") # Sidebar for configurations @@ -189,45 +230,57 @@ def main(): prompt_path = f"{optimizer.root_path}/prompts" result_data = optimizer.data_utils.load_results(prompt_path) - for result in result_data: - round_num = result["round"] - success = result["succeed"] - prompt = result["prompt"] - - with st.expander(f"Round {round_num} {':white_check_mark:' if success else ':x:'}"): - st.markdown("**Prompt:**") - st.code(prompt, language="text") - st.markdown("
", unsafe_allow_html=True) - - col1, col2 = st.columns(2) - with col1: - st.markdown(f"**Status:** {'Success ✅ ' if success else 'Failed ❌ '}") - with col2: - st.markdown(f"**Tokens:** {result['tokens']}") - - st.markdown("**Answers:**") - for idx, answer in enumerate(result["answers"]): - st.markdown(f"**Question {idx + 1}:**") - st.text(answer["question"]) - st.markdown("**Answer:**") - st.text(answer["answer"]) - st.markdown("---") - - # Summary - success_count = sum(1 for r in result_data if r["succeed"]) - total_rounds = len(result_data) - - st.markdown("### Summary") - col1, col2 = st.columns(2) - with col1: - st.metric("Total Rounds", total_rounds) - with col2: - st.metric("Successful Rounds", success_count) + st.session_state.optimization_results = result_data except Exception as e: st.error(f"An error occurred: {str(e)}") _logger.error(f"Error during optimization: {str(e)}") + if st.session_state.optimization_results: + st.header("Optimization Results") + display_optimization_results(st.session_state.optimization_results) + + st.markdown("---") + st.subheader("Test Optimized Prompt") + col1, col2 = st.columns(2) + + with col1: + test_prompt = st.text_area("Optimized Prompt", value="", height=200, key="test_prompt") + + with col2: + test_question = st.text_area("Your Question", value="", height=200, key="test_question") + + if st.button("Test Prompt"): + if test_prompt and test_question: + try: + with st.spinner("Generating response..."): + SPO_LLM.initialize( + optimize_kwargs={"model": opt_model, "temperature": opt_temp}, + evaluate_kwargs={"model": eval_model, "temperature": eval_temp}, + execute_kwargs={"model": exec_model, "temperature": exec_temp}, + ) + + llm = SPO_LLM.get_instance() + messages = [{"role": "user", "content": f"{test_prompt}\n\n{test_question}"}] + + async def get_response(): + return await llm.responser(request_type=RequestType.EXECUTE, messages=messages) + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + response = loop.run_until_complete(get_response()) + finally: + loop.close() + + st.subheader("Response:") + st.markdown(response) + + except Exception as e: + st.error(f"Error generating response: {str(e)}") + else: + st.warning("Please enter both prompt and question.") + if __name__ == "__main__": main()