diff --git a/metagpt/actions/di/ask_review.py b/metagpt/actions/di/ask_review.py index 041011e80..ecbbd992e 100644 --- a/metagpt/actions/di/ask_review.py +++ b/metagpt/actions/di/ask_review.py @@ -3,7 +3,7 @@ from __future__ import annotations from typing import Tuple from metagpt.actions import Action -from metagpt.logs import logger +from metagpt.logs import get_human_input, logger from metagpt.schema import Message, Plan @@ -50,7 +50,7 @@ class AskReview(Action): "Please type your review below:\n" ) - rsp = input(prompt) + rsp = await get_human_input(prompt) if rsp.lower() in ReviewConst.EXIT_WORDS: exit() diff --git a/metagpt/logs.py b/metagpt/logs.py index b208e0868..d6b7cc419 100644 --- a/metagpt/logs.py +++ b/metagpt/logs.py @@ -8,6 +8,7 @@ from __future__ import annotations +import inspect import sys from datetime import datetime from functools import partial @@ -59,6 +60,14 @@ async def log_tool_output_async(output: ToolLogItem | list[ToolLogItem], tool_na await _tool_output_log_async(output=output, tool_name=tool_name) +async def get_human_input(prompt: str = ""): + """interface for getting human input, can be set to get input from different sources with set_human_input_func""" + if inspect.iscoroutinefunction(_get_human_input): + return await _get_human_input(prompt) + else: + return _get_human_input(prompt) + + def set_llm_stream_logfunc(func): global _llm_stream_log _llm_stream_log = func @@ -75,6 +84,11 @@ async def set_tool_output_logfunc_async(func): _tool_output_log_async = func +def set_human_input_func(func): + global _get_human_input + _get_human_input = func + + _llm_stream_log = partial(print, end="") @@ -86,3 +100,6 @@ _tool_output_log = ( async def _tool_output_log_async(*args, **kwargs): # async version pass + + +_get_human_input = input # get human input from console by default diff --git a/tests/metagpt/actions/di/test_ask_review.py b/tests/metagpt/actions/di/test_ask_review.py index 6bb1accf5..d49ad176a 100644 --- a/tests/metagpt/actions/di/test_ask_review.py +++ b/tests/metagpt/actions/di/test_ask_review.py @@ -6,7 +6,7 @@ from metagpt.actions.di.ask_review import AskReview @pytest.mark.asyncio async def test_ask_review(mocker): mock_review_input = "confirm" - mocker.patch("builtins.input", return_value=mock_review_input) + mocker.patch("metagpt.actions.di.ask_review.get_human_input", return_value=mock_review_input) rsp, confirmed = await AskReview().run() assert rsp == mock_review_input assert confirmed