mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
Merge pull request #490 from orange-crow/feat_openai_ask_code_via_function
Feat openai ask code via function.
This commit is contained in:
commit
f57355c889
5 changed files with 236 additions and 7 deletions
|
|
@ -5,6 +5,7 @@
|
|||
@Author : alexanderwu
|
||||
@File : base_gpt_api.py
|
||||
"""
|
||||
import json
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
|
|
@ -14,7 +15,8 @@ from metagpt.provider.base_chatbot import BaseChatbot
|
|||
|
||||
class BaseGPTAPI(BaseChatbot):
|
||||
"""GPT API abstract class, requiring all inheritors to provide a series of standard capabilities"""
|
||||
system_prompt = 'You are a helpful assistant.'
|
||||
|
||||
system_prompt = "You are a helpful assistant."
|
||||
|
||||
def _user_msg(self, msg: str) -> dict[str, str]:
|
||||
return {"role": "user", "content": msg}
|
||||
|
|
@ -110,11 +112,50 @@ class BaseGPTAPI(BaseChatbot):
|
|||
"""Required to provide the first text of choice"""
|
||||
return rsp.get("choices")[0]["message"]["content"]
|
||||
|
||||
def get_choice_function(self, rsp: dict) -> dict:
|
||||
"""Required to provide the first function of choice
|
||||
:param dict rsp: OpenAI chat.comletion respond JSON, Note "message" must include "tool_calls",
|
||||
and "tool_calls" must include "function", for example:
|
||||
{...
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_Y5r6Ddr2Qc2ZrqgfwzPX5l72",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "execute",
|
||||
"arguments": "{\n \"language\": \"python\",\n \"code\": \"print('Hello, World!')\"\n}"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
...}
|
||||
:return dict: return first function of choice, for exmaple,
|
||||
{'name': 'execute', 'arguments': '{\n "language": "python",\n "code": "print(\'Hello, World!\')"\n}'}
|
||||
"""
|
||||
return rsp.get("choices")[0]["message"]["tool_calls"][0]["function"].to_dict()
|
||||
|
||||
def get_choice_function_arguments(self, rsp: dict) -> dict:
|
||||
"""Required to provide the first function arguments of choice.
|
||||
|
||||
:param dict rsp: same as in self.get_choice_function(rsp)
|
||||
:return dict: return the first function arguments of choice, for example,
|
||||
{'language': 'python', 'code': "print('Hello, World!')"}
|
||||
"""
|
||||
return json.loads(self.get_choice_function(rsp)["arguments"])
|
||||
|
||||
def messages_to_prompt(self, messages: list[dict]):
|
||||
"""[{"role": "user", "content": msg}] to user: <msg> etc."""
|
||||
return '\n'.join([f"{i['role']}: {i['content']}" for i in messages])
|
||||
return "\n".join([f"{i['role']}: {i['content']}" for i in messages])
|
||||
|
||||
def messages_to_dict(self, messages):
|
||||
"""objects to [{"role": "user", "content": msg}] etc."""
|
||||
return [i.to_dict() for i in messages]
|
||||
|
||||
30
metagpt/provider/constant.py
Normal file
30
metagpt/provider/constant.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
# function in tools, https://platform.openai.com/docs/api-reference/chat/create#chat-create-tools
|
||||
# Reference: https://github.com/KillianLucas/open-interpreter/blob/v0.1.14/interpreter/llm/setup_openai_coding_llm.py
|
||||
GENERAL_FUNCTION_SCHEMA = {
|
||||
"name": "execute",
|
||||
"description": "Executes code on the user's machine, **in the users local environment**, and returns the output",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"language": {
|
||||
"type": "string",
|
||||
"description": "The programming language (required parameter to the `execute` function)",
|
||||
"enum": [
|
||||
"python",
|
||||
"R",
|
||||
"shell",
|
||||
"applescript",
|
||||
"javascript",
|
||||
"html",
|
||||
"powershell",
|
||||
],
|
||||
},
|
||||
"code": {"type": "string", "description": "The code to execute (required)"},
|
||||
},
|
||||
"required": ["language", "code"],
|
||||
},
|
||||
}
|
||||
|
||||
# tool_choice value for general_function_schema
|
||||
# https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
|
||||
GENERAL_TOOL_CHOICE = {"type": "function", "function": {"name": "execute"}}
|
||||
|
|
@ -21,6 +21,8 @@ from tenacity import (
|
|||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA, GENERAL_TOOL_CHOICE
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.singleton import Singleton
|
||||
from metagpt.utils.token_counter import (
|
||||
TOKEN_COSTS,
|
||||
|
|
@ -110,7 +112,6 @@ class CostManager(metaclass=Singleton):
|
|||
"""
|
||||
return self.total_completion_tokens
|
||||
|
||||
|
||||
def get_total_cost(self):
|
||||
"""
|
||||
Get the total cost of API calls.
|
||||
|
|
@ -120,7 +121,6 @@ class CostManager(metaclass=Singleton):
|
|||
"""
|
||||
return self.total_cost
|
||||
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
"""Get all costs"""
|
||||
return Costs(self.total_prompt_tokens, self.total_completion_tokens, self.total_cost, self.total_budget)
|
||||
|
|
@ -181,7 +181,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
self._update_costs(usage)
|
||||
return full_reply_content
|
||||
|
||||
def _cons_kwargs(self, messages: list[dict]) -> dict:
|
||||
def _cons_kwargs(self, messages: list[dict], **configs) -> dict:
|
||||
kwargs = {
|
||||
"messages": messages,
|
||||
"max_tokens": self.get_max_tokens(messages),
|
||||
|
|
@ -190,6 +190,9 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
"temperature": 0.3,
|
||||
"timeout": 3,
|
||||
}
|
||||
if configs:
|
||||
kwargs.update(configs)
|
||||
|
||||
if CONFIG.openai_api_type == "azure":
|
||||
if CONFIG.deployment_name and CONFIG.deployment_id:
|
||||
raise ValueError("You can only use one of the `deployment_id` or `deployment_name` model")
|
||||
|
|
@ -239,6 +242,81 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
rsp = await self._achat_completion(messages)
|
||||
return self.get_choice_text(rsp)
|
||||
|
||||
def _func_configs(self, messages: list[dict], **kwargs) -> dict:
|
||||
"""
|
||||
Note: Keep kwargs consistent with the parameters in the https://platform.openai.com/docs/api-reference/chat/create
|
||||
"""
|
||||
if "tools" not in kwargs:
|
||||
configs = {
|
||||
"tools": [{"type": "function", "function": GENERAL_FUNCTION_SCHEMA}],
|
||||
"tool_choice": GENERAL_TOOL_CHOICE,
|
||||
}
|
||||
kwargs.update(configs)
|
||||
|
||||
return self._cons_kwargs(messages, **kwargs)
|
||||
|
||||
def _chat_completion_function(self, messages: list[dict], **kwargs) -> dict:
|
||||
rsp = self.llm.ChatCompletion.create(**self._func_configs(messages, **kwargs))
|
||||
self._update_costs(rsp.get("usage"))
|
||||
return rsp
|
||||
|
||||
async def _achat_completion_function(self, messages: list[dict], **chat_configs) -> dict:
|
||||
rsp = await self.llm.ChatCompletion.acreate(**self._func_configs(messages, **chat_configs))
|
||||
self._update_costs(rsp.get("usage"))
|
||||
return rsp
|
||||
|
||||
def _process_message(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]:
|
||||
"""convert messages to list[dict]."""
|
||||
if isinstance(messages, list):
|
||||
messages = [Message(msg) if isinstance(msg, str) else msg for msg in messages]
|
||||
return [msg if isinstance(msg, dict) else msg.to_dict() for msg in messages]
|
||||
|
||||
if isinstance(messages, Message):
|
||||
messages = [messages.to_dict()]
|
||||
elif isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Only support messages type are: str, Message, list[dict], but got {type(messages).__name__}!"
|
||||
)
|
||||
return messages
|
||||
|
||||
def ask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict:
|
||||
"""Use function of tools to ask a code.
|
||||
|
||||
Note: Keep kwargs consistent with the parameters in the https://platform.openai.com/docs/api-reference/chat/create
|
||||
|
||||
Examples:
|
||||
|
||||
>>> llm = OpenAIGPTAPI()
|
||||
>>> llm.ask_code("Write a python hello world code.")
|
||||
{'language': 'python', 'code': "print('Hello, World!')"}
|
||||
>>> msg = [{'role': 'user', 'content': "Write a python hello world code."}]
|
||||
>>> llm.ask_code(msg)
|
||||
{'language': 'python', 'code': "print('Hello, World!')"}
|
||||
"""
|
||||
messages = self._process_message(messages)
|
||||
rsp = self._chat_completion_function(messages, **kwargs)
|
||||
return self.get_choice_function_arguments(rsp)
|
||||
|
||||
async def aask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict:
|
||||
"""Use function of tools to ask a code.
|
||||
|
||||
Note: Keep kwargs consistent with the parameters in the https://platform.openai.com/docs/api-reference/chat/create
|
||||
|
||||
Examples:
|
||||
|
||||
>>> llm = OpenAIGPTAPI()
|
||||
>>> rsp = await llm.ask_code("Write a python hello world code.")
|
||||
>>> rsp
|
||||
{'language': 'python', 'code': "print('Hello, World!')"}
|
||||
>>> msg = [{'role': 'user', 'content': "Write a python hello world code."}]
|
||||
>>> rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
"""
|
||||
messages = self._process_message(messages)
|
||||
rsp = await self._achat_completion_function(messages, **kwargs)
|
||||
return self.get_choice_function_arguments(rsp)
|
||||
|
||||
def _calc_usage(self, messages: list[dict], rsp: str) -> dict:
|
||||
usage = {}
|
||||
if CONFIG.calc_usage:
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ langchain==0.0.231
|
|||
loguru==0.6.0
|
||||
meilisearch==0.21.0
|
||||
numpy==1.24.3
|
||||
openai
|
||||
openai>=0.28.0
|
||||
openpyxl
|
||||
beautifulsoup4==4.12.2
|
||||
pandas==2.0.3
|
||||
|
|
|
|||
80
tests/metagpt/provider/test_openai.py
Normal file
80
tests/metagpt/provider/test_openai.py
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI
|
||||
from metagpt.schema import UserMessage
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aask_code():
|
||||
llm = OpenAIGPTAPI()
|
||||
msg = [{"role": "user", "content": "Write a python hello world code."}]
|
||||
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
assert "language" in rsp
|
||||
assert "code" in rsp
|
||||
assert len(rsp["code"]) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aask_code_str():
|
||||
llm = OpenAIGPTAPI()
|
||||
msg = "Write a python hello world code."
|
||||
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
assert "language" in rsp
|
||||
assert "code" in rsp
|
||||
assert len(rsp["code"]) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aask_code_Message():
|
||||
llm = OpenAIGPTAPI()
|
||||
msg = UserMessage("Write a python hello world code.")
|
||||
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
assert "language" in rsp
|
||||
assert "code" in rsp
|
||||
assert len(rsp["code"]) > 0
|
||||
|
||||
|
||||
def test_ask_code():
|
||||
llm = OpenAIGPTAPI()
|
||||
msg = [{"role": "user", "content": "Write a python hello world code."}]
|
||||
rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
assert "language" in rsp
|
||||
assert "code" in rsp
|
||||
assert len(rsp["code"]) > 0
|
||||
|
||||
|
||||
def test_ask_code_str():
|
||||
llm = OpenAIGPTAPI()
|
||||
msg = "Write a python hello world code."
|
||||
rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
assert "language" in rsp
|
||||
assert "code" in rsp
|
||||
assert len(rsp["code"]) > 0
|
||||
|
||||
|
||||
def test_ask_code_Message():
|
||||
llm = OpenAIGPTAPI()
|
||||
msg = UserMessage("Write a python hello world code.")
|
||||
rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
assert "language" in rsp
|
||||
assert "code" in rsp
|
||||
assert len(rsp["code"]) > 0
|
||||
|
||||
|
||||
def test_ask_code_list_Message():
|
||||
llm = OpenAIGPTAPI()
|
||||
msg = [UserMessage("a=[1,2,5,10,-10]"), UserMessage("写出求a中最大值的代码python")]
|
||||
rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': 'max_value = max(a)\nmax_value'}
|
||||
assert "language" in rsp
|
||||
assert "code" in rsp
|
||||
assert len(rsp["code"]) > 0
|
||||
|
||||
|
||||
def test_ask_code_list_str():
|
||||
llm = OpenAIGPTAPI()
|
||||
msg = ["a=[1,2,5,10,-10]", "写出求a中最大值的代码python"]
|
||||
rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': 'max_value = max(a)\nmax_value'}
|
||||
print(rsp)
|
||||
assert "language" in rsp
|
||||
assert "code" in rsp
|
||||
assert len(rsp["code"]) > 0
|
||||
Loading…
Add table
Add a link
Reference in a new issue