mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-05 05:42:37 +02:00
init project
This commit is contained in:
commit
c871144507
204 changed files with 7220 additions and 0 deletions
16
metagpt/tools/__init__.py
Normal file
16
metagpt/tools/__init__.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/4/29 15:35
|
||||
@Author : alexanderwu
|
||||
@File : __init__.py
|
||||
"""
|
||||
|
||||
|
||||
from enum import Enum, auto
|
||||
|
||||
|
||||
class SearchEngineType(Enum):
|
||||
SERPAPI_GOOGLE = auto()
|
||||
DIRECT_GOOGLE = auto()
|
||||
CUSTOM_ENGINE = auto()
|
||||
111
metagpt/tools/prompt_writer.py
Normal file
111
metagpt/tools/prompt_writer.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/5/2 16:03
|
||||
@Author : alexanderwu
|
||||
@File : prompt_writer.py
|
||||
"""
|
||||
from abc import ABC
|
||||
from typing import Union
|
||||
|
||||
|
||||
class GPTPromptGenerator:
|
||||
"""通过LLM,给定输出,要求LLM给出输入(支持指令、对话、搜索三种风格)"""
|
||||
def __init__(self):
|
||||
self._generators = {i: getattr(self, f"gen_{i}_style") for i in ['instruction', 'chatbot', 'query']}
|
||||
|
||||
def gen_instruction_style(self, example):
|
||||
"""指令风格:给定输出,要求LLM给出输入"""
|
||||
return f"""指令:X
|
||||
输出:{example}
|
||||
这个输出可能来源于什么样的指令?
|
||||
X:"""
|
||||
|
||||
def gen_chatbot_style(self, example):
|
||||
"""对话风格:给定输出,要求LLM给出输入"""
|
||||
return f"""你是一个对话机器人。一个用户给你发送了一条非正式的信息,你的回复如下。
|
||||
信息:X
|
||||
回复:{example}
|
||||
非正式信息X是什么?
|
||||
X:"""
|
||||
|
||||
def gen_query_style(self, example):
|
||||
"""搜索风格:给定输出,要求LLM给出输入"""
|
||||
return f"""你是一个搜索引擎。一个人详细地查询了某个问题,关于这个查询最相关的文档如下。
|
||||
查询:X
|
||||
文档:{example} 详细的查询X是什么?
|
||||
X:"""
|
||||
|
||||
def gen(self, example: str, style: str = 'all') -> Union[list[str], str]:
|
||||
"""
|
||||
通过example生成一个或多个输出,用于让LLM回复对应输入
|
||||
|
||||
:param example: LLM的预期输出样本
|
||||
:param style: (all|instruction|chatbot|query)
|
||||
:return: LLM的预期输入样本(一个或多个)
|
||||
"""
|
||||
if style != 'all':
|
||||
return self._generators[style](example)
|
||||
return [f(example) for f in self._generators.values()]
|
||||
|
||||
|
||||
class WikiHowTemplate:
|
||||
def __init__(self):
|
||||
self._prompts = """Give me {step} steps to {question}.
|
||||
How to {question}?
|
||||
Do you know how can I {question}?
|
||||
List {step} instructions to {question}.
|
||||
What are some tips to {question}?
|
||||
What are some steps to {question}?
|
||||
Can you provide {step} clear and concise instructions on how to {question}?
|
||||
I'm interested in learning how to {question}. Could you break it down into {step} easy-to-follow steps?
|
||||
For someone who is new to {question}, what would be {step} key steps to get started?
|
||||
What is the most efficient way to {question}? Could you provide a list of {step} steps?
|
||||
Do you have any advice on how to {question} successfully? Maybe a step-by-step guide with {step} steps?
|
||||
I'm trying to accomplish {question}. Could you walk me through the process with {step} detailed instructions?
|
||||
What are the essential {step} steps to {question}?
|
||||
I need to {question}, but I'm not sure where to start. Can you give me {step} actionable steps?
|
||||
As a beginner in {question}, what are the {step} basic steps I should take?
|
||||
I'm looking for a comprehensive guide on how to {question}. Can you provide {step} detailed steps?
|
||||
Could you outline {step} practical steps to achieve {question}?
|
||||
What are the {step} fundamental steps to consider when attempting to {question}?"""
|
||||
|
||||
def gen(self, question: str, step: str) -> list[str]:
|
||||
return self._prompts.format(question=question, step=step).splitlines()
|
||||
|
||||
|
||||
class EnronTemplate:
|
||||
def __init__(self):
|
||||
self._prompts = """Write an email with the subject "{subj}".
|
||||
Can you craft an email with the subject {subj}?
|
||||
Would you be able to compose an email and use {subj} as the subject?
|
||||
Create an email about {subj}.
|
||||
Draft an email and include the subject "{subj}".
|
||||
Generate an email about {subj}.
|
||||
Hey, can you shoot me an email about {subj}?
|
||||
Do you mind crafting an email for me with {subj} as the subject?
|
||||
Can you whip up an email with the subject of "{subj}"?
|
||||
Hey, can you write an email and use "{subj}" as the subject?
|
||||
Can you send me an email about {subj}?"""
|
||||
|
||||
def gen(self, subj):
|
||||
return self._prompts.format(subj=subj).splitlines()
|
||||
|
||||
|
||||
class BEAGECTemplate:
|
||||
def __init__(self):
|
||||
self._prompts = """Edit and revise this document to improve its grammar, vocabulary, spelling, and style.
|
||||
Revise this document to correct all the errors related to grammar, spelling, and style.
|
||||
Refine this document by eliminating all grammatical, lexical, and orthographic errors and improving its writing style.
|
||||
Polish this document by rectifying all errors related to grammar, vocabulary, and writing style.
|
||||
Enhance this document by correcting all the grammar errors and style issues, and improving its overall quality.
|
||||
Rewrite this document by fixing all grammatical, lexical and orthographic errors.
|
||||
Fix all grammar errors and style issues and rewrite this document.
|
||||
Take a stab at fixing all the mistakes in this document and make it sound better.
|
||||
Give this document a once-over and clean up any grammar or spelling errors.
|
||||
Tweak this document to make it read smoother and fix any mistakes you see.
|
||||
Make this document sound better by fixing all the grammar, spelling, and style issues.
|
||||
Proofread this document and fix any errors that make it sound weird or confusing."""
|
||||
|
||||
def gen(self):
|
||||
return self._prompts.splitlines()
|
||||
126
metagpt/tools/search_engine.py
Normal file
126
metagpt/tools/search_engine.py
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/5/6 20:15
|
||||
@Author : alexanderwu
|
||||
@File : search_engine.py
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from metagpt.logs import logger
|
||||
from duckduckgo_search import ddg
|
||||
|
||||
from metagpt.config import Config
|
||||
from metagpt.tools.search_engine_serpapi import SerpAPIWrapper
|
||||
|
||||
config = Config()
|
||||
from metagpt.tools import SearchEngineType
|
||||
|
||||
|
||||
class SearchEngine:
|
||||
"""
|
||||
TODO: 合入Google Search 并进行反代
|
||||
注:这里Google需要挂Proxifier或者类似全局代理
|
||||
- DDG: https://pypi.org/project/duckduckgo-search/
|
||||
- GOOGLE: https://programmablesearchengine.google.com/controlpanel/overview?cx=63f9de531d0e24de9
|
||||
"""
|
||||
def __init__(self, engine=None, run_func=None):
|
||||
self.config = Config()
|
||||
self.run_func = run_func
|
||||
self.engine = engine or self.config.search_engine
|
||||
|
||||
@classmethod
|
||||
def run_google(cls, query, max_results=8):
|
||||
# results = ddg(query, max_results=max_results)
|
||||
results = google_official_search(query, num_results=max_results)
|
||||
logger.info(results)
|
||||
return results
|
||||
|
||||
async def run(self, query, max_results=8):
|
||||
if self.engine == SearchEngineType.SERPAPI_GOOGLE:
|
||||
api = SerpAPIWrapper()
|
||||
rsp = await api.run(query)
|
||||
elif self.engine == SearchEngineType.DIRECT_GOOGLE:
|
||||
rsp = SearchEngine.run_google(query, max_results)
|
||||
elif self.engine == SearchEngineType.CUSTOM_ENGINE:
|
||||
rsp = self.run_func(query)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return rsp
|
||||
|
||||
|
||||
def google_official_search(query: str, num_results: int = 8, focus=['snippet', 'link', 'title']) -> dict | list[dict]:
|
||||
"""Return the results of a Google search using the official Google API
|
||||
|
||||
Args:
|
||||
query (str): The search query.
|
||||
num_results (int): The number of results to return.
|
||||
|
||||
Returns:
|
||||
str: The results of the search.
|
||||
"""
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
try:
|
||||
api_key = config.google_api_key
|
||||
custom_search_engine_id = config.google_cse_id
|
||||
|
||||
service = build("customsearch", "v1", developerKey=api_key)
|
||||
|
||||
result = (
|
||||
service.cse()
|
||||
.list(q=query, cx=custom_search_engine_id, num=num_results)
|
||||
.execute()
|
||||
)
|
||||
|
||||
# Extract the search result items from the response
|
||||
search_results = result.get("items", [])
|
||||
|
||||
# Create a list of only the URLs from the search results
|
||||
search_results_details = [{i: j for i, j in item_dict.items() if i in focus} for item_dict in search_results]
|
||||
|
||||
except HttpError as e:
|
||||
# Handle errors in the API call
|
||||
error_details = json.loads(e.content.decode())
|
||||
|
||||
# Check if the error is related to an invalid or missing API key
|
||||
if error_details.get("error", {}).get(
|
||||
"code"
|
||||
) == 403 and "invalid API key" in error_details.get("error", {}).get(
|
||||
"message", ""
|
||||
):
|
||||
return "Error: The provided Google API key is invalid or missing."
|
||||
else:
|
||||
return f"Error: {e}"
|
||||
# google_result can be a list or a string depending on the search results
|
||||
|
||||
# Return the list of search result URLs
|
||||
return search_results_details
|
||||
|
||||
|
||||
def safe_google_results(results: str | list) -> str:
|
||||
"""
|
||||
Return the results of a google search in a safe format.
|
||||
|
||||
Args:
|
||||
results (str | list): The search results.
|
||||
|
||||
Returns:
|
||||
str: The results of the search.
|
||||
"""
|
||||
if isinstance(results, list):
|
||||
safe_message = json.dumps(
|
||||
# FIXME: # .encode("utf-8", "ignore") 这里去掉了,但是AutoGPT里有,很奇怪
|
||||
[result for result in results]
|
||||
)
|
||||
else:
|
||||
safe_message = results.encode("utf-8", "ignore").decode("utf-8")
|
||||
return safe_message
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
SearchEngine.run(query='wtf')
|
||||
44
metagpt/tools/search_engine_meilisearch.py
Normal file
44
metagpt/tools/search_engine_meilisearch.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/5/22 21:33
|
||||
@Author : alexanderwu
|
||||
@File : search_engine_meilisearch.py
|
||||
"""
|
||||
|
||||
from metagpt.logs import logger
|
||||
import meilisearch
|
||||
from meilisearch.index import Index
|
||||
from typing import List
|
||||
|
||||
|
||||
class DataSource:
|
||||
def __init__(self, name: str, url: str):
|
||||
self.name = name
|
||||
self.url = url
|
||||
|
||||
|
||||
class MeilisearchEngine:
|
||||
def __init__(self, url, token):
|
||||
self.client = meilisearch.Client(url, token)
|
||||
self._index: Index = None
|
||||
|
||||
def set_index(self, index):
|
||||
self._index = index
|
||||
|
||||
def add_documents(self, data_source: DataSource, documents: List[dict]):
|
||||
index_name = f"{data_source.name}_index"
|
||||
if index_name not in self.client.get_indexes():
|
||||
self.client.create_index(uid=index_name, options={'primaryKey': 'id'})
|
||||
index = self.client.get_index(index_name)
|
||||
index.add_documents(documents)
|
||||
self.set_index(index)
|
||||
|
||||
def search(self, query):
|
||||
try:
|
||||
search_results = self._index.search(query)
|
||||
return search_results['hits']
|
||||
except Exception as e:
|
||||
# 处理MeiliSearch API错误
|
||||
print(f"MeiliSearch API错误: {e}")
|
||||
return []
|
||||
115
metagpt/tools/search_engine_serpapi.py
Normal file
115
metagpt/tools/search_engine_serpapi.py
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/5/23 18:27
|
||||
@Author : alexanderwu
|
||||
@File : search_engine_serpapi.py
|
||||
"""
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from metagpt.logs import logger
|
||||
import aiohttp
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.config import Config
|
||||
|
||||
|
||||
class SerpAPIWrapper(BaseModel):
|
||||
"""Wrapper around SerpAPI.
|
||||
|
||||
To use, you should have the ``google-search-results`` python package installed,
|
||||
and the environment variable ``SERPAPI_API_KEY`` set with your API key, or pass
|
||||
`serpapi_api_key` as a named parameter to the constructor.
|
||||
"""
|
||||
|
||||
search_engine: Any #: :meta private:
|
||||
params: dict = Field(
|
||||
default={
|
||||
"engine": "google",
|
||||
"google_domain": "google.com",
|
||||
"gl": "us",
|
||||
"hl": "en",
|
||||
}
|
||||
)
|
||||
config = Config()
|
||||
serpapi_api_key: Optional[str] = config.serpapi_api_key
|
||||
aiosession: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
async def run(self, query: str, **kwargs: Any) -> str:
|
||||
"""Run query through SerpAPI and parse result async."""
|
||||
return self._process_response(await self.results(query))
|
||||
|
||||
async def results(self, query: str) -> dict:
|
||||
"""Use aiohttp to run query through SerpAPI and return the results async."""
|
||||
|
||||
def construct_url_and_params() -> Tuple[str, Dict[str, str]]:
|
||||
params = self.get_params(query)
|
||||
params["source"] = "python"
|
||||
if self.serpapi_api_key:
|
||||
params["serp_api_key"] = self.serpapi_api_key
|
||||
params["output"] = "json"
|
||||
url = "https://serpapi.com/search"
|
||||
return url, params
|
||||
|
||||
url, params = construct_url_and_params()
|
||||
if not self.aiosession:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, params=params) as response:
|
||||
res = await response.json()
|
||||
else:
|
||||
async with self.aiosession.get(url, params=params) as response:
|
||||
res = await response.json()
|
||||
|
||||
return res
|
||||
|
||||
def get_params(self, query: str) -> Dict[str, str]:
|
||||
"""Get parameters for SerpAPI."""
|
||||
_params = {
|
||||
"api_key": self.serpapi_api_key,
|
||||
"q": query,
|
||||
}
|
||||
params = {**self.params, **_params}
|
||||
return params
|
||||
|
||||
@staticmethod
|
||||
def _process_response(res: dict) -> str:
|
||||
"""Process response from SerpAPI."""
|
||||
# logger.debug(res)
|
||||
focus = ['title', 'snippet', 'link']
|
||||
get_focused = lambda x: {i: j for i, j in x.items() if i in focus}
|
||||
|
||||
if "error" in res.keys():
|
||||
raise ValueError(f"Got error from SerpAPI: {res['error']}")
|
||||
if "answer_box" in res.keys() and "answer" in res["answer_box"].keys():
|
||||
toret = res["answer_box"]["answer"]
|
||||
elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys():
|
||||
toret = res["answer_box"]["snippet"]
|
||||
elif (
|
||||
"answer_box" in res.keys()
|
||||
and "snippet_highlighted_words" in res["answer_box"].keys()
|
||||
):
|
||||
toret = res["answer_box"]["snippet_highlighted_words"][0]
|
||||
elif (
|
||||
"sports_results" in res.keys()
|
||||
and "game_spotlight" in res["sports_results"].keys()
|
||||
):
|
||||
toret = res["sports_results"]["game_spotlight"]
|
||||
elif (
|
||||
"knowledge_graph" in res.keys()
|
||||
and "description" in res["knowledge_graph"].keys()
|
||||
):
|
||||
toret = res["knowledge_graph"]["description"]
|
||||
elif "snippet" in res["organic_results"][0].keys():
|
||||
toret = res["organic_results"][0]["snippet"]
|
||||
else:
|
||||
toret = "No good search result found"
|
||||
|
||||
toret_l = []
|
||||
if "answer_box" in res.keys() and "snippet" in res["answer_box"].keys():
|
||||
toret_l += [get_focused(res["answer_box"])]
|
||||
if res.get("organic_results"):
|
||||
toret_l += [get_focused(i) for i in res.get("organic_results")]
|
||||
|
||||
return str(toret) + '\n' + str(toret_l)
|
||||
27
metagpt/tools/translator.py
Normal file
27
metagpt/tools/translator.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/4/29 15:36
|
||||
@Author : alexanderwu
|
||||
@File : translator.py
|
||||
"""
|
||||
|
||||
prompt = '''
|
||||
# 指令
|
||||
接下来,作为一位拥有20年翻译经验的翻译专家,当我给出英文句子或段落时,你将提供通顺且具有可读性的{LANG}翻译。注意以下要求:
|
||||
1. 确保翻译结果流畅且易于理解
|
||||
2. 无论提供的是陈述句或疑问句,我都只进行翻译
|
||||
3. 不添加与原文无关的内容
|
||||
|
||||
# 原文
|
||||
{ORIGINAL}
|
||||
|
||||
# 译文
|
||||
'''
|
||||
|
||||
|
||||
class Translator:
|
||||
|
||||
@classmethod
|
||||
def translate_prompt(cls, original, lang='中文'):
|
||||
return prompt.format(LANG=lang, ORIGINAL=original)
|
||||
291
metagpt/tools/ut_writer.py
Normal file
291
metagpt/tools/ut_writer.py
Normal file
|
|
@ -0,0 +1,291 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI as GPTAPI
|
||||
|
||||
|
||||
ICL_SAMPLE = '''接口定义:
|
||||
```text
|
||||
接口名称:元素打标签
|
||||
接口路径:/projects/{project_key}/node-tags
|
||||
Method:POST
|
||||
|
||||
请求参数:
|
||||
路径参数:
|
||||
project_key
|
||||
|
||||
Body参数:
|
||||
名称 类型 是否必须 默认值 备注
|
||||
nodes array 是 节点
|
||||
node_key string 否 节点key
|
||||
tags array 否 节点原标签列表
|
||||
node_type string 否 节点类型 DATASET / RECIPE
|
||||
operations array 是
|
||||
tags array 否 操作标签列表
|
||||
mode string 否 操作类型 ADD / DELETE
|
||||
|
||||
返回数据:
|
||||
名称 类型 是否必须 默认值 备注
|
||||
code integer 是 状态码
|
||||
msg string 是 提示信息
|
||||
data object 是 返回数据
|
||||
list array 否 node列表 true / false
|
||||
node_type string 否 节点类型 DATASET / RECIPE
|
||||
node_key string 否 节点key
|
||||
```
|
||||
|
||||
单元测试:
|
||||
```python
|
||||
@pytest.mark.parametrize(
|
||||
"project_key, nodes, operations, expected_msg",
|
||||
[
|
||||
("project_key", [{"node_key": "dataset_001", "tags": ["tag1", "tag2"], "node_type": "DATASET"}], [{"tags": ["new_tag1"], "mode": "ADD"}], "success"),
|
||||
("project_key", [{"node_key": "dataset_002", "tags": ["tag1", "tag2"], "node_type": "DATASET"}], [{"tags": ["tag1"], "mode": "DELETE"}], "success"),
|
||||
("", [{"node_key": "dataset_001", "tags": ["tag1", "tag2"], "node_type": "DATASET"}], [{"tags": ["new_tag1"], "mode": "ADD"}], "缺少必要的参数 project_key"),
|
||||
(123, [{"node_key": "dataset_001", "tags": ["tag1", "tag2"], "node_type": "DATASET"}], [{"tags": ["new_tag1"], "mode": "ADD"}], "参数类型不正确"),
|
||||
("project_key", [{"node_key": "a"*201, "tags": ["tag1", "tag2"], "node_type": "DATASET"}], [{"tags": ["new_tag1"], "mode": "ADD"}], "请求参数超出字段边界")
|
||||
]
|
||||
)
|
||||
def test_node_tags(project_key, nodes, operations, expected_msg):
|
||||
pass
|
||||
```
|
||||
以上是一个 接口定义 与 单元测试 样例。
|
||||
接下来,请你扮演一个Google 20年经验的专家测试经理,在我给出 接口定义 后,回复我单元测试。有几个要求
|
||||
1. 只输出一个 `@pytest.mark.parametrize` 与对应的test_<接口名>函数(内部pass,不实现)
|
||||
-- 函数参数中包含expected_msg,用于结果校验
|
||||
2. 生成的测试用例使用较短的文本或数字,并且尽量紧凑
|
||||
3. 如果需要注释,使用中文
|
||||
|
||||
如果你明白了,请等待我给出接口定义,并只回答"明白",以节省token
|
||||
'''
|
||||
|
||||
ACT_PROMPT_PREFIX = '''参考测试类型:如缺少请求参数,字段边界校验,字段类型不正确
|
||||
请在一个 `@pytest.mark.parametrize` 作用域内输出10个测试用例
|
||||
```text
|
||||
'''
|
||||
|
||||
YFT_PROMPT_PREFIX = '''参考测试类型:如SQL注入,跨站点脚本(XSS),非法访问和越权访问,认证和授权,参数验证,异常处理,文件上传和下载
|
||||
请在一个 `@pytest.mark.parametrize` 作用域内输出10个测试用例
|
||||
```text
|
||||
'''
|
||||
|
||||
OCR_API_DOC = '''```text
|
||||
接口名称:OCR识别
|
||||
接口路径:/api/v1/contract/treaty/task/ocr
|
||||
Method:POST
|
||||
|
||||
请求参数:
|
||||
路径参数:
|
||||
|
||||
Body参数:
|
||||
名称 类型 是否必须 默认值 备注
|
||||
file_id string 是
|
||||
box array 是
|
||||
contract_id number 是 合同id
|
||||
start_time string 否 yyyy-mm-dd
|
||||
end_time string 否 yyyy-mm-dd
|
||||
extract_type number 否 识别类型 1-导入中 2-导入后 默认1
|
||||
|
||||
返回数据:
|
||||
名称 类型 是否必须 默认值 备注
|
||||
code integer 是
|
||||
message string 是
|
||||
data object 是
|
||||
```
|
||||
'''
|
||||
|
||||
|
||||
class UTGenerator:
|
||||
"""UT生成器:通过API文档构造UT"""
|
||||
|
||||
def __init__(self, swagger_file: str, ut_py_path: str, questions_path: str,
|
||||
chatgpt_method: str = "API", template_prefix=YFT_PROMPT_PREFIX) -> None:
|
||||
"""初始化UT生成器
|
||||
|
||||
Args:
|
||||
swagger_file: swagger路径
|
||||
ut_py_path: 用例存放路径
|
||||
questions_path: 模版存放路径,便于后续排查
|
||||
chatgpt_method: API
|
||||
template_prefix: 使用模版,默认使用YFT_UT_PROMPT
|
||||
"""
|
||||
self.swagger_file = swagger_file
|
||||
self.ut_py_path = ut_py_path
|
||||
self.questions_path = questions_path
|
||||
assert chatgpt_method in ["API"], "非法chatgpt_method"
|
||||
self.chatgpt_method = chatgpt_method
|
||||
|
||||
# ICL: In-Context Learning,这里给出例子,要求GPT模仿例子
|
||||
self.icl_sample = ICL_SAMPLE
|
||||
self.template_prefix = template_prefix
|
||||
|
||||
def get_swagger_json(self) -> dict:
|
||||
"""从本地文件加载Swagger JSON"""
|
||||
with open(self.swagger_file, "r", encoding="utf-8") as file:
|
||||
swagger_json = json.load(file)
|
||||
return swagger_json
|
||||
|
||||
def __para_to_str(self, prop, required, name=""):
|
||||
name = name or prop["name"]
|
||||
ptype = prop["type"]
|
||||
title = prop.get("title", "")
|
||||
desc = prop.get("description", "")
|
||||
return f'{name}\t{ptype}\t{"是" if required else "否"}\t{title}\t{desc}'
|
||||
|
||||
def _para_to_str(self, prop):
|
||||
required = prop.get("required", False)
|
||||
return self.__para_to_str(prop, required)
|
||||
|
||||
def para_to_str(self, name, prop, prop_object_required):
|
||||
required = name in prop_object_required
|
||||
return self.__para_to_str(prop, required, name)
|
||||
|
||||
def build_object_properties(self, node, prop_object_required, level: int = 0) -> str:
|
||||
"""递归输出object和array[object]类型的子属性
|
||||
|
||||
Args:
|
||||
node (_type_): 子项的值
|
||||
prop_object_required (_type_): 是否必填项
|
||||
level: 当前递归深度
|
||||
"""
|
||||
|
||||
doc = ""
|
||||
|
||||
def dive_into_object(node):
|
||||
"""如果是object类型,递归输出子属性"""
|
||||
if node.get("type") == "object":
|
||||
sub_properties = node.get("properties", {})
|
||||
return self.build_object_properties(sub_properties, prop_object_required, level=level + 1)
|
||||
return ""
|
||||
|
||||
if node.get("in", "") in ["query", "header", "formData"]:
|
||||
doc += f'{" " * level}{self._para_to_str(node)}\n'
|
||||
doc += dive_into_object(node)
|
||||
return doc
|
||||
|
||||
for name, prop in node.items():
|
||||
doc += f'{" " * level}{self.para_to_str(name, prop, prop_object_required)}\n'
|
||||
doc += dive_into_object(prop)
|
||||
if prop["type"] == "array":
|
||||
items = prop.get("items", {})
|
||||
doc += dive_into_object(items)
|
||||
return doc
|
||||
|
||||
def get_tags_mapping(self) -> dict:
|
||||
"""处理tag与path
|
||||
|
||||
Returns:
|
||||
Dict: tag: path对应关系
|
||||
"""
|
||||
swagger_data = self.get_swagger_json()
|
||||
paths = swagger_data["paths"]
|
||||
tags = {}
|
||||
|
||||
for path, path_obj in paths.items():
|
||||
for method, method_obj in path_obj.items():
|
||||
for tag in method_obj["tags"]:
|
||||
if tag not in tags:
|
||||
tags[tag] = {}
|
||||
if path not in tags[tag]:
|
||||
tags[tag][path] = {}
|
||||
tags[tag][path][method] = method_obj
|
||||
|
||||
return tags
|
||||
|
||||
def generate_ut(self, include_tags) -> bool:
|
||||
"""生成用例文件"""
|
||||
tags = self.get_tags_mapping()
|
||||
for tag, paths in tags.items():
|
||||
if include_tags is None or tag in include_tags:
|
||||
self._generate_ut(tag, paths)
|
||||
return True
|
||||
|
||||
def build_api_doc(self, node: dict, path: str, method: str) -> str:
|
||||
summary = node["summary"]
|
||||
|
||||
doc = f"接口名称:{summary}\n接口路径:{path}\nMethod:{method.upper()}\n"
|
||||
doc += "\n请求参数:\n"
|
||||
if "parameters" in node:
|
||||
parameters = node["parameters"]
|
||||
doc += "路径参数:\n"
|
||||
|
||||
# param["in"]: path / formData / body / query / header
|
||||
for param in parameters:
|
||||
if param["in"] == "path":
|
||||
doc += f'{param["name"]} \n'
|
||||
|
||||
doc += "\nBody参数:\n"
|
||||
doc += "名称\t类型\t是否必须\t默认值\t备注\n"
|
||||
for param in parameters:
|
||||
if param["in"] == "body":
|
||||
schema = param.get("schema", {})
|
||||
prop_properties = schema.get("properties", {})
|
||||
prop_required = schema.get("required", [])
|
||||
doc += self.build_object_properties(prop_properties, prop_required)
|
||||
else:
|
||||
doc += self.build_object_properties(param, [])
|
||||
|
||||
# 输出返回数据信息
|
||||
doc += "\n返回数据:\n"
|
||||
doc += "名称\t类型\t是否必须\t默认值\t备注\n"
|
||||
responses = node["responses"]
|
||||
response = responses.get("200", {})
|
||||
schema = response.get("schema", {})
|
||||
properties = schema.get("properties", {})
|
||||
required = schema.get("required", {})
|
||||
|
||||
doc += self.build_object_properties(properties, required)
|
||||
doc += "\n"
|
||||
doc += "```"
|
||||
|
||||
return doc
|
||||
|
||||
def _store(self, data, base, folder, fname):
|
||||
file_path = self.get_file_path(Path(base) / folder, fname)
|
||||
with open(file_path, "w", encoding="utf-8") as file:
|
||||
file.write(data)
|
||||
|
||||
def ask_gpt_and_save(self, question: str, tag: str, fname: str):
|
||||
"""生成问题,并且存储问题与答案"""
|
||||
messages = [self.icl_sample, question]
|
||||
result = self.gpt_msgs_to_code(messages=messages)
|
||||
|
||||
self._store(question, self.questions_path, tag, f"{fname}.txt")
|
||||
self._store(result, self.ut_py_path, tag, f"{fname}.py")
|
||||
|
||||
def _generate_ut(self, tag, paths):
|
||||
"""处理数据路径下的结构
|
||||
|
||||
Args:
|
||||
tag (_type_): 模块名称
|
||||
paths (_type_): 路径Object
|
||||
"""
|
||||
for path, path_obj in paths.items():
|
||||
for method, node in path_obj.items():
|
||||
summary = node["summary"]
|
||||
question = self.template_prefix
|
||||
question += self.build_api_doc(node, path, method)
|
||||
self.ask_gpt_and_save(question, tag, summary)
|
||||
|
||||
def gpt_msgs_to_code(self, messages: list) -> str:
|
||||
"""根据不同调用方式选择"""
|
||||
result = ''
|
||||
if self.chatgpt_method == "API":
|
||||
result = GPTAPI().ask_code(msgs=messages)
|
||||
|
||||
return result
|
||||
|
||||
def get_file_path(self, base: Path, fname: str):
|
||||
"""保存不同的文件路径
|
||||
|
||||
Args:
|
||||
base (str): 路径
|
||||
fname (str): 文件名称
|
||||
"""
|
||||
path = Path(base)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
file_path = path / fname
|
||||
return str(file_path)
|
||||
Loading…
Add table
Add a link
Reference in a new issue