get_human_input interface

This commit is contained in:
yzlin 2024-04-13 11:44:31 +08:00
parent c6e42631da
commit ed8777db99
3 changed files with 20 additions and 3 deletions

View file

@ -3,7 +3,7 @@ from __future__ import annotations
from typing import Tuple
from metagpt.actions import Action
from metagpt.logs import logger
from metagpt.logs import get_human_input, logger
from metagpt.schema import Message, Plan
@ -50,7 +50,7 @@ class AskReview(Action):
"Please type your review below:\n"
)
rsp = input(prompt)
rsp = await get_human_input(prompt)
if rsp.lower() in ReviewConst.EXIT_WORDS:
exit()

View file

@ -8,6 +8,7 @@
from __future__ import annotations
import inspect
import sys
from datetime import datetime
from functools import partial
@ -59,6 +60,14 @@ async def log_tool_output_async(output: ToolLogItem | list[ToolLogItem], tool_na
await _tool_output_log_async(output=output, tool_name=tool_name)
async def get_human_input(prompt: str = ""):
"""interface for getting human input, can be set to get input from different sources with set_human_input_func"""
if inspect.iscoroutinefunction(_get_human_input):
return await _get_human_input(prompt)
else:
return _get_human_input(prompt)
def set_llm_stream_logfunc(func):
global _llm_stream_log
_llm_stream_log = func
@ -75,6 +84,11 @@ async def set_tool_output_logfunc_async(func):
_tool_output_log_async = func
def set_human_input_func(func):
global _get_human_input
_get_human_input = func
_llm_stream_log = partial(print, end="")
@ -86,3 +100,6 @@ _tool_output_log = (
async def _tool_output_log_async(*args, **kwargs):
# async version
pass
_get_human_input = input # get human input from console by default

View file

@ -6,7 +6,7 @@ from metagpt.actions.di.ask_review import AskReview
@pytest.mark.asyncio
async def test_ask_review(mocker):
mock_review_input = "confirm"
mocker.patch("builtins.input", return_value=mock_review_input)
mocker.patch("metagpt.actions.di.ask_review.get_human_input", return_value=mock_review_input)
rsp, confirmed = await AskReview().run()
assert rsp == mock_review_input
assert confirmed