feat: ask_code message support type list[Message].

This commit is contained in:
刘棒棒 2023-11-21 11:40:15 +08:00
parent 66f647276e
commit 4418fe8283
2 changed files with 11 additions and 2 deletions

View file

@ -262,10 +262,10 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
self._update_costs(rsp.get("usage"))
return rsp
def _process_message(self, messages: Union[str, Message, list[dict]]) -> list[dict]:
def _process_message(self, messages: Union[str, Message, list[dict], list[Message]]) -> list[dict]:
"""convert messages to list[dict]."""
if isinstance(messages, list):
return messages
return [msg if isinstance(msg, dict) else msg.to_dict() for msg in messages]
if isinstance(messages, Message):
messages = [messages.to_dict()]

View file

@ -59,3 +59,12 @@ def test_ask_code_Message():
assert "language" in rsp
assert "code" in rsp
assert len(rsp["code"]) > 0
def test_ask_code_list_Message():
llm = OpenAIGPTAPI()
msg = [UserMessage("a=[1,2,5,10,-10]"), UserMessage("写出求a中最大值的代码python")]
rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': 'max_value = max(a)\nmax_value'}
assert "language" in rsp
assert "code" in rsp
assert len(rsp["code"]) > 0