This commit is contained in:
DESKTOP-RTLN3BA\$punk 2024-08-12 00:32:42 -07:00
parent 24f641347d
commit 63146aa9b7
86 changed files with 18766 additions and 0 deletions

12
backend/.gitignore vendored Normal file
View file

@ -0,0 +1,12 @@
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
__pycache__
__pycache__/
.__pycache__

View file

@ -0,0 +1,830 @@
import asyncio
import json
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union, cast
from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import SystemMessage
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
PromptTemplate,
)
from langchain_core.pydantic_v1 import BaseModel, Field, create_model
from langchain_core.runnables import RunnableConfig
examples = [
{
"text": (
"Adam is a software engineer in Microsoft since 2009, "
"and last year he got an award as the Best Talent"
),
"head": "Adam",
"head_type": "Person",
"relation": "WORKS_FOR",
"tail": "Microsoft",
"tail_type": "Company",
},
{
"text": (
"Adam is a software engineer in Microsoft since 2009, "
"and last year he got an award as the Best Talent"
),
"head": "Adam",
"head_type": "Person",
"relation": "HAS_AWARD",
"tail": "Best Talent",
"tail_type": "Award",
},
{
"text": (
"Microsoft is a tech company that provide "
"several products such as Microsoft Word"
),
"head": "Microsoft Word",
"head_type": "Product",
"relation": "PRODUCED_BY",
"tail": "Microsoft",
"tail_type": "Company",
},
{
"text": "Microsoft Word is a lightweight app that accessible offline",
"head": "Microsoft Word",
"head_type": "Product",
"relation": "HAS_CHARACTERISTIC",
"tail": "lightweight app",
"tail_type": "Characteristic",
},
{
"text": "Microsoft Word is a lightweight app that accessible offline",
"head": "Microsoft Word",
"head_type": "Product",
"relation": "HAS_CHARACTERISTIC",
"tail": "accessible offline",
"tail_type": "Characteristic",
},
]
system_prompt = (
"# Knowledge Graph Instructions for GPT-4\n"
"## 1. Overview\n"
"You are a top-tier algorithm designed for extracting information in structured "
"formats to build a knowledge graph.\n"
"Try to capture as much information from the text as possible without "
"sacrificing accuracy. Do not add any information that is not explicitly "
"mentioned in the text.\n"
"- **Nodes** represent entities and concepts.\n"
"- The aim is to achieve simplicity and clarity in the knowledge graph, making it\n"
"accessible for a vast audience.\n"
"## 2. Labeling Nodes\n"
"- **Consistency**: Ensure you use available types for node labels.\n"
"Ensure you use basic or elementary types for node labels.\n"
"- For example, when you identify an entity representing a person, "
"always label it as **'person'**. Avoid using more specific terms "
"like 'mathematician' or 'scientist'."
"- **Node IDs**: Never utilize integers as node IDs. Node IDs should be "
"names or human-readable identifiers found in the text.\n"
"- **Relationships** represent connections between entities or concepts.\n"
"Ensure consistency and generality in relationship types when constructing "
"knowledge graphs. Instead of using specific and momentary types "
"such as 'BECAME_PROFESSOR', use more general and timeless relationship types "
"like 'PROFESSOR'. Make sure to use general and timeless relationship types!\n"
"## 3. Coreference Resolution\n"
"- **Maintain Entity Consistency**: When extracting entities, it's vital to "
"ensure consistency.\n"
'If an entity, such as "John Doe", is mentioned multiple times in the text '
'but is referred to by different names or pronouns (e.g., "Joe", "he"),'
"always use the most complete identifier for that entity throughout the "
'knowledge graph. In this example, use "John Doe" as the entity ID.\n'
"Remember, the knowledge graph should be coherent and easily understandable, "
"so maintaining consistency in entity references is crucial.\n"
"## 4. Strict Compliance\n"
"Adhere to the rules strictly. Non-compliance will result in termination."
)
default_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
system_prompt,
),
(
"human",
(
"Note: The information given to you is information about a User's Web Browsing History."
"Tip: Make sure to answer in the correct format and do "
"not include any explanations. "
"Use the given format to extract information from the "
"following input: {input}"
),
),
]
)
def _get_additional_info(input_type: str) -> str:
# Check if the input_type is one of the allowed values
if input_type not in ["node", "relationship", "property"]:
raise ValueError("input_type must be 'node', 'relationship', or 'property'")
# Perform actions based on the input_type
if input_type == "node":
return (
"Ensure you use basic or elementary types for node labels.\n"
"For example, when you identify an entity representing a person, "
"always label it as **'Person'**. Avoid using more specific terms "
"like 'Mathematician' or 'Scientist'"
)
elif input_type == "relationship":
return (
"Instead of using specific and momentary types such as "
"'BECAME_PROFESSOR', use more general and timeless relationship types "
"like 'PROFESSOR'. However, do not sacrifice any accuracy for generality"
)
elif input_type == "property":
return ""
return ""
def optional_enum_field(
enum_values: Optional[List[str]] = None,
description: str = "",
input_type: str = "node",
llm_type: Optional[str] = None,
**field_kwargs: Any,
) -> Any:
"""Utility function to conditionally create a field with an enum constraint."""
# Only openai supports enum param
if enum_values and llm_type == "openai-chat":
return Field(
...,
enum=enum_values,
description=f"{description}. Available options are {enum_values}",
**field_kwargs,
)
elif enum_values:
return Field(
...,
description=f"{description}. Available options are {enum_values}",
**field_kwargs,
)
else:
additional_info = _get_additional_info(input_type)
return Field(..., description=description + additional_info, **field_kwargs)
class _Graph(BaseModel):
nodes: Optional[List]
relationships: Optional[List]
class UnstructuredRelation(BaseModel):
head: str = Field(
description=(
"extracted head entity like Microsoft, Apple, John. "
"Must use human-readable unique identifier."
)
)
head_type: str = Field(
description="type of the extracted head entity like Person, Company, etc"
)
relation: str = Field(description="relation between the head and the tail entities")
tail: str = Field(
description=(
"extracted tail entity like Microsoft, Apple, John. "
"Must use human-readable unique identifier."
)
)
tail_type: str = Field(
description="type of the extracted tail entity like Person, Company, etc"
)
def create_unstructured_prompt(
node_labels: Optional[List[str]] = None, rel_types: Optional[List[str]] = None
) -> ChatPromptTemplate:
node_labels_str = str(node_labels) if node_labels else ""
rel_types_str = str(rel_types) if rel_types else ""
base_string_parts = [
"You are a top-tier algorithm designed for extracting information in "
"structured formats to build a knowledge graph. Your task is to identify "
"the entities and relations requested with the user prompt from a given "
"text. You must generate the output in a JSON format containing a list "
'with JSON objects. Each object should have the keys: "head", '
'"head_type", "relation", "tail", and "tail_type". The "head" '
"key must contain the text of the extracted entity with one of the types "
"from the provided list in the user prompt.",
f'The "head_type" key must contain the type of the extracted head entity, '
f"which must be one of the types from {node_labels_str}."
if node_labels
else "",
f'The "relation" key must contain the type of relation between the "head" '
f'and the "tail", which must be one of the relations from {rel_types_str}.'
if rel_types
else "",
f'The "tail" key must represent the text of an extracted entity which is '
f'the tail of the relation, and the "tail_type" key must contain the type '
f"of the tail entity from {node_labels_str}."
if node_labels
else "",
"Attempt to extract as many entities and relations as you can. Maintain "
"Entity Consistency: When extracting entities, it's vital to ensure "
'consistency. If an entity, such as "John Doe", is mentioned multiple '
"times in the text but is referred to by different names or pronouns "
'(e.g., "Joe", "he"), always use the most complete identifier for '
"that entity. The knowledge graph should be coherent and easily "
"understandable, so maintaining consistency in entity references is "
"crucial.",
"IMPORTANT NOTES:\n- Don't add any explanation and text.",
]
system_prompt = "\n".join(filter(None, base_string_parts))
system_message = SystemMessage(content=system_prompt)
parser = JsonOutputParser(pydantic_object=UnstructuredRelation)
human_string_parts = [
"Based on the following example, extract entities and "
"relations from the provided text.\n\n",
"Use the following entity types, don't use other entity "
"that is not defined below:"
"# ENTITY TYPES:"
"{node_labels}"
if node_labels
else "",
"Use the following relation types, don't use other relation "
"that is not defined below:"
"# RELATION TYPES:"
"{rel_types}"
if rel_types
else "",
"Below are a number of examples of text and their extracted "
"entities and relationships."
"{examples}\n"
"For the following text, extract entities and relations as "
"in the provided example."
"{format_instructions}\nText: {input}",
]
human_prompt_string = "\n".join(filter(None, human_string_parts))
human_prompt = PromptTemplate(
template=human_prompt_string,
input_variables=["input"],
partial_variables={
"format_instructions": parser.get_format_instructions(),
"node_labels": node_labels,
"rel_types": rel_types,
"examples": examples,
},
)
human_message_prompt = HumanMessagePromptTemplate(prompt=human_prompt)
chat_prompt = ChatPromptTemplate.from_messages(
[system_message, human_message_prompt]
)
return chat_prompt
def create_simple_model(
node_labels: Optional[List[str]] = None,
rel_types: Optional[List[str]] = None,
node_properties: Union[bool, List[str]] = False,
llm_type: Optional[str] = None,
relationship_properties: Union[bool, List[str]] = False,
) -> Type[_Graph]:
"""
Create a simple graph model with optional constraints on node
and relationship types.
Args:
node_labels (Optional[List[str]]): Specifies the allowed node types.
Defaults to None, allowing all node types.
rel_types (Optional[List[str]]): Specifies the allowed relationship types.
Defaults to None, allowing all relationship types.
node_properties (Union[bool, List[str]]): Specifies if node properties should
be included. If a list is provided, only properties with keys in the list
will be included. If True, all properties are included. Defaults to False.
relationship_properties (Union[bool, List[str]]): Specifies if relationship
properties should be included. If a list is provided, only properties with
keys in the list will be included. If True, all properties are included.
Defaults to False.
llm_type (Optional[str]): The type of the language model. Defaults to None.
Only openai supports enum param: openai-chat.
Returns:
Type[_Graph]: A graph model with the specified constraints.
Raises:
ValueError: If 'id' is included in the node or relationship properties list.
"""
node_fields: Dict[str, Tuple[Any, Any]] = {
"id": (
str,
Field(..., description="Name or human-readable unique identifier."),
),
"type": (
str,
optional_enum_field(
node_labels,
description="The type or label of the node.",
input_type="node",
llm_type=llm_type,
),
),
}
if node_properties:
if isinstance(node_properties, list) and "id" in node_properties:
raise ValueError("The node property 'id' is reserved and cannot be used.")
# Map True to empty array
node_properties_mapped: List[str] = (
[] if node_properties is True else node_properties
)
class Property(BaseModel):
"""A single property consisting of key and value"""
key: str = optional_enum_field(
node_properties_mapped,
description="Property key.",
input_type="property",
llm_type=llm_type,
)
value: str = Field(..., description="value")
node_fields["properties"] = (
Optional[List[Property]],
Field(None, description="List of node properties"),
)
SimpleNode = create_model("SimpleNode", **node_fields) # type: ignore
relationship_fields: Dict[str, Tuple[Any, Any]] = {
"source_node_id": (
str,
Field(
...,
description="Name or human-readable unique identifier of source node",
),
),
"source_node_type": (
str,
optional_enum_field(
node_labels,
description="The type or label of the source node.",
input_type="node",
llm_type=llm_type,
),
),
"target_node_id": (
str,
Field(
...,
description="Name or human-readable unique identifier of target node",
),
),
"target_node_type": (
str,
optional_enum_field(
node_labels,
description="The type or label of the target node.",
input_type="node",
llm_type=llm_type,
),
),
"type": (
str,
optional_enum_field(
rel_types,
description="The type of the relationship.",
input_type="relationship",
llm_type=llm_type,
),
),
}
if relationship_properties:
if (
isinstance(relationship_properties, list)
and "id" in relationship_properties
):
raise ValueError(
"The relationship property 'id' is reserved and cannot be used."
)
# Map True to empty array
relationship_properties_mapped: List[str] = (
[] if relationship_properties is True else relationship_properties
)
class RelationshipProperty(BaseModel):
"""A single property consisting of key and value"""
key: str = optional_enum_field(
relationship_properties_mapped,
description="Property key.",
input_type="property",
llm_type=llm_type,
)
value: str = Field(..., description="value")
relationship_fields["properties"] = (
Optional[List[RelationshipProperty]],
Field(None, description="List of relationship properties"),
)
SimpleRelationship = create_model("SimpleRelationship", **relationship_fields) # type: ignore
class DynamicGraph(_Graph):
"""Represents a graph document consisting of nodes and relationships."""
nodes: Optional[List[SimpleNode]] = Field(description="List of nodes") # type: ignore
relationships: Optional[List[SimpleRelationship]] = Field( # type: ignore
description="List of relationships"
)
return DynamicGraph
def map_to_base_node(node: Any) -> Node:
"""Map the SimpleNode to the base Node."""
properties = {}
if hasattr(node, "properties") and node.properties:
for p in node.properties:
properties[format_property_key(p.key)] = p.value
return Node(id=node.id, type=node.type, properties=properties)
def map_to_base_relationship(rel: Any) -> Relationship:
"""Map the SimpleRelationship to the base Relationship."""
source = Node(id=rel.source_node_id, type=rel.source_node_type)
target = Node(id=rel.target_node_id, type=rel.target_node_type)
properties = {}
if hasattr(rel, "properties") and rel.properties:
for p in rel.properties:
properties[format_property_key(p.key)] = p.value
return Relationship(
source=source, target=target, type=rel.type, properties=properties
)
def _parse_and_clean_json(
argument_json: Dict[str, Any],
) -> Tuple[List[Node], List[Relationship]]:
nodes = []
for node in argument_json["nodes"]:
if not node.get("id"): # Id is mandatory, skip this node
continue
node_properties = {}
if "properties" in node and node["properties"]:
for p in node["properties"]:
node_properties[format_property_key(p["key"])] = p["value"]
nodes.append(
Node(
id=node["id"],
type=node.get("type"),
properties=node_properties,
)
)
relationships = []
for rel in argument_json["relationships"]:
# Mandatory props
if (
not rel.get("source_node_id")
or not rel.get("target_node_id")
or not rel.get("type")
):
continue
# Node type copying if needed from node list
if not rel.get("source_node_type"):
try:
rel["source_node_type"] = [
el.get("type")
for el in argument_json["nodes"]
if el["id"] == rel["source_node_id"]
][0]
except IndexError:
rel["source_node_type"] = None
if not rel.get("target_node_type"):
try:
rel["target_node_type"] = [
el.get("type")
for el in argument_json["nodes"]
if el["id"] == rel["target_node_id"]
][0]
except IndexError:
rel["target_node_type"] = None
rel_properties = {}
if "properties" in rel and rel["properties"]:
for p in rel["properties"]:
rel_properties[format_property_key(p["key"])] = p["value"]
source_node = Node(
id=rel["source_node_id"],
type=rel["source_node_type"],
)
target_node = Node(
id=rel["target_node_id"],
type=rel["target_node_type"],
)
relationships.append(
Relationship(
source=source_node,
target=target_node,
type=rel["type"],
properties=rel_properties,
)
)
return nodes, relationships
def _format_nodes(nodes: List[Node]) -> List[Node]:
return [
Node(
id=el.id.title() if isinstance(el.id, str) else el.id,
type=el.type.capitalize() # type: ignore[arg-type]
if el.type
else None, # handle empty strings # type: ignore[arg-type]
properties=el.properties,
)
for el in nodes
]
def _format_relationships(rels: List[Relationship]) -> List[Relationship]:
return [
Relationship(
source=_format_nodes([el.source])[0],
target=_format_nodes([el.target])[0],
type=el.type.replace(" ", "_").upper(),
properties=el.properties,
)
for el in rels
]
def format_property_key(s: str) -> str:
words = s.split()
if not words:
return s
first_word = words[0].lower()
capitalized_words = [word.capitalize() for word in words[1:]]
return "".join([first_word] + capitalized_words)
def _convert_to_graph_document(
raw_schema: Dict[Any, Any],
) -> Tuple[List[Node], List[Relationship]]:
# If there are validation errors
if not raw_schema["parsed"]:
try:
try: # OpenAI type response
argument_json = json.loads(
raw_schema["raw"].additional_kwargs["tool_calls"][0]["function"][
"arguments"
]
)
except Exception: # Google type response
argument_json = json.loads(
raw_schema["raw"].additional_kwargs["function_call"]["arguments"]
)
nodes, relationships = _parse_and_clean_json(argument_json)
except Exception: # If we can't parse JSON
return ([], [])
else: # If there are no validation errors use parsed pydantic object
parsed_schema: _Graph = raw_schema["parsed"]
nodes = (
[map_to_base_node(node) for node in parsed_schema.nodes if node.id]
if parsed_schema.nodes
else []
)
relationships = (
[
map_to_base_relationship(rel)
for rel in parsed_schema.relationships
if rel.type and rel.source_node_id and rel.target_node_id
]
if parsed_schema.relationships
else []
)
# Title / Capitalize
return _format_nodes(nodes), _format_relationships(relationships)
class LLMGraphTransformer:
"""Transform documents into graph-based documents using a LLM.
It allows specifying constraints on the types of nodes and relationships to include
in the output graph. The class supports extracting properties for both nodes and
relationships.
Args:
llm (BaseLanguageModel): An instance of a language model supporting structured
output.
allowed_nodes (List[str], optional): Specifies which node types are
allowed in the graph. Defaults to an empty list, allowing all node types.
allowed_relationships (List[str], optional): Specifies which relationship types
are allowed in the graph. Defaults to an empty list, allowing all relationship
types.
prompt (Optional[ChatPromptTemplate], optional): The prompt to pass to
the LLM with additional instructions.
strict_mode (bool, optional): Determines whether the transformer should apply
filtering to strictly adhere to `allowed_nodes` and `allowed_relationships`.
Defaults to True.
node_properties (Union[bool, List[str]]): If True, the LLM can extract any
node properties from text. Alternatively, a list of valid properties can
be provided for the LLM to extract, restricting extraction to those specified.
relationship_properties (Union[bool, List[str]]): If True, the LLM can extract
any relationship properties from text. Alternatively, a list of valid
properties can be provided for the LLM to extract, restricting extraction to
those specified.
Example:
.. code-block:: python
from langchain_experimental.graph_transformers import LLMGraphTransformer
from langchain_core.documents import Document
from langchain_openai import ChatOpenAI
llm=ChatOpenAI(temperature=0)
transformer = LLMGraphTransformer(
llm=llm,
allowed_nodes=["Person", "Organization"])
doc = Document(page_content="Elon Musk is suing OpenAI")
graph_documents = transformer.convert_to_graph_documents([doc])
"""
def __init__(
self,
llm: BaseLanguageModel,
allowed_nodes: List[str] = [],
allowed_relationships: List[str] = [],
prompt: Optional[ChatPromptTemplate] = None,
strict_mode: bool = True,
node_properties: Union[bool, List[str]] = False,
relationship_properties: Union[bool, List[str]] = False,
) -> None:
self.allowed_nodes = allowed_nodes
self.allowed_relationships = allowed_relationships
self.strict_mode = strict_mode
self._function_call = True
# Check if the LLM really supports structured output
try:
llm.with_structured_output(_Graph)
except NotImplementedError:
self._function_call = False
if not self._function_call:
if node_properties or relationship_properties:
raise ValueError(
"The 'node_properties' and 'relationship_properties' parameters "
"cannot be used in combination with a LLM that doesn't support "
"native function calling."
)
try:
import json_repair # type: ignore
self.json_repair = json_repair
except ImportError:
raise ImportError(
"Could not import json_repair python package. "
"Please install it with `pip install json-repair`."
)
prompt = prompt or create_unstructured_prompt(
allowed_nodes, allowed_relationships
)
self.chain = prompt | llm
else:
# Define chain
try:
llm_type = llm._llm_type # type: ignore
except AttributeError:
llm_type = None
schema = create_simple_model(
allowed_nodes,
allowed_relationships,
node_properties,
llm_type,
relationship_properties,
)
structured_llm = llm.with_structured_output(schema, include_raw=True)
prompt = prompt or default_prompt
self.chain = prompt | structured_llm
def process_response(
self, document: Document, config: Optional[RunnableConfig] = None
) -> GraphDocument:
"""
Processes a single document, transforming it into a graph document using
an LLM based on the model's schema and constraints.
"""
text = document.page_content
raw_schema = self.chain.invoke({"input": text}, config=config)
if self._function_call:
raw_schema = cast(Dict[Any, Any], raw_schema)
nodes, relationships = _convert_to_graph_document(raw_schema)
else:
nodes_set = set()
relationships = []
if not isinstance(raw_schema, str):
raw_schema = raw_schema.content
parsed_json = self.json_repair.loads(raw_schema)
for rel in parsed_json:
# Nodes need to be deduplicated using a set
nodes_set.add((rel["head"], rel["head_type"]))
nodes_set.add((rel["tail"], rel["tail_type"]))
source_node = Node(id=rel["head"], type=rel["head_type"])
target_node = Node(id=rel["tail"], type=rel["tail_type"])
relationships.append(
Relationship(
source=source_node, target=target_node, type=rel["relation"]
)
)
# Create nodes list
nodes = [Node(id=el[0], type=el[1]) for el in list(nodes_set)]
# Strict mode filtering
if self.strict_mode and (self.allowed_nodes or self.allowed_relationships):
if self.allowed_nodes:
lower_allowed_nodes = [el.lower() for el in self.allowed_nodes]
nodes = [
node for node in nodes if node.type.lower() in lower_allowed_nodes
]
relationships = [
rel
for rel in relationships
if rel.source.type.lower() in lower_allowed_nodes
and rel.target.type.lower() in lower_allowed_nodes
]
if self.allowed_relationships:
relationships = [
rel
for rel in relationships
if rel.type.lower()
in [el.lower() for el in self.allowed_relationships]
]
return GraphDocument(nodes=nodes, relationships=relationships, source=document)
def convert_to_graph_documents(
self, documents: Sequence[Document], config: Optional[RunnableConfig] = None
) -> List[GraphDocument]:
"""Convert a sequence of documents into graph documents.
Args:
documents (Sequence[Document]): The original documents.
**kwargs: Additional keyword arguments.
Returns:
Sequence[GraphDocument]: The transformed documents as graphs.
"""
return [self.process_response(document, config) for document in documents]
async def aprocess_response(
self, document: Document, config: Optional[RunnableConfig] = None
) -> GraphDocument:
"""
Asynchronously processes a single document, transforming it into a
graph document.
"""
text = document.page_content
raw_schema = await self.chain.ainvoke({"input": text}, config=config)
raw_schema = cast(Dict[Any, Any], raw_schema)
nodes, relationships = _convert_to_graph_document(raw_schema)
if self.strict_mode and (self.allowed_nodes or self.allowed_relationships):
if self.allowed_nodes:
lower_allowed_nodes = [el.lower() for el in self.allowed_nodes]
nodes = [
node for node in nodes if node.type.lower() in lower_allowed_nodes
]
relationships = [
rel
for rel in relationships
if rel.source.type.lower() in lower_allowed_nodes
and rel.target.type.lower() in lower_allowed_nodes
]
if self.allowed_relationships:
relationships = [
rel
for rel in relationships
if rel.type.lower()
in [el.lower() for el in self.allowed_relationships]
]
return GraphDocument(nodes=nodes, relationships=relationships, source=document)
async def aconvert_to_graph_documents(
self, documents: Sequence[Document], config: Optional[RunnableConfig] = None
) -> List[GraphDocument]:
"""
Asynchronously convert a sequence of documents into graph documents.
"""
tasks = [
asyncio.create_task(self.aprocess_response(document, config))
for document in documents
]
results = await asyncio.gather(*tasks)
return results

14
backend/database.py Normal file
View file

@ -0,0 +1,14 @@
from sqlalchemy import create_engine, Column, Integer, String
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from envs import POSTGRES_DATABASE_URL
engine = create_engine(
POSTGRES_DATABASE_URL
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()

10
backend/envs.py Normal file
View file

@ -0,0 +1,10 @@
#POSTGRES DB TO TRACK USERS
POSTGRES_DATABASE_URL = "postgresql+psycopg2://postgres:postgres@localhost:5432/surfsense"
# API KEY TO VERIFY
API_SECRET_KEY = "surfsense"
# Your JWT secret and algorithm
SECRET_KEY = "your_secret_key"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 720

14
backend/models.py Normal file
View file

@ -0,0 +1,14 @@
from sqlalchemy import Column, Integer, String
from database import Base
from database import engine
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True, index=True)
username = Column(String, unique=True, index=True)
hashed_password = Column(String)
# Create the database tables if they don't exist
User.metadata.create_all(bind=engine)

84
backend/prompts.py Normal file
View file

@ -0,0 +1,84 @@
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts import ChatPromptTemplate
from datetime import datetime, timezone
DATE_TODAY = "Today's date is " + datetime.now(timezone.utc).astimezone().isoformat() + '\n'
CYPHER_QA_TEMPLATE = DATE_TODAY + """You are an assistant that helps to form nice and human understandable answers.
The information part contains the provided information that you must use to construct an answer.
The provided information is authoritative, you must never doubt it or try to use your internal knowledge to correct it.
Make the answer sound as a response to the question. Do not mention that you based the result on the given information.
Here are the examples:
Question: Website on which the most time was spend on?
Context:[{'d.VisitedWebPageURL': 'https://stackoverflow.com/questions/59873698/the-default-export-is-not-a-react-component-in-page-nextjs', 'totalDuration': 8889167}]
Helpful Answer: You visited https://stackoverflow.com/questions/59873698/the-default-export-is-not-a-react-component-in-page-nextjs for 8889167 milliseconds or 8889.167 seconds.
Follow this example when generating answers.
If the provided information is empty, then and only then, return exactly 'don't know' as answer.
Information:
{context}
Question: {question}
Helpful Answer:"""
CYPHER_QA_PROMPT = PromptTemplate(
input_variables=["context", "question"], template=CYPHER_QA_TEMPLATE
)
SIMILARITY_SEARCH_RAG = DATE_TODAY + """You are an assistant for question-answering tasks.
Use the following pieces of retrieved context to answer the question.
If you don't know the answer, return exactly 'don't know' as answer.
Question: {question}
Context: {context}
Answer:"""
SIMILARITY_SEARCH_PROMPT = PromptTemplate(
input_variables=["context", "question"], template=SIMILARITY_SEARCH_RAG
)
# doc_extract_chain = DOCUMENT_METADATA_EXTRACTION_PROMT | structured_llm
CYPHER_GENERATION_TEMPLATE = DATE_TODAY + """Task:Generate Cypher statement to query a graph database.
Instructions:
Use only the provided relationship types and properties in the schema.
Do not use any other relationship types or properties that are not provided.
Schema:
{schema}
Note: Do not include any explanations or apologies in your responses.
Do not respond to any questions that might ask anything else than for you to construct a Cypher statement.
Do not include any text except the generated Cypher statement.
The question is:
{question}"""
CYPHER_GENERATION_PROMPT = PromptTemplate(
input_variables=["schema", "question"], template=CYPHER_GENERATION_TEMPLATE
)
DOC_DESCRIPTION_TEMPLATE = """Task:Give Detailed Description of the page content of the given document.
Instructions:
Provide as much details about metadata & page content as if you need to give human readable report of this Browsing session event.
Document:
{document}
"""
DOC_DESCRIPTION_PROMPT = PromptTemplate(
input_variables=["document"], template=DOC_DESCRIPTION_TEMPLATE
)
DOCUMENT_METADATA_EXTRACTION_SYSTEM_MESSAGE = DATE_TODAY + """You are a helpful assistant. You are given a Cypher statement result after quering the Neo4j graph database.
Generate a very good Query that can be used to perform similarity search on the vector store of the Neo4j graph database"""
DOCUMENT_METADATA_EXTRACTION_PROMT = ChatPromptTemplate.from_messages([("system", DOCUMENT_METADATA_EXTRACTION_SYSTEM_MESSAGE), ("human", "{input}")])

41
backend/pydmodels.py Normal file
View file

@ -0,0 +1,41 @@
from pydantic import BaseModel, Field
from typing import List, Optional
class UserQuery(BaseModel):
query: str
neourl: str
neouser: str
neopass: str
openaikey: str
apisecretkey: str
class DescriptionResponse(BaseModel):
response: str
class DocMeta(BaseModel):
BrowsingSessionId: Optional[str] = Field(default=None, description="BrowsingSessionId of Document")
VisitedWebPageURL: Optional[str] = Field(default=None, description="VisitedWebPageURL of Document")
VisitedWebPageTitle: Optional[str] = Field(default=None, description="VisitedWebPageTitle of Document")
VisitedWebPageDateWithTimeInISOString: Optional[str] = Field(default=None, description="VisitedWebPageDateWithTimeInISOString of Document")
VisitedWebPageReffererURL: Optional[str] = Field(default=None, description="VisitedWebPageReffererURL of Document")
VisitedWebPageVisitDurationInMilliseconds: Optional[int] = Field(default=None, description="VisitedWebPageVisitDurationInMilliseconds of Document"),
VisitedWebPageContent: Optional[str] = Field(default=None, description="Visited WebPage Content in markdown of Document")
class RetrivedDocListItem(BaseModel):
metadata: DocMeta
pageContent: str
class RetrivedDocList(BaseModel):
documents: List[RetrivedDocListItem]
neourl: str
neouser: str
neopass: str
openaikey: str
apisecretkey: str
class UserQueryResponse(BaseModel):
response: str
relateddocs: List[DocMeta]

8
backend/requirements.txt Normal file
View file

@ -0,0 +1,8 @@
bcrypt
cryptography
fastapi
python-jose
python-multipart
SQLAlchemy
uvicorn
passlib

322
backend/server.py Normal file
View file

@ -0,0 +1,322 @@
from __future__ import annotations
from langchain.chains import GraphCypherQAChain
from langchain_community.graphs import Neo4jGraph
from langchain_core.documents import Document
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import Neo4jVector
from envs import ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM, API_SECRET_KEY, SECRET_KEY
from prompts import CYPHER_QA_PROMPT, DOC_DESCRIPTION_PROMPT, SIMILARITY_SEARCH_PROMPT , CYPHER_GENERATION_PROMPT, DOCUMENT_METADATA_EXTRACTION_PROMT
from pydmodels import DescriptionResponse, UserQuery, DocMeta, RetrivedDocList, UserQueryResponse
from langchain_experimental.text_splitter import SemanticChunker
#Our Imps
from LLMGraphTransformer import LLMGraphTransformer
from langchain_openai import ChatOpenAI
# Auth Libs
from fastapi import FastAPI, Depends, HTTPException, Request, status
from sqlalchemy.orm import Session
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt
from datetime import datetime, timedelta
from passlib.context import CryptContext
from models import User
from database import SessionLocal, engine
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
@app.post("/")
def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
if(data.apisecretkey != API_SECRET_KEY):
raise HTTPException(status_code=401, detail="Unauthorized")
query = data.query
graph = Neo4jGraph(url=data.neourl, username=data.neouser, password=data.neopass)
llm = ChatOpenAI(
model="gpt-4o-mini",
temperature=0,
max_tokens=None,
timeout=None,
api_key=data.openaikey
)
embeddings = OpenAIEmbeddings(
model="text-embedding-ada-002",
api_key=data.openaikey,
)
chain = GraphCypherQAChain.from_llm(
graph=graph,
cypher_prompt=CYPHER_GENERATION_PROMPT,
cypher_llm=llm,
verbose=True,
validate_cypher=True,
qa_prompt=CYPHER_QA_PROMPT ,
qa_llm=llm,
return_intermediate_steps=True,
top_k=5,
)
vector_index = Neo4jVector.from_existing_graph(
embeddings,
graph=graph,
search_type="hybrid",
node_label="Document",
text_node_properties=["text"],
embedding_node_property="embedding",
)
docs = vector_index.similarity_search(query,k=5)
docstoreturn = []
for doc in docs:
docstoreturn.append(
DocMeta(
BrowsingSessionId=doc.metadata["BrowsingSessionId"] if "BrowsingSessionId" in doc.metadata.keys() else "NOT AVAILABLE",
VisitedWebPageURL=doc.metadata["VisitedWebPageURL"] if "VisitedWebPageURL" in doc.metadata.keys()else "NOT AVAILABLE",
VisitedWebPageTitle=doc.metadata["VisitedWebPageTitle"] if "VisitedWebPageTitle" in doc.metadata.keys() else "NOT AVAILABLE",
VisitedWebPageDateWithTimeInISOString= doc.metadata["VisitedWebPageDateWithTimeInISOString"] if "VisitedWebPageDateWithTimeInISOString" in doc.metadata.keys() else "NOT AVAILABLE",
VisitedWebPageReffererURL= doc.metadata["VisitedWebPageReffererURL"] if "VisitedWebPageReffererURL" in doc.metadata.keys() else "NOT AVAILABLE",
VisitedWebPageVisitDurationInMilliseconds= doc.metadata["VisitedWebPageVisitDurationInMilliseconds"] if "VisitedWebPageVisitDurationInMilliseconds" in doc.metadata.keys() else None,
VisitedWebPageContent= doc.page_content if doc.page_content else "NOT AVAILABLE"
)
)
docstoreturn = [i for n, i in enumerate(docstoreturn) if i not in docstoreturn[n + 1:]]
try:
response = chain.invoke({"query": query})
if "don't know" in response["result"]:
raise Exception("No response from graph")
structured_llm = llm.with_structured_output(RetrivedDocList)
doc_extract_chain = DOCUMENT_METADATA_EXTRACTION_PROMT | structured_llm
query = doc_extract_chain.invoke(response["intermediate_steps"][1]["context"])
docs = vector_index.similarity_search(query.searchquery,k=5)
docstoreturn = []
for doc in docs:
docstoreturn.append(
DocMeta(
BrowsingSessionId=doc.metadata["BrowsingSessionId"] if "BrowsingSessionId" in doc.metadata.keys() else "NOT AVAILABLE",
VisitedWebPageURL=doc.metadata["VisitedWebPageURL"] if "VisitedWebPageURL" in doc.metadata.keys()else "NOT AVAILABLE",
VisitedWebPageTitle=doc.metadata["VisitedWebPageTitle"] if "VisitedWebPageTitle" in doc.metadata.keys() else "NOT AVAILABLE",
VisitedWebPageDateWithTimeInISOString= doc.metadata["VisitedWebPageDateWithTimeInISOString"] if "VisitedWebPageDateWithTimeInISOString" in doc.metadata.keys() else "NOT AVAILABLE",
VisitedWebPageReffererURL= doc.metadata["VisitedWebPageReffererURL"] if "VisitedWebPageReffererURL" in doc.metadata.keys() else "NOT AVAILABLE",
VisitedWebPageVisitDurationInMilliseconds= doc.metadata["VisitedWebPageVisitDurationInMilliseconds"] if "VisitedWebPageVisitDurationInMilliseconds" in doc.metadata.keys() else None,
VisitedWebPageContent= doc.page_content if doc.page_content else "NOT AVAILABLE"
)
)
docstoreturn = [i for n, i in enumerate(docstoreturn) if i not in docstoreturn[n + 1:]]
return UserQueryResponse(relateddocs=docstoreturn,response=response["result"])
except:
# Fallback to Similarity Search RAG
searchchain = SIMILARITY_SEARCH_PROMPT | llm
response = searchchain.invoke({"question": query, "context": docs})
return UserQueryResponse(relateddocs=docstoreturn,response=response.content)
# DOC DESCRIPTION
@app.post("/kb/doc")
def get_doc_description(data: UserQuery, response_model=DescriptionResponse):
if(data.apisecretkey != API_SECRET_KEY):
raise HTTPException(status_code=401, detail="Unauthorized")
document = data.query
llm = ChatOpenAI(
model="gpt-4o-mini",
temperature=0,
max_tokens=None,
timeout=None,
api_key=data.openaikey
)
descriptionchain = DOC_DESCRIPTION_PROMPT | llm
response = descriptionchain.invoke({"document": document})
return DescriptionResponse(response=response.content)
# SAVE DOCS TO GRAPH DB
@app.post("/kb/")
def populate_graph(apires: RetrivedDocList):
if(apires.apisecretkey != API_SECRET_KEY):
raise HTTPException(status_code=401, detail="Unauthorized")
print("STARTED")
# print(apires)
graph = Neo4jGraph(url=apires.neourl, username=apires.neouser, password=apires.neopass)
llm = ChatOpenAI(
model="gpt-4o-mini",
temperature=0,
max_tokens=None,
timeout=None,
api_key=apires.openaikey
)
embeddings = OpenAIEmbeddings(
model="text-embedding-ada-002",
api_key=apires.openaikey,
)
llm_transformer = LLMGraphTransformer(llm=llm)
raw_documents = []
for doc in apires.documents:
raw_documents.append(Document(page_content=doc.pageContent, metadata=doc.metadata))
text_splitter = SemanticChunker(embeddings=embeddings)
documents = text_splitter.split_documents(raw_documents)
graph_documents = llm_transformer.convert_to_graph_documents(documents)
graph.add_graph_documents(
graph_documents,
baseEntityLabel=True,
include_source=True
)
print("FINISHED")
return {
"success": "Graph Will be populated Shortly"
}
#AUTH CODE
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
# Recommended for Local Setups
# origins = [
# "http://localhost:3000", # Adjust the port if your frontend runs on a different one
# "https://yourfrontenddomain.com",
# ]
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins from the list
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
# Dependency
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
class UserCreate(BaseModel):
username: str
password: str
apisecretkey: str
def get_user_by_username(db: Session, username: str):
return db.query(User).filter(User.username == username).first()
def create_user(db: Session, user: UserCreate):
hashed_password = pwd_context.hash(user.password)
db_user = User(username=user.username, hashed_password=hashed_password)
db.add(db_user)
db.commit()
return "complete"
@app.post("/register")
def register_user(user: UserCreate, db: Session = Depends(get_db)):
if(user.apisecretkey != API_SECRET_KEY):
raise HTTPException(status_code=401, detail="Unauthorized")
db_user = get_user_by_username(db, username=user.username)
if db_user:
raise HTTPException(status_code=400, detail="Username already registered")
del user.apisecretkey
return create_user(db=db, user=user)
# Authenticate the user
def authenticate_user(username: str, password: str, db: Session):
user = db.query(User).filter(User.username == username).first()
if not user:
return False
if not pwd_context.verify(password, user.hashed_password):
return False
return user
# Create access token
def create_access_token(data: dict, expires_delta: timedelta | None = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
@app.post("/token")
def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
user = authenticate_user(form_data.username, form_data.password, db)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires
)
return {"access_token": access_token, "token_type": "bearer"}
def verify_token(token: str = Depends(oauth2_scheme)):
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
return payload
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
@app.get("/verify-token/{token}")
async def verify_user_token(token: str):
verify_token(token=token)
return {"message": "Token is valid"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="127.0.0.1", port=8000)