Merge branch 'del-_parse_arguments-in-openai_api.py' into code_intepreter

This commit is contained in:
yzlin 2024-02-04 23:41:39 +08:00
commit ca6749b5f1
3 changed files with 33 additions and 63 deletions

View file

@ -1,3 +1,5 @@
import json
import pytest
from openai.types.chat import (
ChatCompletion,
@ -41,16 +43,6 @@ async def test_speech_to_text():
def tool_calls_rsp():
function_rsps = [
Function(arguments='{\n"language": "python",\n"code": "print(\'hello world\')"}', name="execute"),
Function(arguments='{\n"language": "python",\n"code": \'print("hello world")\'}', name="execute"),
Function(arguments='{\n"language": \'python\',\n"code": "print(\'hello world\')"}', name="execute"),
Function(arguments='{\n"language": "python",\n"code": "print(\'hello world\')"}', name="execute"),
Function(arguments='{\n"language": "python",\n"code": ```print("hello world")```}', name="execute"),
Function(arguments='{\n"language": "python",\n"code": """print("hello world")"""}', name="execute"),
Function(arguments='\nprint("hello world")\\n', name="execute"),
# only `{` in arguments
Function(arguments='{\n"language": "python",\n"code": "print(\'hello world\')"', name="execute"),
# no `{`, `}` in arguments
Function(arguments='\n"language": "python",\n"code": "print(\'hello world\')"', name="execute"),
]
tool_calls = [
ChatCompletionMessageToolCall(type="function", id=f"call_{i}", function=f) for i, f in enumerate(function_rsps)
@ -64,10 +56,6 @@ def tool_calls_rsp():
messages.extend(
[
ChatCompletionMessage(content="```python\nprint('hello world')```", role="assistant", tool_calls=None),
ChatCompletionMessage(content="'''python\nprint('hello world')'''", role="assistant", tool_calls=None),
ChatCompletionMessage(content='"""python\nprint(\'hello world\')"""', role="assistant", tool_calls=None),
ChatCompletionMessage(content="'''python\nprint(\"hello world\")'''", role="assistant", tool_calls=None),
ChatCompletionMessage(content="```python\nprint('hello world')```", role="assistant", tool_calls=None),
]
)
choices = [
@ -79,6 +67,15 @@ def tool_calls_rsp():
]
@pytest.fixture
def json_decode_error():
function_rsp = Function(arguments='{\n"language": \'python\',\n"code": "print(\'hello world\')"}', name="execute")
tool_calls = [ChatCompletionMessageToolCall(type="function", id=f"call_{0}", function=function_rsp)]
message = ChatCompletionMessage(content=None, role="assistant", tool_calls=tool_calls)
choices = [Choice(finish_reason="tool_calls", logprobs=None, index=0, message=message)]
return ChatCompletion(id="0", choices=choices, created=0, model="gpt-4", object="chat.completion")
class TestOpenAI:
def test_make_client_kwargs_without_proxy(self):
instance = OpenAILLM(mock_llm_config)
@ -107,6 +104,12 @@ class TestOpenAI:
else:
code["language"] == "python"
def test_aask_code_json_decode_error(self, json_decode_error):
instance = OpenAILLM(mock_llm_config)
with pytest.raises(json.decoder.JSONDecodeError) as e:
instance.get_choice_function_arguments(json_decode_error)
assert "JSONDecodeError" in str(e)
@pytest.mark.asyncio
async def test_gen_image():