rm repeated openai embedding

This commit is contained in:
better629 2023-10-01 11:26:42 +08:00
parent fd4ee7256c
commit ba45c3710a
2 changed files with 18 additions and 28 deletions

View file

@ -5,8 +5,8 @@
import datetime
from numpy import dot
from numpy.linalg import norm
from examples.st_game.memory.agent_memory import AgentMemory, BasicMemory
from utils.utils import embedding_tools
from ..memory.agent_memory import AgentMemory, BasicMemory
from ..utils.utils import get_embedding
def agent_retrieve(agent_memory: AgentMemory, curr_time: datetime.datetime, memory_forget: float, query: str,
@ -73,7 +73,7 @@ def extract_relevance(query, score_list):
"""
抽取相关性
"""
query_embedding = embedding_tools(query)
query_embedding = get_embedding(query)
# 进行
for i in range(len(score_list)):
result = cos_sim(score_list[i]["memory"].embedding_key, query_embedding)

View file

@ -27,16 +27,7 @@ def write_json_file(json_file: str, data: list, encoding=None):
json.dump(data, fout, ensure_ascii=False, indent=4)
def embedding_tools(query):
embedding_result = openai.Embedding.create(
model="text-embedding-ada-002",
input=query
)
embedding_key = embedding_result['data'][0]["embedding"]
return embedding_key
def read_csv_to_list(curr_file: str, header=False, strip_trail=True):
def read_csv_to_list(curr_file: str, header=False, strip_trail=True):
"""
Reads in a csv file to a list of list. If header is True, it returns a
tuple with (header row, all rows)
@ -45,39 +36,38 @@ def read_csv_to_list(curr_file: str, header=False, strip_trail=True):
RETURNS:
List of list where the component lists are the rows of the file.
"""
if not header:
if not header:
analysis_list = []
with open(curr_file) as f_analysis_file:
with open(curr_file) as f_analysis_file:
data_reader = csv.reader(f_analysis_file, delimiter=",")
for count, row in enumerate(data_reader):
if strip_trail:
for count, row in enumerate(data_reader):
if strip_trail:
row = [i.strip() for i in row]
analysis_list += [row]
return analysis_list
else:
else:
analysis_list = []
with open(curr_file) as f_analysis_file:
with open(curr_file) as f_analysis_file:
data_reader = csv.reader(f_analysis_file, delimiter=",")
for count, row in enumerate(data_reader):
if strip_trail:
for count, row in enumerate(data_reader):
if strip_trail:
row = [i.strip() for i in row]
analysis_list += [row]
return analysis_list[0], analysis_list[1:]
def get_embedding(text, model: str="text-embedding-ada-002"):
def get_embedding(text, model: str = "text-embedding-ada-002"):
text = text.replace("\n", " ")
if not text:
if not text:
text = "this is blank"
return openai.Embedding.create(
input=[text], model=model)['data'][0]['embedding']
def generate_poig_score(scratch, event_type, description):
if "is idle" in description:
def generate_poig_score(scratch, event_type, description):
if "is idle" in description:
return 1
if event_type == "action":
if event_type == "action":
return get_poignancy_action(scratch, description)[0]
elif event_type == "chat":
elif event_type == "chat":
return get_poignancy_chat(scratch, description)[0]