diff --git a/examples/exp_pool/init_exp_pool.py b/examples/exp_pool/init_exp_pool.py index 321c38d78..62747b8d8 100644 --- a/examples/exp_pool/init_exp_pool.py +++ b/examples/exp_pool/init_exp_pool.py @@ -46,6 +46,7 @@ async def add_exp(req: str, resp: str, tag: str, metric: Metric = None): metric=metric or Metric(score=Score(val=10, reason="Manual")), ) exp_manager = get_exp_manager() + exp_manager.config.exp_pool.enabled = True exp_manager.config.exp_pool.enable_write = True exp_manager.create_exp(exp) logger.info(f"New experience created for the request `{req[:10]}`.") @@ -59,8 +60,10 @@ async def add_exps(exps: list, tag: str): tag: A tag for categorizing the experiences. """ - - tasks = [add_exp(req=json.dumps(exp["req"]), resp=exp["resp"], tag=tag) for exp in exps] + tasks = [ + add_exp(req=exp["req"] if isinstance(exp["req"], str) else json.dumps(exp["req"]), resp=exp["resp"], tag=tag) + for exp in exps + ] await asyncio.gather(*tasks)