diff --git a/examples/st_game/memory/retrieve.py b/examples/st_game/memory/retrieve.py index d35418b77..6ff507037 100644 --- a/examples/st_game/memory/retrieve.py +++ b/examples/st_game/memory/retrieve.py @@ -5,8 +5,8 @@ import datetime from numpy import dot from numpy.linalg import norm -from examples.st_game.memory.agent_memory import AgentMemory, BasicMemory -from utils.utils import embedding_tools +from ..memory.agent_memory import AgentMemory, BasicMemory +from ..utils.utils import get_embedding def agent_retrieve(agent_memory: AgentMemory, curr_time: datetime.datetime, memory_forget: float, query: str, @@ -73,7 +73,7 @@ def extract_relevance(query, score_list): """ 抽取相关性 """ - query_embedding = embedding_tools(query) + query_embedding = get_embedding(query) # 进行 for i in range(len(score_list)): result = cos_sim(score_list[i]["memory"].embedding_key, query_embedding) diff --git a/examples/st_game/utils/utils.py b/examples/st_game/utils/utils.py index e6b29a667..5cd110e9f 100644 --- a/examples/st_game/utils/utils.py +++ b/examples/st_game/utils/utils.py @@ -27,16 +27,7 @@ def write_json_file(json_file: str, data: list, encoding=None): json.dump(data, fout, ensure_ascii=False, indent=4) -def embedding_tools(query): - embedding_result = openai.Embedding.create( - model="text-embedding-ada-002", - input=query - ) - embedding_key = embedding_result['data'][0]["embedding"] - return embedding_key - - -def read_csv_to_list(curr_file: str, header=False, strip_trail=True): +def read_csv_to_list(curr_file: str, header=False, strip_trail=True): """ Reads in a csv file to a list of list. If header is True, it returns a tuple with (header row, all rows) @@ -45,39 +36,38 @@ def read_csv_to_list(curr_file: str, header=False, strip_trail=True): RETURNS: List of list where the component lists are the rows of the file. """ - if not header: + if not header: analysis_list = [] - with open(curr_file) as f_analysis_file: + with open(curr_file) as f_analysis_file: data_reader = csv.reader(f_analysis_file, delimiter=",") - for count, row in enumerate(data_reader): - if strip_trail: + for count, row in enumerate(data_reader): + if strip_trail: row = [i.strip() for i in row] analysis_list += [row] return analysis_list - else: + else: analysis_list = [] - with open(curr_file) as f_analysis_file: + with open(curr_file) as f_analysis_file: data_reader = csv.reader(f_analysis_file, delimiter=",") - for count, row in enumerate(data_reader): - if strip_trail: + for count, row in enumerate(data_reader): + if strip_trail: row = [i.strip() for i in row] analysis_list += [row] return analysis_list[0], analysis_list[1:] -def get_embedding(text, model: str="text-embedding-ada-002"): +def get_embedding(text, model: str = "text-embedding-ada-002"): text = text.replace("\n", " ") - if not text: + if not text: text = "this is blank" return openai.Embedding.create( input=[text], model=model)['data'][0]['embedding'] -def generate_poig_score(scratch, event_type, description): - if "is idle" in description: +def generate_poig_score(scratch, event_type, description): + if "is idle" in description: return 1 - if event_type == "action": + if event_type == "action": return get_poignancy_action(scratch, description)[0] - elif event_type == "chat": + elif event_type == "chat": return get_poignancy_chat(scratch, description)[0] -