2024-08-12 00:32:42 -07:00
from __future__ import annotations
from langchain . chains import GraphCypherQAChain
from langchain_community . graphs import Neo4jGraph
from langchain_core . documents import Document
from langchain_openai import OpenAIEmbeddings
from langchain_community . vectorstores import Neo4jVector
from envs import ACCESS_TOKEN_EXPIRE_MINUTES , ALGORITHM , API_SECRET_KEY , SECRET_KEY
2024-08-16 20:35:50 -07:00
from prompts import CYPHER_QA_PROMPT , DATE_TODAY , DOC_DESCRIPTION_PROMPT , GRAPH_QUERY_GEN_PROMPT , SIMILARITY_SEARCH_PROMPT , CYPHER_GENERATION_PROMPT , DOCUMENT_METADATA_EXTRACTION_PROMT
from pydmodels import DescriptionResponse , PrecisionQuery , PrecisionResponse , UserQuery , DocMeta , RetrivedDocList , UserQueryResponse , UserQueryWithChatHistory , VectorSearchQuery
2024-08-12 00:32:42 -07:00
from langchain_experimental . text_splitter import SemanticChunker
2024-08-16 20:35:50 -07:00
from langchain_core . prompts import ChatPromptTemplate , MessagesPlaceholder
from langchain_core . messages import HumanMessage , SystemMessage , AIMessage
2024-08-12 00:32:42 -07:00
#Our Imps
from LLMGraphTransformer import LLMGraphTransformer
from langchain_openai import ChatOpenAI
2024-08-13 19:56:54 -07:00
from DataExample import examples
2024-08-12 00:32:42 -07:00
# Auth Libs
from fastapi import FastAPI , Depends , HTTPException , Request , status
from sqlalchemy . orm import Session
from fastapi . security import OAuth2PasswordBearer , OAuth2PasswordRequestForm
from jose import JWTError , jwt
from datetime import datetime , timedelta
from passlib . context import CryptContext
from models import User
from database import SessionLocal , engine
from pydantic import BaseModel
from fastapi . middleware . cors import CORSMiddleware
app = FastAPI ( )
2024-08-16 20:35:50 -07:00
# GraphCypherQAChain
2024-08-12 00:32:42 -07:00
@app.post ( " / " )
def get_user_query_response ( data : UserQuery , response_model = UserQueryResponse ) :
if ( data . apisecretkey != API_SECRET_KEY ) :
raise HTTPException ( status_code = 401 , detail = " Unauthorized " )
graph = Neo4jGraph ( url = data . neourl , username = data . neouser , password = data . neopass )
llm = ChatOpenAI (
model = " gpt-4o-mini " ,
temperature = 0 ,
max_tokens = None ,
timeout = None ,
api_key = data . openaikey
)
2024-08-13 19:56:54 -07:00
# Query Expansion
searchchain = GRAPH_QUERY_GEN_PROMPT | llm
2024-08-16 20:35:50 -07:00
# qry = searchchain.invoke({"question": data.query, "context": examples})
2024-08-13 19:56:54 -07:00
2024-08-16 20:35:50 -07:00
query = data . query #qry.content
2024-08-13 19:56:54 -07:00
2024-08-12 00:32:42 -07:00
embeddings = OpenAIEmbeddings (
model = " text-embedding-ada-002 " ,
api_key = data . openaikey ,
)
chain = GraphCypherQAChain . from_llm (
graph = graph ,
cypher_prompt = CYPHER_GENERATION_PROMPT ,
cypher_llm = llm ,
verbose = True ,
validate_cypher = True ,
qa_prompt = CYPHER_QA_PROMPT ,
qa_llm = llm ,
return_intermediate_steps = True ,
top_k = 5 ,
)
vector_index = Neo4jVector . from_existing_graph (
embeddings ,
graph = graph ,
search_type = " hybrid " ,
node_label = " Document " ,
text_node_properties = [ " text " ] ,
embedding_node_property = " embedding " ,
)
2024-08-16 20:35:50 -07:00
docs = vector_index . similarity_search ( data . query , k = 5 )
2024-08-12 00:32:42 -07:00
docstoreturn = [ ]
for doc in docs :
docstoreturn . append (
DocMeta (
BrowsingSessionId = doc . metadata [ " BrowsingSessionId " ] if " BrowsingSessionId " in doc . metadata . keys ( ) else " NOT AVAILABLE " ,
VisitedWebPageURL = doc . metadata [ " VisitedWebPageURL " ] if " VisitedWebPageURL " in doc . metadata . keys ( ) else " NOT AVAILABLE " ,
VisitedWebPageTitle = doc . metadata [ " VisitedWebPageTitle " ] if " VisitedWebPageTitle " in doc . metadata . keys ( ) else " NOT AVAILABLE " ,
VisitedWebPageDateWithTimeInISOString = doc . metadata [ " VisitedWebPageDateWithTimeInISOString " ] if " VisitedWebPageDateWithTimeInISOString " in doc . metadata . keys ( ) else " NOT AVAILABLE " ,
VisitedWebPageReffererURL = doc . metadata [ " VisitedWebPageReffererURL " ] if " VisitedWebPageReffererURL " in doc . metadata . keys ( ) else " NOT AVAILABLE " ,
VisitedWebPageVisitDurationInMilliseconds = doc . metadata [ " VisitedWebPageVisitDurationInMilliseconds " ] if " VisitedWebPageVisitDurationInMilliseconds " in doc . metadata . keys ( ) else None ,
VisitedWebPageContent = doc . page_content if doc . page_content else " NOT AVAILABLE "
)
)
docstoreturn = [ i for n , i in enumerate ( docstoreturn ) if i not in docstoreturn [ n + 1 : ] ]
2024-08-13 19:56:54 -07:00
# responsegrp = chain.invoke({"query": query})
2024-08-12 00:32:42 -07:00
try :
2024-08-13 19:56:54 -07:00
responsegrp = chain . invoke ( { " query " : query } )
if " don ' t know " in responsegrp [ " result " ] :
2024-08-12 00:32:42 -07:00
raise Exception ( " No response from graph " )
2024-08-13 19:56:54 -07:00
structured_llm = llm . with_structured_output ( VectorSearchQuery )
2024-08-12 00:32:42 -07:00
doc_extract_chain = DOCUMENT_METADATA_EXTRACTION_PROMT | structured_llm
2024-08-13 19:56:54 -07:00
newquery = doc_extract_chain . invoke ( responsegrp [ " intermediate_steps " ] [ 1 ] [ " context " ] )
2024-08-12 00:32:42 -07:00
2024-08-13 19:56:54 -07:00
docs = vector_index . similarity_search ( newquery . searchquery , k = 5 )
2024-08-12 00:32:42 -07:00
docstoreturn = [ ]
for doc in docs :
docstoreturn . append (
DocMeta (
BrowsingSessionId = doc . metadata [ " BrowsingSessionId " ] if " BrowsingSessionId " in doc . metadata . keys ( ) else " NOT AVAILABLE " ,
VisitedWebPageURL = doc . metadata [ " VisitedWebPageURL " ] if " VisitedWebPageURL " in doc . metadata . keys ( ) else " NOT AVAILABLE " ,
VisitedWebPageTitle = doc . metadata [ " VisitedWebPageTitle " ] if " VisitedWebPageTitle " in doc . metadata . keys ( ) else " NOT AVAILABLE " ,
VisitedWebPageDateWithTimeInISOString = doc . metadata [ " VisitedWebPageDateWithTimeInISOString " ] if " VisitedWebPageDateWithTimeInISOString " in doc . metadata . keys ( ) else " NOT AVAILABLE " ,
VisitedWebPageReffererURL = doc . metadata [ " VisitedWebPageReffererURL " ] if " VisitedWebPageReffererURL " in doc . metadata . keys ( ) else " NOT AVAILABLE " ,
VisitedWebPageVisitDurationInMilliseconds = doc . metadata [ " VisitedWebPageVisitDurationInMilliseconds " ] if " VisitedWebPageVisitDurationInMilliseconds " in doc . metadata . keys ( ) else None ,
VisitedWebPageContent = doc . page_content if doc . page_content else " NOT AVAILABLE "
)
)
docstoreturn = [ i for n , i in enumerate ( docstoreturn ) if i not in docstoreturn [ n + 1 : ] ]
2024-08-13 19:56:54 -07:00
return UserQueryResponse ( relateddocs = docstoreturn , response = responsegrp [ " result " ] )
2024-08-12 00:32:42 -07:00
except :
# Fallback to Similarity Search RAG
searchchain = SIMILARITY_SEARCH_PROMPT | llm
2024-08-13 19:56:54 -07:00
response = searchchain . invoke ( { " question " : data . query , " context " : docs } )
2024-08-12 00:32:42 -07:00
return UserQueryResponse ( relateddocs = docstoreturn , response = response . content )
2024-08-16 20:35:50 -07:00
#RETURN n LIMIT 25;
@app.post ( " /precision " )
def get_precision_search_response ( data : PrecisionQuery , response_model = PrecisionResponse ) :
if ( data . apisecretkey != API_SECRET_KEY ) :
raise HTTPException ( status_code = 401 , detail = " Unauthorized " )
graph = Neo4jGraph ( url = data . neourl , username = data . neouser , password = data . neopass )
GRAPH_QUERY = " MATCH (d:Document) WHERE d.VisitedWebPageDateWithTimeInISOString >= " + " ' " + data . daterange [ 0 ] + " ' " + " AND d.VisitedWebPageDateWithTimeInISOString <= " + " ' " + data . daterange [ 1 ] + " ' "
if ( data . timerange [ 0 ] > = data . timerange [ 1 ] ) :
GRAPH_QUERY + = " AND d.VisitedWebPageVisitDurationInMilliseconds >= 0 "
else :
GRAPH_QUERY + = " AND d.VisitedWebPageVisitDurationInMilliseconds >= " + str ( data . timerange [ 0 ] ) + " AND d.VisitedWebPageVisitDurationInMilliseconds <= " + str ( data . timerange [ 1 ] )
if ( data . webpageurl ) :
GRAPH_QUERY + = " AND d.VisitedWebPageURL CONTAINS " + " ' " + data . webpageurl . lower ( ) + " ' "
if ( data . sessionid ) :
GRAPH_QUERY + = " AND d.BrowsingSessionId = " + " ' " + data . sessionid + " ' "
GRAPH_QUERY + = " RETURN d; "
graphdocs = graph . query ( GRAPH_QUERY )
docsDict = { }
for d in graphdocs :
if d [ ' d ' ] [ ' BrowsingSessionId ' ] not in docsDict :
docsDict [ d [ ' d ' ] [ ' BrowsingSessionId ' ] ] = d [ ' d ' ]
else :
docsDict [ d [ ' d ' ] [ ' BrowsingSessionId ' ] ] [ ' text ' ] + = d [ ' d ' ] [ ' text ' ]
docs = [ ]
for x in docsDict . values ( ) :
docs . append ( DocMeta (
BrowsingSessionId = x [ ' BrowsingSessionId ' ] ,
VisitedWebPageURL = x [ ' VisitedWebPageURL ' ] ,
VisitedWebPageVisitDurationInMilliseconds = x [ ' VisitedWebPageVisitDurationInMilliseconds ' ] ,
VisitedWebPageTitle = x [ ' VisitedWebPageTitle ' ] ,
VisitedWebPageReffererURL = x [ ' VisitedWebPageReffererURL ' ] ,
VisitedWebPageDateWithTimeInISOString = x [ ' VisitedWebPageDateWithTimeInISOString ' ] ,
VisitedWebPageContent = x [ ' text ' ]
) )
return PrecisionResponse ( documents = docs )
# Multi DOC Chat
@app.post ( " /chat/docs " )
def doc_chat_with_history ( data : UserQueryWithChatHistory , response_model = DescriptionResponse ) :
if ( data . apisecretkey != API_SECRET_KEY ) :
raise HTTPException ( status_code = 401 , detail = " Unauthorized " )
llm = ChatOpenAI (
model = " gpt-4o-mini " ,
temperature = 0 ,
max_tokens = None ,
timeout = None ,
api_key = data . openaikey
)
chatHistory = [ ]
for chat in data . chat :
if ( chat . type == ' system ' ) :
chatHistory . append ( SystemMessage ( content = DATE_TODAY + """ You are an helpful assistant for question-answering tasks.
Use the following pieces of retrieved context to answer the question .
If you don ' t know the answer, just say that you don ' t know .
Context : """ + str(chat.content)))
if ( chat . type == ' ai ' ) :
chatHistory . append ( AIMessage ( content = chat . content ) )
if ( chat . type == ' human ' ) :
chatHistory . append ( HumanMessage ( content = chat . content ) )
chatHistory . append ( ( " human " , " {input} " ) ) ;
2024-08-12 00:32:42 -07:00
2024-08-16 20:35:50 -07:00
qa_prompt = ChatPromptTemplate . from_messages ( chatHistory )
descriptionchain = qa_prompt | llm
response = descriptionchain . invoke ( { " input " : data . query } )
return DescriptionResponse ( response = response . content )
2024-08-12 00:32:42 -07:00
# DOC DESCRIPTION
@app.post ( " /kb/doc " )
def get_doc_description ( data : UserQuery , response_model = DescriptionResponse ) :
if ( data . apisecretkey != API_SECRET_KEY ) :
raise HTTPException ( status_code = 401 , detail = " Unauthorized " )
document = data . query
llm = ChatOpenAI (
model = " gpt-4o-mini " ,
temperature = 0 ,
max_tokens = None ,
timeout = None ,
api_key = data . openaikey
)
descriptionchain = DOC_DESCRIPTION_PROMPT | llm
response = descriptionchain . invoke ( { " document " : document } )
return DescriptionResponse ( response = response . content )
# SAVE DOCS TO GRAPH DB
@app.post ( " /kb/ " )
def populate_graph ( apires : RetrivedDocList ) :
if ( apires . apisecretkey != API_SECRET_KEY ) :
raise HTTPException ( status_code = 401 , detail = " Unauthorized " )
print ( " STARTED " )
# print(apires)
graph = Neo4jGraph ( url = apires . neourl , username = apires . neouser , password = apires . neopass )
llm = ChatOpenAI (
model = " gpt-4o-mini " ,
temperature = 0 ,
max_tokens = None ,
timeout = None ,
api_key = apires . openaikey
)
embeddings = OpenAIEmbeddings (
model = " text-embedding-ada-002 " ,
api_key = apires . openaikey ,
)
llm_transformer = LLMGraphTransformer ( llm = llm )
raw_documents = [ ]
for doc in apires . documents :
raw_documents . append ( Document ( page_content = doc . pageContent , metadata = doc . metadata ) )
text_splitter = SemanticChunker ( embeddings = embeddings )
documents = text_splitter . split_documents ( raw_documents )
graph_documents = llm_transformer . convert_to_graph_documents ( documents )
graph . add_graph_documents (
graph_documents ,
baseEntityLabel = True ,
include_source = True
)
print ( " FINISHED " )
return {
" success " : " Graph Will be populated Shortly "
}
#AUTH CODE
oauth2_scheme = OAuth2PasswordBearer ( tokenUrl = " token " )
# Recommended for Local Setups
# origins = [
# "http://localhost:3000", # Adjust the port if your frontend runs on a different one
# "https://yourfrontenddomain.com",
# ]
app . add_middleware (
CORSMiddleware ,
allow_origins = [ " * " ] , # Allows all origins from the list
allow_credentials = True ,
allow_methods = [ " * " ] , # Allows all methods
allow_headers = [ " * " ] , # Allows all headers
)
# Dependency
def get_db ( ) :
db = SessionLocal ( )
try :
yield db
finally :
db . close ( )
pwd_context = CryptContext ( schemes = [ " bcrypt " ] , deprecated = " auto " )
class UserCreate ( BaseModel ) :
username : str
password : str
apisecretkey : str
def get_user_by_username ( db : Session , username : str ) :
return db . query ( User ) . filter ( User . username == username ) . first ( )
def create_user ( db : Session , user : UserCreate ) :
hashed_password = pwd_context . hash ( user . password )
db_user = User ( username = user . username , hashed_password = hashed_password )
db . add ( db_user )
db . commit ( )
return " complete "
@app.post ( " /register " )
def register_user ( user : UserCreate , db : Session = Depends ( get_db ) ) :
if ( user . apisecretkey != API_SECRET_KEY ) :
raise HTTPException ( status_code = 401 , detail = " Unauthorized " )
db_user = get_user_by_username ( db , username = user . username )
if db_user :
raise HTTPException ( status_code = 400 , detail = " Username already registered " )
del user . apisecretkey
return create_user ( db = db , user = user )
# Authenticate the user
def authenticate_user ( username : str , password : str , db : Session ) :
user = db . query ( User ) . filter ( User . username == username ) . first ( )
if not user :
return False
if not pwd_context . verify ( password , user . hashed_password ) :
return False
return user
# Create access token
def create_access_token ( data : dict , expires_delta : timedelta | None = None ) :
to_encode = data . copy ( )
if expires_delta :
expire = datetime . utcnow ( ) + expires_delta
else :
expire = datetime . utcnow ( ) + timedelta ( minutes = 15 )
to_encode . update ( { " exp " : expire } )
encoded_jwt = jwt . encode ( to_encode , SECRET_KEY , algorithm = ALGORITHM )
return encoded_jwt
@app.post ( " /token " )
def login_for_access_token ( form_data : OAuth2PasswordRequestForm = Depends ( ) , db : Session = Depends ( get_db ) ) :
user = authenticate_user ( form_data . username , form_data . password , db )
if not user :
raise HTTPException (
status_code = status . HTTP_401_UNAUTHORIZED ,
detail = " Incorrect username or password " ,
headers = { " WWW-Authenticate " : " Bearer " } ,
)
access_token_expires = timedelta ( minutes = ACCESS_TOKEN_EXPIRE_MINUTES )
access_token = create_access_token (
data = { " sub " : user . username } , expires_delta = access_token_expires
)
return { " access_token " : access_token , " token_type " : " bearer " }
def verify_token ( token : str = Depends ( oauth2_scheme ) ) :
try :
payload = jwt . decode ( token , SECRET_KEY , algorithms = [ ALGORITHM ] )
username : str = payload . get ( " sub " )
if username is None :
raise HTTPException ( status_code = 403 , detail = " Token is invalid or expired " )
return payload
except JWTError :
raise HTTPException ( status_code = 403 , detail = " Token is invalid or expired " )
@app.get ( " /verify-token/ {token} " )
async def verify_user_token ( token : str ) :
verify_token ( token = token )
return { " message " : " Token is valid " }
if __name__ == " __main__ " :
import uvicorn
uvicorn . run ( app , host = " 127.0.0.1 " , port = 8000 )