mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-26 17:26:23 +02:00
feat: gpt-researcher custom response.Now very close to perplexity.
This commit is contained in:
parent
dfb0967dbe
commit
46c9b228df
5 changed files with 215 additions and 194 deletions
|
|
@ -5,14 +5,13 @@ from langchain_core.documents import Document
|
|||
from langchain_ollama import OllamaLLM
|
||||
from langchain_openai import ChatOpenAI
|
||||
from sqlalchemy import insert
|
||||
from prompts import CONTEXT_ANSWER_PROMPT, DATE_TODAY, SUBQUERY_DECOMPOSITION_PROMT
|
||||
from prompts import DATE_TODAY
|
||||
from pydmodels import ChatToUpdate, DescriptionResponse, DocWithContent, DocumentsToDelete, NewUserChat, UserCreate, UserQuery, RetrivedDocList, UserQueryResponse, UserQueryWithChatHistory
|
||||
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
|
||||
from langchain_unstructured import UnstructuredLoader
|
||||
|
||||
#Heirerical Indices class
|
||||
from HIndices import HIndices
|
||||
|
||||
from Utils.stringify import stringify
|
||||
|
||||
# Auth Libs
|
||||
|
|
@ -31,13 +30,20 @@ import os
|
|||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
IS_LOCAL_SETUP = os.environ.get("IS_LOCAL_SETUP")
|
||||
FAST_LLM = os.environ.get("FAST_LLM")
|
||||
IS_LOCAL_SETUP = True if FAST_LLM.startswith("ollama") else False
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = int(os.environ.get("ACCESS_TOKEN_EXPIRE_MINUTES"))
|
||||
ALGORITHM = os.environ.get("ALGORITHM")
|
||||
API_SECRET_KEY = os.environ.get("API_SECRET_KEY")
|
||||
SECRET_KEY = os.environ.get("SECRET_KEY")
|
||||
UNSTRUCTURED_API_KEY = os.environ.get("UNSTRUCTURED_API_KEY")
|
||||
|
||||
def extract_model_name(model_string: str) -> tuple[str, str]:
|
||||
part1, part2 = model_string.split(":", 1) # Split into two parts at the first colon
|
||||
return part2
|
||||
|
||||
MODEL_NAME = extract_model_name(FAST_LLM)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Dependency
|
||||
|
|
@ -71,6 +77,7 @@ async def upload_files(files: list[UploadFile], token: str = Depends(oauth2_sche
|
|||
chunking_strategy="basic",
|
||||
max_characters=90000,
|
||||
include_orig_elements=False,
|
||||
strategy="fast",
|
||||
)
|
||||
|
||||
filedocs = loader.load()
|
||||
|
|
@ -117,7 +124,7 @@ async def upload_files(files: list[UploadFile], token: str = Depends(oauth2_sche
|
|||
db.commit()
|
||||
|
||||
# Create hierarchical indices
|
||||
if IS_LOCAL_SETUP == 'true':
|
||||
if IS_LOCAL_SETUP == True:
|
||||
index = HIndices(username=username)
|
||||
else:
|
||||
index = HIndices(username=username, api_key=api_key)
|
||||
|
|
@ -144,61 +151,22 @@ def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
|
|||
|
||||
query = data.query
|
||||
search_space = data.search_space
|
||||
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
sub_query_llm = OllamaLLM(model="mistral-nemo",temperature=0)
|
||||
qa_llm = OllamaLLM(model="mistral-nemo",temperature=0)
|
||||
else:
|
||||
sub_query_llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", api_key=data.openaikey)
|
||||
qa_llm = ChatOpenAI(temperature=0.5, model_name="gpt-4o-mini", api_key=data.openaikey)
|
||||
|
||||
|
||||
|
||||
# Create an LLMChain for sub-query decomposition
|
||||
subquery_decomposer_chain = SUBQUERY_DECOMPOSITION_PROMT | sub_query_llm
|
||||
|
||||
#Experimental
|
||||
def decompose_query(original_query: str):
|
||||
"""
|
||||
Decompose the original query into simpler sub-queries.
|
||||
|
||||
Args:
|
||||
original_query (str): The original complex query
|
||||
|
||||
Returns:
|
||||
List[str]: A list of simpler sub-queries
|
||||
"""
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
response = subquery_decomposer_chain.invoke(original_query)
|
||||
else:
|
||||
response = subquery_decomposer_chain.invoke(original_query).content
|
||||
|
||||
sub_queries = [q.strip() for q in response.split('\n') if q.strip() and not q.strip().startswith('Sub-queries:')]
|
||||
return sub_queries
|
||||
|
||||
|
||||
|
||||
# Create Heirarical Indecices
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
if(IS_LOCAL_SETUP == True):
|
||||
index = HIndices(username=username)
|
||||
else:
|
||||
index = HIndices(username=username,api_key=data.openaikey)
|
||||
|
||||
|
||||
|
||||
# For Those Who Want HyDe Questions
|
||||
# sub_queries = decompose_query(query)
|
||||
|
||||
#Implement HyDe over it if you crazy
|
||||
sub_queries = []
|
||||
sub_queries.append(query)
|
||||
|
||||
duplicate_related_summary_docs = []
|
||||
context_to_answer = ""
|
||||
for sub_query in sub_queries:
|
||||
localreturn = index.local_search(query=sub_query, search_space=search_space)
|
||||
globalreturn, related_summary_docs = index.global_search(query=sub_query, search_space=search_space)
|
||||
|
||||
context_to_answer += localreturn + "\n\n" + globalreturn
|
||||
|
||||
# I know this is not the best way to do it, but I am too lazy to change it now
|
||||
related_summary_docs = index.summary_vector_search(query=sub_query, search_space=search_space)
|
||||
duplicate_related_summary_docs.extend(related_summary_docs)
|
||||
|
||||
|
||||
|
|
@ -222,16 +190,11 @@ def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
|
|||
|
||||
returnDocs.append(entry)
|
||||
|
||||
|
||||
ans_chain = CONTEXT_ANSWER_PROMPT | qa_llm
|
||||
|
||||
|
||||
finalans = ans_chain.invoke({"query": query, "context": context_to_answer})
|
||||
finalans = index.new_search(query=query, search_space=search_space)
|
||||
|
||||
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
return UserQueryResponse(response=finalans, relateddocs=returnDocs)
|
||||
else:
|
||||
return UserQueryResponse(response=finalans.content, relateddocs=returnDocs)
|
||||
return UserQueryResponse(response=finalans, relateddocs=returnDocs)
|
||||
|
||||
|
||||
except JWTError:
|
||||
|
|
@ -310,7 +273,7 @@ def save_data(apires: RetrivedDocList, db: Session = Depends(get_db)):
|
|||
db.commit()
|
||||
|
||||
# Create hierarchical indices
|
||||
if IS_LOCAL_SETUP == 'true':
|
||||
if IS_LOCAL_SETUP == True:
|
||||
index = HIndices(username=username)
|
||||
else:
|
||||
index = HIndices(username=username, api_key=apires.openaikey)
|
||||
|
|
@ -336,10 +299,10 @@ def doc_chat_with_history(data: UserQueryWithChatHistory, response_model=Descrip
|
|||
if username is None:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
llm = OllamaLLM(model="mistral-nemo",temperature=0)
|
||||
if(IS_LOCAL_SETUP == True):
|
||||
llm = OllamaLLM(model=MODEL_NAME,temperature=0)
|
||||
else:
|
||||
llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", api_key=data.openaikey)
|
||||
llm = ChatOpenAI(temperature=0, model_name=MODEL_NAME, api_key=data.openaikey)
|
||||
|
||||
chatHistory = []
|
||||
|
||||
|
|
@ -365,7 +328,7 @@ def doc_chat_with_history(data: UserQueryWithChatHistory, response_model=Descrip
|
|||
|
||||
response = descriptionchain.invoke({"input": data.query})
|
||||
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
if(IS_LOCAL_SETUP == True):
|
||||
return DescriptionResponse(response=response)
|
||||
else:
|
||||
return DescriptionResponse(response=response.content)
|
||||
|
|
@ -384,7 +347,7 @@ def delete_all_related_data(data: DocumentsToDelete, db: Session = Depends(get_d
|
|||
if username is None:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
if(IS_LOCAL_SETUP == True):
|
||||
index = HIndices(username=username)
|
||||
else:
|
||||
index = HIndices(username=username,api_key=data.openaikey)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue