feat: add ask code via function.

This commit is contained in:
刘棒棒 2023-11-18 16:09:56 +08:00
parent adf40e6783
commit 0153bb5416
4 changed files with 120 additions and 6 deletions

View file

@ -5,6 +5,7 @@
@Author : alexanderwu
@File : base_gpt_api.py
"""
import json
from abc import abstractmethod
from typing import Optional
@ -14,7 +15,8 @@ from metagpt.provider.base_chatbot import BaseChatbot
class BaseGPTAPI(BaseChatbot):
"""GPT API abstract class, requiring all inheritors to provide a series of standard capabilities"""
system_prompt = 'You are a helpful assistant.'
system_prompt = "You are a helpful assistant."
def _user_msg(self, msg: str) -> dict[str, str]:
return {"role": "user", "content": msg}
@ -108,11 +110,23 @@ class BaseGPTAPI(BaseChatbot):
"""Required to provide the first text of choice"""
return rsp.get("choices")[0]["message"]["content"]
def get_choice_function(self, rsp: dict) -> dict:
"""Required to provide the first function of choice. for example:
"function": {
"name": "execute",
"arguments": "{\n \"language\": \"python\",\n \"code\": \"print('Hello, World!')\"\n}"
}
"""
return rsp.get("choices")[0]["message"]["tool_calls"][0]["function"].to_dict()
def get_choice_function_arguments(self, rsp: dict) -> dict:
"""Required to provide the first function arguments of choice."""
return json.loads(self.get_choice_function(rsp)["arguments"])
def messages_to_prompt(self, messages: list[dict]):
"""[{"role": "user", "content": msg}] to user: <msg> etc."""
return '\n'.join([f"{i['role']}: {i['content']}" for i in messages])
return "\n".join([f"{i['role']}: {i['content']}" for i in messages])
def messages_to_dict(self, messages):
"""objects to [{"role": "user", "content": msg}] etc."""
return [i.to_dict() for i in messages]

View file

@ -21,6 +21,7 @@ from tenacity import (
from metagpt.config import CONFIG
from metagpt.logs import logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.utils.function_schema import general_function_schema, general_tool_choice
from metagpt.utils.singleton import Singleton
from metagpt.utils.token_counter import (
TOKEN_COSTS,
@ -110,7 +111,6 @@ class CostManager(metaclass=Singleton):
"""
return self.total_completion_tokens
def get_total_cost(self):
"""
Get the total cost of API calls.
@ -120,7 +120,6 @@ class CostManager(metaclass=Singleton):
"""
return self.total_cost
def get_costs(self) -> Costs:
"""Get all costs"""
return Costs(self.total_prompt_tokens, self.total_completion_tokens, self.total_cost, self.total_budget)
@ -181,7 +180,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
self._update_costs(usage)
return full_reply_content
def _cons_kwargs(self, messages: list[dict]) -> dict:
def _cons_kwargs(self, messages: list[dict], **configs) -> dict:
kwargs = {
"messages": messages,
"max_tokens": self.get_max_tokens(messages),
@ -190,6 +189,9 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
"temperature": 0.3,
"timeout": 3,
}
if configs:
kwargs.update(configs)
if CONFIG.openai_api_type == "azure":
if CONFIG.deployment_name and CONFIG.deployment_id:
raise ValueError("You can only use one of the `deployment_id` or `deployment_name` model")
@ -239,6 +241,53 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
rsp = await self._achat_completion(messages)
return self.get_choice_text(rsp)
def _func_configs(self, messages: list[dict], **kwargs) -> dict:
if "tools" not in kwargs:
configs = {
"tools": [{"type": "function", "function": general_function_schema}],
"tool_choice": general_tool_choice,
}
kwargs.update(configs)
return self._cons_kwargs(messages, **kwargs)
def _chat_completion_function(self, messages: list[dict], **kwargs) -> dict:
rsp = self.llm.ChatCompletion.create(**self._func_configs(messages, **kwargs))
self._update_costs(rsp.get("usage"))
return rsp
async def _achat_completion_function(self, messages: list[dict], **chat_configs) -> dict:
rsp = await self.llm.ChatCompletion.acreate(**self._func_configs(messages, **chat_configs))
self._update_costs(rsp.get("usage"))
return rsp
def ask_code(self, messages: list[dict], **kwargs) -> dict:
"""Use function of tools to ask a code.
https://platform.openai.com/docs/api-reference/chat/create#chat-create-tools
Examples:
>>> llm = OpenAIGPTAPI()
>>> msg = [{'role': 'user', 'content': "Write a python hello world code."}]
>>> llm.ask_code(msg)
{'language': 'python', 'code': "print('Hello, World!')"}
"""
rsp = self._chat_completion_function(messages, **kwargs)
return self.get_choice_function_arguments(rsp)
async def aask_code(self, messages: list[dict], **kwargs) -> dict:
"""Use function of tools to ask a code.
https://platform.openai.com/docs/api-reference/chat/create#chat-create-tools
Examples:
>>> llm = OpenAIGPTAPI()
>>> msg = [{'role': 'user', 'content': "Write a python hello world code."}]
>>> rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
"""
rsp = await self._achat_completion_function(messages, **kwargs)
return self.get_choice_function_arguments(rsp)
def _calc_usage(self, messages: list[dict], rsp: str) -> dict:
usage = {}
if CONFIG.calc_usage:

View file

@ -0,0 +1,29 @@
# function in tools, https://platform.openai.com/docs/api-reference/chat/create#chat-create-tools
general_function_schema = {
"name": "execute",
"description": "Executes code on the user's machine, **in the users local environment**, and returns the output",
"parameters": {
"type": "object",
"properties": {
"language": {
"type": "string",
"description": "The programming language (required parameter to the `execute` function)",
"enum": [
"python",
"R",
"shell",
"applescript",
"javascript",
"html",
"powershell",
],
},
"code": {"type": "string", "description": "The code to execute (required)"},
},
"required": ["language", "code"],
},
}
# tool_choice value for general_function_schema
# https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
general_tool_choice = {"type": "function", "function": {"name": "execute"}}

View file

@ -0,0 +1,22 @@
import pytest
from metagpt.provider.openai_api import OpenAIGPTAPI
@pytest.mark.asyncio
async def test_aask_code():
llm = OpenAIGPTAPI()
msg = [{"role": "user", "content": "Write a python hello world code."}]
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
assert "language" in rsp
assert "code" in rsp
assert len(rsp["code"]) > 0
def test_ask_code():
llm = OpenAIGPTAPI()
msg = [{"role": "user", "content": "Write a python hello world code."}]
rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
assert "language" in rsp
assert "code" in rsp
assert len(rsp["code"]) > 0