From ec78ef00b53a7a8bfaef6241f1f579a0275ad6ec Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Thu, 15 Aug 2024 21:19:19 +0800 Subject: [PATCH] update exp example --- examples/exp_pool/init_exp_pool.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/exp_pool/init_exp_pool.py b/examples/exp_pool/init_exp_pool.py index 321c38d78..62747b8d8 100644 --- a/examples/exp_pool/init_exp_pool.py +++ b/examples/exp_pool/init_exp_pool.py @@ -46,6 +46,7 @@ async def add_exp(req: str, resp: str, tag: str, metric: Metric = None): metric=metric or Metric(score=Score(val=10, reason="Manual")), ) exp_manager = get_exp_manager() + exp_manager.config.exp_pool.enabled = True exp_manager.config.exp_pool.enable_write = True exp_manager.create_exp(exp) logger.info(f"New experience created for the request `{req[:10]}`.") @@ -59,8 +60,10 @@ async def add_exps(exps: list, tag: str): tag: A tag for categorizing the experiences. """ - - tasks = [add_exp(req=json.dumps(exp["req"]), resp=exp["resp"], tag=tag) for exp in exps] + tasks = [ + add_exp(req=exp["req"] if isinstance(exp["req"], str) else json.dumps(exp["req"]), resp=exp["resp"], tag=tag) + for exp in exps + ] await asyncio.gather(*tasks)