mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-24 14:15:17 +02:00
update spo app.py to test prompt
This commit is contained in:
parent
4417d805de
commit
ac3623cd84
1 changed files with 88 additions and 35 deletions
|
|
@ -1,3 +1,4 @@
|
|||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
|
@ -9,7 +10,7 @@ sys.path.append(str(Path(__file__).parents[3]))
|
|||
|
||||
from metagpt.const import METAGPT_ROOT # noqa: E402
|
||||
from metagpt.ext.spo.components.optimizer import PromptOptimizer # noqa: E402
|
||||
from metagpt.ext.spo.utils.llm_client import SPO_LLM # noqa: E402
|
||||
from metagpt.ext.spo.utils.llm_client import SPO_LLM, RequestType # noqa: E402
|
||||
|
||||
|
||||
def load_yaml_template(template_path):
|
||||
|
|
@ -36,7 +37,47 @@ def save_yaml_template(template_path, data):
|
|||
yaml.dump(template_format, f, allow_unicode=True, sort_keys=False, default_flow_style=False, indent=2)
|
||||
|
||||
|
||||
def display_optimization_results(result_data):
|
||||
for result in result_data:
|
||||
round_num = result["round"]
|
||||
success = result["succeed"]
|
||||
prompt = result["prompt"]
|
||||
|
||||
with st.expander(f"Round {round_num} {':white_check_mark:' if success else ':x:'}"):
|
||||
st.markdown("**Prompt:**")
|
||||
st.code(prompt, language="text")
|
||||
st.markdown("<br>", unsafe_allow_html=True)
|
||||
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
st.markdown(f"**Status:** {'Success ✅ ' if success else 'Failed ❌ '}")
|
||||
with col2:
|
||||
st.markdown(f"**Tokens:** {result['tokens']}")
|
||||
|
||||
st.markdown("**Answers:**")
|
||||
for idx, answer in enumerate(result["answers"]):
|
||||
st.markdown(f"**Question {idx + 1}:**")
|
||||
st.text(answer["question"])
|
||||
st.markdown("**Answer:**")
|
||||
st.text(answer["answer"])
|
||||
st.markdown("---")
|
||||
|
||||
# Summary
|
||||
success_count = sum(1 for r in result_data if r["succeed"])
|
||||
total_rounds = len(result_data)
|
||||
|
||||
st.markdown("### Summary")
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
st.metric("Total Rounds", total_rounds)
|
||||
with col2:
|
||||
st.metric("Successful Rounds", success_count)
|
||||
|
||||
|
||||
def main():
|
||||
if "optimization_results" not in st.session_state:
|
||||
st.session_state.optimization_results = []
|
||||
|
||||
st.title("SPO | Self-Supervised Prompt Optimization 🤖")
|
||||
|
||||
# Sidebar for configurations
|
||||
|
|
@ -189,45 +230,57 @@ def main():
|
|||
prompt_path = f"{optimizer.root_path}/prompts"
|
||||
result_data = optimizer.data_utils.load_results(prompt_path)
|
||||
|
||||
for result in result_data:
|
||||
round_num = result["round"]
|
||||
success = result["succeed"]
|
||||
prompt = result["prompt"]
|
||||
|
||||
with st.expander(f"Round {round_num} {':white_check_mark:' if success else ':x:'}"):
|
||||
st.markdown("**Prompt:**")
|
||||
st.code(prompt, language="text")
|
||||
st.markdown("<br>", unsafe_allow_html=True)
|
||||
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
st.markdown(f"**Status:** {'Success ✅ ' if success else 'Failed ❌ '}")
|
||||
with col2:
|
||||
st.markdown(f"**Tokens:** {result['tokens']}")
|
||||
|
||||
st.markdown("**Answers:**")
|
||||
for idx, answer in enumerate(result["answers"]):
|
||||
st.markdown(f"**Question {idx + 1}:**")
|
||||
st.text(answer["question"])
|
||||
st.markdown("**Answer:**")
|
||||
st.text(answer["answer"])
|
||||
st.markdown("---")
|
||||
|
||||
# Summary
|
||||
success_count = sum(1 for r in result_data if r["succeed"])
|
||||
total_rounds = len(result_data)
|
||||
|
||||
st.markdown("### Summary")
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
st.metric("Total Rounds", total_rounds)
|
||||
with col2:
|
||||
st.metric("Successful Rounds", success_count)
|
||||
st.session_state.optimization_results = result_data
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"An error occurred: {str(e)}")
|
||||
_logger.error(f"Error during optimization: {str(e)}")
|
||||
|
||||
if st.session_state.optimization_results:
|
||||
st.header("Optimization Results")
|
||||
display_optimization_results(st.session_state.optimization_results)
|
||||
|
||||
st.markdown("---")
|
||||
st.subheader("Test Optimized Prompt")
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
test_prompt = st.text_area("Optimized Prompt", value="", height=200, key="test_prompt")
|
||||
|
||||
with col2:
|
||||
test_question = st.text_area("Your Question", value="", height=200, key="test_question")
|
||||
|
||||
if st.button("Test Prompt"):
|
||||
if test_prompt and test_question:
|
||||
try:
|
||||
with st.spinner("Generating response..."):
|
||||
SPO_LLM.initialize(
|
||||
optimize_kwargs={"model": opt_model, "temperature": opt_temp},
|
||||
evaluate_kwargs={"model": eval_model, "temperature": eval_temp},
|
||||
execute_kwargs={"model": exec_model, "temperature": exec_temp},
|
||||
)
|
||||
|
||||
llm = SPO_LLM.get_instance()
|
||||
messages = [{"role": "user", "content": f"{test_prompt}\n\n{test_question}"}]
|
||||
|
||||
async def get_response():
|
||||
return await llm.responser(request_type=RequestType.EXECUTE, messages=messages)
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
response = loop.run_until_complete(get_response())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
st.subheader("Response:")
|
||||
st.markdown(response)
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"Error generating response: {str(e)}")
|
||||
else:
|
||||
st.warning("Please enter both prompt and question.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue