update spo app.py to test prompt

This commit is contained in:
xiangjinyu 2025-02-12 16:38:16 +08:00
parent 4417d805de
commit ac3623cd84

View file

@ -1,3 +1,4 @@
import asyncio
import sys
from pathlib import Path
@ -9,7 +10,7 @@ sys.path.append(str(Path(__file__).parents[3]))
from metagpt.const import METAGPT_ROOT # noqa: E402
from metagpt.ext.spo.components.optimizer import PromptOptimizer # noqa: E402
from metagpt.ext.spo.utils.llm_client import SPO_LLM # noqa: E402
from metagpt.ext.spo.utils.llm_client import SPO_LLM, RequestType # noqa: E402
def load_yaml_template(template_path):
@ -36,7 +37,47 @@ def save_yaml_template(template_path, data):
yaml.dump(template_format, f, allow_unicode=True, sort_keys=False, default_flow_style=False, indent=2)
def display_optimization_results(result_data):
for result in result_data:
round_num = result["round"]
success = result["succeed"]
prompt = result["prompt"]
with st.expander(f"Round {round_num} {':white_check_mark:' if success else ':x:'}"):
st.markdown("**Prompt:**")
st.code(prompt, language="text")
st.markdown("<br>", unsafe_allow_html=True)
col1, col2 = st.columns(2)
with col1:
st.markdown(f"**Status:** {'Success ✅ ' if success else 'Failed ❌ '}")
with col2:
st.markdown(f"**Tokens:** {result['tokens']}")
st.markdown("**Answers:**")
for idx, answer in enumerate(result["answers"]):
st.markdown(f"**Question {idx + 1}:**")
st.text(answer["question"])
st.markdown("**Answer:**")
st.text(answer["answer"])
st.markdown("---")
# Summary
success_count = sum(1 for r in result_data if r["succeed"])
total_rounds = len(result_data)
st.markdown("### Summary")
col1, col2 = st.columns(2)
with col1:
st.metric("Total Rounds", total_rounds)
with col2:
st.metric("Successful Rounds", success_count)
def main():
if "optimization_results" not in st.session_state:
st.session_state.optimization_results = []
st.title("SPO | Self-Supervised Prompt Optimization 🤖")
# Sidebar for configurations
@ -189,45 +230,57 @@ def main():
prompt_path = f"{optimizer.root_path}/prompts"
result_data = optimizer.data_utils.load_results(prompt_path)
for result in result_data:
round_num = result["round"]
success = result["succeed"]
prompt = result["prompt"]
with st.expander(f"Round {round_num} {':white_check_mark:' if success else ':x:'}"):
st.markdown("**Prompt:**")
st.code(prompt, language="text")
st.markdown("<br>", unsafe_allow_html=True)
col1, col2 = st.columns(2)
with col1:
st.markdown(f"**Status:** {'Success ✅ ' if success else 'Failed ❌ '}")
with col2:
st.markdown(f"**Tokens:** {result['tokens']}")
st.markdown("**Answers:**")
for idx, answer in enumerate(result["answers"]):
st.markdown(f"**Question {idx + 1}:**")
st.text(answer["question"])
st.markdown("**Answer:**")
st.text(answer["answer"])
st.markdown("---")
# Summary
success_count = sum(1 for r in result_data if r["succeed"])
total_rounds = len(result_data)
st.markdown("### Summary")
col1, col2 = st.columns(2)
with col1:
st.metric("Total Rounds", total_rounds)
with col2:
st.metric("Successful Rounds", success_count)
st.session_state.optimization_results = result_data
except Exception as e:
st.error(f"An error occurred: {str(e)}")
_logger.error(f"Error during optimization: {str(e)}")
if st.session_state.optimization_results:
st.header("Optimization Results")
display_optimization_results(st.session_state.optimization_results)
st.markdown("---")
st.subheader("Test Optimized Prompt")
col1, col2 = st.columns(2)
with col1:
test_prompt = st.text_area("Optimized Prompt", value="", height=200, key="test_prompt")
with col2:
test_question = st.text_area("Your Question", value="", height=200, key="test_question")
if st.button("Test Prompt"):
if test_prompt and test_question:
try:
with st.spinner("Generating response..."):
SPO_LLM.initialize(
optimize_kwargs={"model": opt_model, "temperature": opt_temp},
evaluate_kwargs={"model": eval_model, "temperature": eval_temp},
execute_kwargs={"model": exec_model, "temperature": exec_temp},
)
llm = SPO_LLM.get_instance()
messages = [{"role": "user", "content": f"{test_prompt}\n\n{test_question}"}]
async def get_response():
return await llm.responser(request_type=RequestType.EXECUTE, messages=messages)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
response = loop.run_until_complete(get_response())
finally:
loop.close()
st.subheader("Response:")
st.markdown(response)
except Exception as e:
st.error(f"Error generating response: {str(e)}")
else:
st.warning("Please enter both prompt and question.")
if __name__ == "__main__":
main()