feat: gpt-researcher custom response.Now very close to perplexity.

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2024-10-24 22:19:29 -07:00
parent dfb0967dbe
commit 46c9b228df
5 changed files with 215 additions and 194 deletions

View file

@ -5,14 +5,13 @@ from langchain_core.documents import Document
from langchain_ollama import OllamaLLM
from langchain_openai import ChatOpenAI
from sqlalchemy import insert
from prompts import CONTEXT_ANSWER_PROMPT, DATE_TODAY, SUBQUERY_DECOMPOSITION_PROMT
from prompts import DATE_TODAY
from pydmodels import ChatToUpdate, DescriptionResponse, DocWithContent, DocumentsToDelete, NewUserChat, UserCreate, UserQuery, RetrivedDocList, UserQueryResponse, UserQueryWithChatHistory
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
from langchain_unstructured import UnstructuredLoader
#Heirerical Indices class
from HIndices import HIndices
from Utils.stringify import stringify
# Auth Libs
@ -31,13 +30,20 @@ import os
from dotenv import load_dotenv
load_dotenv()
IS_LOCAL_SETUP = os.environ.get("IS_LOCAL_SETUP")
FAST_LLM = os.environ.get("FAST_LLM")
IS_LOCAL_SETUP = True if FAST_LLM.startswith("ollama") else False
ACCESS_TOKEN_EXPIRE_MINUTES = int(os.environ.get("ACCESS_TOKEN_EXPIRE_MINUTES"))
ALGORITHM = os.environ.get("ALGORITHM")
API_SECRET_KEY = os.environ.get("API_SECRET_KEY")
SECRET_KEY = os.environ.get("SECRET_KEY")
UNSTRUCTURED_API_KEY = os.environ.get("UNSTRUCTURED_API_KEY")
def extract_model_name(model_string: str) -> tuple[str, str]:
part1, part2 = model_string.split(":", 1) # Split into two parts at the first colon
return part2
MODEL_NAME = extract_model_name(FAST_LLM)
app = FastAPI()
# Dependency
@ -71,6 +77,7 @@ async def upload_files(files: list[UploadFile], token: str = Depends(oauth2_sche
chunking_strategy="basic",
max_characters=90000,
include_orig_elements=False,
strategy="fast",
)
filedocs = loader.load()
@ -117,7 +124,7 @@ async def upload_files(files: list[UploadFile], token: str = Depends(oauth2_sche
db.commit()
# Create hierarchical indices
if IS_LOCAL_SETUP == 'true':
if IS_LOCAL_SETUP == True:
index = HIndices(username=username)
else:
index = HIndices(username=username, api_key=api_key)
@ -144,61 +151,22 @@ def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
query = data.query
search_space = data.search_space
if(IS_LOCAL_SETUP == 'true'):
sub_query_llm = OllamaLLM(model="mistral-nemo",temperature=0)
qa_llm = OllamaLLM(model="mistral-nemo",temperature=0)
else:
sub_query_llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", api_key=data.openaikey)
qa_llm = ChatOpenAI(temperature=0.5, model_name="gpt-4o-mini", api_key=data.openaikey)
# Create an LLMChain for sub-query decomposition
subquery_decomposer_chain = SUBQUERY_DECOMPOSITION_PROMT | sub_query_llm
#Experimental
def decompose_query(original_query: str):
"""
Decompose the original query into simpler sub-queries.
Args:
original_query (str): The original complex query
Returns:
List[str]: A list of simpler sub-queries
"""
if(IS_LOCAL_SETUP == 'true'):
response = subquery_decomposer_chain.invoke(original_query)
else:
response = subquery_decomposer_chain.invoke(original_query).content
sub_queries = [q.strip() for q in response.split('\n') if q.strip() and not q.strip().startswith('Sub-queries:')]
return sub_queries
# Create Heirarical Indecices
if(IS_LOCAL_SETUP == 'true'):
if(IS_LOCAL_SETUP == True):
index = HIndices(username=username)
else:
index = HIndices(username=username,api_key=data.openaikey)
# For Those Who Want HyDe Questions
# sub_queries = decompose_query(query)
#Implement HyDe over it if you crazy
sub_queries = []
sub_queries.append(query)
duplicate_related_summary_docs = []
context_to_answer = ""
for sub_query in sub_queries:
localreturn = index.local_search(query=sub_query, search_space=search_space)
globalreturn, related_summary_docs = index.global_search(query=sub_query, search_space=search_space)
context_to_answer += localreturn + "\n\n" + globalreturn
# I know this is not the best way to do it, but I am too lazy to change it now
related_summary_docs = index.summary_vector_search(query=sub_query, search_space=search_space)
duplicate_related_summary_docs.extend(related_summary_docs)
@ -222,16 +190,11 @@ def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
returnDocs.append(entry)
ans_chain = CONTEXT_ANSWER_PROMPT | qa_llm
finalans = ans_chain.invoke({"query": query, "context": context_to_answer})
finalans = index.new_search(query=query, search_space=search_space)
if(IS_LOCAL_SETUP == 'true'):
return UserQueryResponse(response=finalans, relateddocs=returnDocs)
else:
return UserQueryResponse(response=finalans.content, relateddocs=returnDocs)
return UserQueryResponse(response=finalans, relateddocs=returnDocs)
except JWTError:
@ -310,7 +273,7 @@ def save_data(apires: RetrivedDocList, db: Session = Depends(get_db)):
db.commit()
# Create hierarchical indices
if IS_LOCAL_SETUP == 'true':
if IS_LOCAL_SETUP == True:
index = HIndices(username=username)
else:
index = HIndices(username=username, api_key=apires.openaikey)
@ -336,10 +299,10 @@ def doc_chat_with_history(data: UserQueryWithChatHistory, response_model=Descrip
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
if(IS_LOCAL_SETUP == 'true'):
llm = OllamaLLM(model="mistral-nemo",temperature=0)
if(IS_LOCAL_SETUP == True):
llm = OllamaLLM(model=MODEL_NAME,temperature=0)
else:
llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", api_key=data.openaikey)
llm = ChatOpenAI(temperature=0, model_name=MODEL_NAME, api_key=data.openaikey)
chatHistory = []
@ -365,7 +328,7 @@ def doc_chat_with_history(data: UserQueryWithChatHistory, response_model=Descrip
response = descriptionchain.invoke({"input": data.query})
if(IS_LOCAL_SETUP == 'true'):
if(IS_LOCAL_SETUP == True):
return DescriptionResponse(response=response)
else:
return DescriptionResponse(response=response.content)
@ -384,7 +347,7 @@ def delete_all_related_data(data: DocumentsToDelete, db: Session = Depends(get_d
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
if(IS_LOCAL_SETUP == 'true'):
if(IS_LOCAL_SETUP == True):
index = HIndices(username=username)
else:
index = HIndices(username=username,api_key=data.openaikey)