mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-16 21:05:20 +02:00
init
This commit is contained in:
parent
24f641347d
commit
63146aa9b7
86 changed files with 18766 additions and 0 deletions
12
backend/.gitignore
vendored
Normal file
12
backend/.gitignore
vendored
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
__pycache__
|
||||
__pycache__/
|
||||
.__pycache__
|
||||
830
backend/LLMGraphTransformer.py
Normal file
830
backend/LLMGraphTransformer.py
Normal file
|
|
@ -0,0 +1,830 @@
|
|||
import asyncio
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union, cast
|
||||
|
||||
from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.messages import SystemMessage
|
||||
from langchain_core.output_parsers import JsonOutputParser
|
||||
from langchain_core.prompts import (
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
PromptTemplate,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, create_model
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
examples = [
|
||||
{
|
||||
"text": (
|
||||
"Adam is a software engineer in Microsoft since 2009, "
|
||||
"and last year he got an award as the Best Talent"
|
||||
),
|
||||
"head": "Adam",
|
||||
"head_type": "Person",
|
||||
"relation": "WORKS_FOR",
|
||||
"tail": "Microsoft",
|
||||
"tail_type": "Company",
|
||||
},
|
||||
{
|
||||
"text": (
|
||||
"Adam is a software engineer in Microsoft since 2009, "
|
||||
"and last year he got an award as the Best Talent"
|
||||
),
|
||||
"head": "Adam",
|
||||
"head_type": "Person",
|
||||
"relation": "HAS_AWARD",
|
||||
"tail": "Best Talent",
|
||||
"tail_type": "Award",
|
||||
},
|
||||
{
|
||||
"text": (
|
||||
"Microsoft is a tech company that provide "
|
||||
"several products such as Microsoft Word"
|
||||
),
|
||||
"head": "Microsoft Word",
|
||||
"head_type": "Product",
|
||||
"relation": "PRODUCED_BY",
|
||||
"tail": "Microsoft",
|
||||
"tail_type": "Company",
|
||||
},
|
||||
{
|
||||
"text": "Microsoft Word is a lightweight app that accessible offline",
|
||||
"head": "Microsoft Word",
|
||||
"head_type": "Product",
|
||||
"relation": "HAS_CHARACTERISTIC",
|
||||
"tail": "lightweight app",
|
||||
"tail_type": "Characteristic",
|
||||
},
|
||||
{
|
||||
"text": "Microsoft Word is a lightweight app that accessible offline",
|
||||
"head": "Microsoft Word",
|
||||
"head_type": "Product",
|
||||
"relation": "HAS_CHARACTERISTIC",
|
||||
"tail": "accessible offline",
|
||||
"tail_type": "Characteristic",
|
||||
},
|
||||
]
|
||||
|
||||
system_prompt = (
|
||||
"# Knowledge Graph Instructions for GPT-4\n"
|
||||
"## 1. Overview\n"
|
||||
"You are a top-tier algorithm designed for extracting information in structured "
|
||||
"formats to build a knowledge graph.\n"
|
||||
"Try to capture as much information from the text as possible without "
|
||||
"sacrificing accuracy. Do not add any information that is not explicitly "
|
||||
"mentioned in the text.\n"
|
||||
"- **Nodes** represent entities and concepts.\n"
|
||||
"- The aim is to achieve simplicity and clarity in the knowledge graph, making it\n"
|
||||
"accessible for a vast audience.\n"
|
||||
"## 2. Labeling Nodes\n"
|
||||
"- **Consistency**: Ensure you use available types for node labels.\n"
|
||||
"Ensure you use basic or elementary types for node labels.\n"
|
||||
"- For example, when you identify an entity representing a person, "
|
||||
"always label it as **'person'**. Avoid using more specific terms "
|
||||
"like 'mathematician' or 'scientist'."
|
||||
"- **Node IDs**: Never utilize integers as node IDs. Node IDs should be "
|
||||
"names or human-readable identifiers found in the text.\n"
|
||||
"- **Relationships** represent connections between entities or concepts.\n"
|
||||
"Ensure consistency and generality in relationship types when constructing "
|
||||
"knowledge graphs. Instead of using specific and momentary types "
|
||||
"such as 'BECAME_PROFESSOR', use more general and timeless relationship types "
|
||||
"like 'PROFESSOR'. Make sure to use general and timeless relationship types!\n"
|
||||
"## 3. Coreference Resolution\n"
|
||||
"- **Maintain Entity Consistency**: When extracting entities, it's vital to "
|
||||
"ensure consistency.\n"
|
||||
'If an entity, such as "John Doe", is mentioned multiple times in the text '
|
||||
'but is referred to by different names or pronouns (e.g., "Joe", "he"),'
|
||||
"always use the most complete identifier for that entity throughout the "
|
||||
'knowledge graph. In this example, use "John Doe" as the entity ID.\n'
|
||||
"Remember, the knowledge graph should be coherent and easily understandable, "
|
||||
"so maintaining consistency in entity references is crucial.\n"
|
||||
"## 4. Strict Compliance\n"
|
||||
"Adhere to the rules strictly. Non-compliance will result in termination."
|
||||
)
|
||||
|
||||
default_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
system_prompt,
|
||||
),
|
||||
(
|
||||
"human",
|
||||
(
|
||||
"Note: The information given to you is information about a User's Web Browsing History."
|
||||
"Tip: Make sure to answer in the correct format and do "
|
||||
"not include any explanations. "
|
||||
"Use the given format to extract information from the "
|
||||
"following input: {input}"
|
||||
),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _get_additional_info(input_type: str) -> str:
|
||||
# Check if the input_type is one of the allowed values
|
||||
if input_type not in ["node", "relationship", "property"]:
|
||||
raise ValueError("input_type must be 'node', 'relationship', or 'property'")
|
||||
|
||||
# Perform actions based on the input_type
|
||||
if input_type == "node":
|
||||
return (
|
||||
"Ensure you use basic or elementary types for node labels.\n"
|
||||
"For example, when you identify an entity representing a person, "
|
||||
"always label it as **'Person'**. Avoid using more specific terms "
|
||||
"like 'Mathematician' or 'Scientist'"
|
||||
)
|
||||
elif input_type == "relationship":
|
||||
return (
|
||||
"Instead of using specific and momentary types such as "
|
||||
"'BECAME_PROFESSOR', use more general and timeless relationship types "
|
||||
"like 'PROFESSOR'. However, do not sacrifice any accuracy for generality"
|
||||
)
|
||||
elif input_type == "property":
|
||||
return ""
|
||||
return ""
|
||||
|
||||
|
||||
def optional_enum_field(
|
||||
enum_values: Optional[List[str]] = None,
|
||||
description: str = "",
|
||||
input_type: str = "node",
|
||||
llm_type: Optional[str] = None,
|
||||
**field_kwargs: Any,
|
||||
) -> Any:
|
||||
"""Utility function to conditionally create a field with an enum constraint."""
|
||||
# Only openai supports enum param
|
||||
if enum_values and llm_type == "openai-chat":
|
||||
return Field(
|
||||
...,
|
||||
enum=enum_values,
|
||||
description=f"{description}. Available options are {enum_values}",
|
||||
**field_kwargs,
|
||||
)
|
||||
elif enum_values:
|
||||
return Field(
|
||||
...,
|
||||
description=f"{description}. Available options are {enum_values}",
|
||||
**field_kwargs,
|
||||
)
|
||||
else:
|
||||
additional_info = _get_additional_info(input_type)
|
||||
return Field(..., description=description + additional_info, **field_kwargs)
|
||||
|
||||
|
||||
class _Graph(BaseModel):
|
||||
nodes: Optional[List]
|
||||
relationships: Optional[List]
|
||||
|
||||
|
||||
class UnstructuredRelation(BaseModel):
|
||||
head: str = Field(
|
||||
description=(
|
||||
"extracted head entity like Microsoft, Apple, John. "
|
||||
"Must use human-readable unique identifier."
|
||||
)
|
||||
)
|
||||
head_type: str = Field(
|
||||
description="type of the extracted head entity like Person, Company, etc"
|
||||
)
|
||||
relation: str = Field(description="relation between the head and the tail entities")
|
||||
tail: str = Field(
|
||||
description=(
|
||||
"extracted tail entity like Microsoft, Apple, John. "
|
||||
"Must use human-readable unique identifier."
|
||||
)
|
||||
)
|
||||
tail_type: str = Field(
|
||||
description="type of the extracted tail entity like Person, Company, etc"
|
||||
)
|
||||
|
||||
|
||||
def create_unstructured_prompt(
|
||||
node_labels: Optional[List[str]] = None, rel_types: Optional[List[str]] = None
|
||||
) -> ChatPromptTemplate:
|
||||
node_labels_str = str(node_labels) if node_labels else ""
|
||||
rel_types_str = str(rel_types) if rel_types else ""
|
||||
base_string_parts = [
|
||||
"You are a top-tier algorithm designed for extracting information in "
|
||||
"structured formats to build a knowledge graph. Your task is to identify "
|
||||
"the entities and relations requested with the user prompt from a given "
|
||||
"text. You must generate the output in a JSON format containing a list "
|
||||
'with JSON objects. Each object should have the keys: "head", '
|
||||
'"head_type", "relation", "tail", and "tail_type". The "head" '
|
||||
"key must contain the text of the extracted entity with one of the types "
|
||||
"from the provided list in the user prompt.",
|
||||
f'The "head_type" key must contain the type of the extracted head entity, '
|
||||
f"which must be one of the types from {node_labels_str}."
|
||||
if node_labels
|
||||
else "",
|
||||
f'The "relation" key must contain the type of relation between the "head" '
|
||||
f'and the "tail", which must be one of the relations from {rel_types_str}.'
|
||||
if rel_types
|
||||
else "",
|
||||
f'The "tail" key must represent the text of an extracted entity which is '
|
||||
f'the tail of the relation, and the "tail_type" key must contain the type '
|
||||
f"of the tail entity from {node_labels_str}."
|
||||
if node_labels
|
||||
else "",
|
||||
"Attempt to extract as many entities and relations as you can. Maintain "
|
||||
"Entity Consistency: When extracting entities, it's vital to ensure "
|
||||
'consistency. If an entity, such as "John Doe", is mentioned multiple '
|
||||
"times in the text but is referred to by different names or pronouns "
|
||||
'(e.g., "Joe", "he"), always use the most complete identifier for '
|
||||
"that entity. The knowledge graph should be coherent and easily "
|
||||
"understandable, so maintaining consistency in entity references is "
|
||||
"crucial.",
|
||||
"IMPORTANT NOTES:\n- Don't add any explanation and text.",
|
||||
]
|
||||
system_prompt = "\n".join(filter(None, base_string_parts))
|
||||
|
||||
system_message = SystemMessage(content=system_prompt)
|
||||
parser = JsonOutputParser(pydantic_object=UnstructuredRelation)
|
||||
|
||||
human_string_parts = [
|
||||
"Based on the following example, extract entities and "
|
||||
"relations from the provided text.\n\n",
|
||||
"Use the following entity types, don't use other entity "
|
||||
"that is not defined below:"
|
||||
"# ENTITY TYPES:"
|
||||
"{node_labels}"
|
||||
if node_labels
|
||||
else "",
|
||||
"Use the following relation types, don't use other relation "
|
||||
"that is not defined below:"
|
||||
"# RELATION TYPES:"
|
||||
"{rel_types}"
|
||||
if rel_types
|
||||
else "",
|
||||
"Below are a number of examples of text and their extracted "
|
||||
"entities and relationships."
|
||||
"{examples}\n"
|
||||
"For the following text, extract entities and relations as "
|
||||
"in the provided example."
|
||||
"{format_instructions}\nText: {input}",
|
||||
]
|
||||
human_prompt_string = "\n".join(filter(None, human_string_parts))
|
||||
human_prompt = PromptTemplate(
|
||||
template=human_prompt_string,
|
||||
input_variables=["input"],
|
||||
partial_variables={
|
||||
"format_instructions": parser.get_format_instructions(),
|
||||
"node_labels": node_labels,
|
||||
"rel_types": rel_types,
|
||||
"examples": examples,
|
||||
},
|
||||
)
|
||||
|
||||
human_message_prompt = HumanMessagePromptTemplate(prompt=human_prompt)
|
||||
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[system_message, human_message_prompt]
|
||||
)
|
||||
return chat_prompt
|
||||
|
||||
|
||||
def create_simple_model(
|
||||
node_labels: Optional[List[str]] = None,
|
||||
rel_types: Optional[List[str]] = None,
|
||||
node_properties: Union[bool, List[str]] = False,
|
||||
llm_type: Optional[str] = None,
|
||||
relationship_properties: Union[bool, List[str]] = False,
|
||||
) -> Type[_Graph]:
|
||||
"""
|
||||
Create a simple graph model with optional constraints on node
|
||||
and relationship types.
|
||||
|
||||
Args:
|
||||
node_labels (Optional[List[str]]): Specifies the allowed node types.
|
||||
Defaults to None, allowing all node types.
|
||||
rel_types (Optional[List[str]]): Specifies the allowed relationship types.
|
||||
Defaults to None, allowing all relationship types.
|
||||
node_properties (Union[bool, List[str]]): Specifies if node properties should
|
||||
be included. If a list is provided, only properties with keys in the list
|
||||
will be included. If True, all properties are included. Defaults to False.
|
||||
relationship_properties (Union[bool, List[str]]): Specifies if relationship
|
||||
properties should be included. If a list is provided, only properties with
|
||||
keys in the list will be included. If True, all properties are included.
|
||||
Defaults to False.
|
||||
llm_type (Optional[str]): The type of the language model. Defaults to None.
|
||||
Only openai supports enum param: openai-chat.
|
||||
|
||||
Returns:
|
||||
Type[_Graph]: A graph model with the specified constraints.
|
||||
|
||||
Raises:
|
||||
ValueError: If 'id' is included in the node or relationship properties list.
|
||||
"""
|
||||
|
||||
node_fields: Dict[str, Tuple[Any, Any]] = {
|
||||
"id": (
|
||||
str,
|
||||
Field(..., description="Name or human-readable unique identifier."),
|
||||
),
|
||||
"type": (
|
||||
str,
|
||||
optional_enum_field(
|
||||
node_labels,
|
||||
description="The type or label of the node.",
|
||||
input_type="node",
|
||||
llm_type=llm_type,
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
if node_properties:
|
||||
if isinstance(node_properties, list) and "id" in node_properties:
|
||||
raise ValueError("The node property 'id' is reserved and cannot be used.")
|
||||
# Map True to empty array
|
||||
node_properties_mapped: List[str] = (
|
||||
[] if node_properties is True else node_properties
|
||||
)
|
||||
|
||||
class Property(BaseModel):
|
||||
"""A single property consisting of key and value"""
|
||||
|
||||
key: str = optional_enum_field(
|
||||
node_properties_mapped,
|
||||
description="Property key.",
|
||||
input_type="property",
|
||||
llm_type=llm_type,
|
||||
)
|
||||
value: str = Field(..., description="value")
|
||||
|
||||
node_fields["properties"] = (
|
||||
Optional[List[Property]],
|
||||
Field(None, description="List of node properties"),
|
||||
)
|
||||
SimpleNode = create_model("SimpleNode", **node_fields) # type: ignore
|
||||
|
||||
relationship_fields: Dict[str, Tuple[Any, Any]] = {
|
||||
"source_node_id": (
|
||||
str,
|
||||
Field(
|
||||
...,
|
||||
description="Name or human-readable unique identifier of source node",
|
||||
),
|
||||
),
|
||||
"source_node_type": (
|
||||
str,
|
||||
optional_enum_field(
|
||||
node_labels,
|
||||
description="The type or label of the source node.",
|
||||
input_type="node",
|
||||
llm_type=llm_type,
|
||||
),
|
||||
),
|
||||
"target_node_id": (
|
||||
str,
|
||||
Field(
|
||||
...,
|
||||
description="Name or human-readable unique identifier of target node",
|
||||
),
|
||||
),
|
||||
"target_node_type": (
|
||||
str,
|
||||
optional_enum_field(
|
||||
node_labels,
|
||||
description="The type or label of the target node.",
|
||||
input_type="node",
|
||||
llm_type=llm_type,
|
||||
),
|
||||
),
|
||||
"type": (
|
||||
str,
|
||||
optional_enum_field(
|
||||
rel_types,
|
||||
description="The type of the relationship.",
|
||||
input_type="relationship",
|
||||
llm_type=llm_type,
|
||||
),
|
||||
),
|
||||
}
|
||||
if relationship_properties:
|
||||
if (
|
||||
isinstance(relationship_properties, list)
|
||||
and "id" in relationship_properties
|
||||
):
|
||||
raise ValueError(
|
||||
"The relationship property 'id' is reserved and cannot be used."
|
||||
)
|
||||
# Map True to empty array
|
||||
relationship_properties_mapped: List[str] = (
|
||||
[] if relationship_properties is True else relationship_properties
|
||||
)
|
||||
|
||||
class RelationshipProperty(BaseModel):
|
||||
"""A single property consisting of key and value"""
|
||||
|
||||
key: str = optional_enum_field(
|
||||
relationship_properties_mapped,
|
||||
description="Property key.",
|
||||
input_type="property",
|
||||
llm_type=llm_type,
|
||||
)
|
||||
value: str = Field(..., description="value")
|
||||
|
||||
relationship_fields["properties"] = (
|
||||
Optional[List[RelationshipProperty]],
|
||||
Field(None, description="List of relationship properties"),
|
||||
)
|
||||
SimpleRelationship = create_model("SimpleRelationship", **relationship_fields) # type: ignore
|
||||
|
||||
class DynamicGraph(_Graph):
|
||||
"""Represents a graph document consisting of nodes and relationships."""
|
||||
|
||||
nodes: Optional[List[SimpleNode]] = Field(description="List of nodes") # type: ignore
|
||||
relationships: Optional[List[SimpleRelationship]] = Field( # type: ignore
|
||||
description="List of relationships"
|
||||
)
|
||||
|
||||
return DynamicGraph
|
||||
|
||||
|
||||
def map_to_base_node(node: Any) -> Node:
|
||||
"""Map the SimpleNode to the base Node."""
|
||||
properties = {}
|
||||
if hasattr(node, "properties") and node.properties:
|
||||
for p in node.properties:
|
||||
properties[format_property_key(p.key)] = p.value
|
||||
return Node(id=node.id, type=node.type, properties=properties)
|
||||
|
||||
|
||||
def map_to_base_relationship(rel: Any) -> Relationship:
|
||||
"""Map the SimpleRelationship to the base Relationship."""
|
||||
source = Node(id=rel.source_node_id, type=rel.source_node_type)
|
||||
target = Node(id=rel.target_node_id, type=rel.target_node_type)
|
||||
properties = {}
|
||||
if hasattr(rel, "properties") and rel.properties:
|
||||
for p in rel.properties:
|
||||
properties[format_property_key(p.key)] = p.value
|
||||
return Relationship(
|
||||
source=source, target=target, type=rel.type, properties=properties
|
||||
)
|
||||
|
||||
|
||||
def _parse_and_clean_json(
|
||||
argument_json: Dict[str, Any],
|
||||
) -> Tuple[List[Node], List[Relationship]]:
|
||||
nodes = []
|
||||
for node in argument_json["nodes"]:
|
||||
if not node.get("id"): # Id is mandatory, skip this node
|
||||
continue
|
||||
node_properties = {}
|
||||
if "properties" in node and node["properties"]:
|
||||
for p in node["properties"]:
|
||||
node_properties[format_property_key(p["key"])] = p["value"]
|
||||
nodes.append(
|
||||
Node(
|
||||
id=node["id"],
|
||||
type=node.get("type"),
|
||||
properties=node_properties,
|
||||
)
|
||||
)
|
||||
relationships = []
|
||||
for rel in argument_json["relationships"]:
|
||||
# Mandatory props
|
||||
if (
|
||||
not rel.get("source_node_id")
|
||||
or not rel.get("target_node_id")
|
||||
or not rel.get("type")
|
||||
):
|
||||
continue
|
||||
|
||||
# Node type copying if needed from node list
|
||||
if not rel.get("source_node_type"):
|
||||
try:
|
||||
rel["source_node_type"] = [
|
||||
el.get("type")
|
||||
for el in argument_json["nodes"]
|
||||
if el["id"] == rel["source_node_id"]
|
||||
][0]
|
||||
except IndexError:
|
||||
rel["source_node_type"] = None
|
||||
if not rel.get("target_node_type"):
|
||||
try:
|
||||
rel["target_node_type"] = [
|
||||
el.get("type")
|
||||
for el in argument_json["nodes"]
|
||||
if el["id"] == rel["target_node_id"]
|
||||
][0]
|
||||
except IndexError:
|
||||
rel["target_node_type"] = None
|
||||
|
||||
rel_properties = {}
|
||||
if "properties" in rel and rel["properties"]:
|
||||
for p in rel["properties"]:
|
||||
rel_properties[format_property_key(p["key"])] = p["value"]
|
||||
|
||||
source_node = Node(
|
||||
id=rel["source_node_id"],
|
||||
type=rel["source_node_type"],
|
||||
)
|
||||
target_node = Node(
|
||||
id=rel["target_node_id"],
|
||||
type=rel["target_node_type"],
|
||||
)
|
||||
relationships.append(
|
||||
Relationship(
|
||||
source=source_node,
|
||||
target=target_node,
|
||||
type=rel["type"],
|
||||
properties=rel_properties,
|
||||
)
|
||||
)
|
||||
return nodes, relationships
|
||||
|
||||
|
||||
def _format_nodes(nodes: List[Node]) -> List[Node]:
|
||||
return [
|
||||
Node(
|
||||
id=el.id.title() if isinstance(el.id, str) else el.id,
|
||||
type=el.type.capitalize() # type: ignore[arg-type]
|
||||
if el.type
|
||||
else None, # handle empty strings # type: ignore[arg-type]
|
||||
properties=el.properties,
|
||||
)
|
||||
for el in nodes
|
||||
]
|
||||
|
||||
|
||||
def _format_relationships(rels: List[Relationship]) -> List[Relationship]:
|
||||
return [
|
||||
Relationship(
|
||||
source=_format_nodes([el.source])[0],
|
||||
target=_format_nodes([el.target])[0],
|
||||
type=el.type.replace(" ", "_").upper(),
|
||||
properties=el.properties,
|
||||
)
|
||||
for el in rels
|
||||
]
|
||||
|
||||
|
||||
def format_property_key(s: str) -> str:
|
||||
words = s.split()
|
||||
if not words:
|
||||
return s
|
||||
first_word = words[0].lower()
|
||||
capitalized_words = [word.capitalize() for word in words[1:]]
|
||||
return "".join([first_word] + capitalized_words)
|
||||
|
||||
|
||||
def _convert_to_graph_document(
|
||||
raw_schema: Dict[Any, Any],
|
||||
) -> Tuple[List[Node], List[Relationship]]:
|
||||
# If there are validation errors
|
||||
if not raw_schema["parsed"]:
|
||||
try:
|
||||
try: # OpenAI type response
|
||||
argument_json = json.loads(
|
||||
raw_schema["raw"].additional_kwargs["tool_calls"][0]["function"][
|
||||
"arguments"
|
||||
]
|
||||
)
|
||||
except Exception: # Google type response
|
||||
argument_json = json.loads(
|
||||
raw_schema["raw"].additional_kwargs["function_call"]["arguments"]
|
||||
)
|
||||
|
||||
nodes, relationships = _parse_and_clean_json(argument_json)
|
||||
except Exception: # If we can't parse JSON
|
||||
return ([], [])
|
||||
else: # If there are no validation errors use parsed pydantic object
|
||||
parsed_schema: _Graph = raw_schema["parsed"]
|
||||
nodes = (
|
||||
[map_to_base_node(node) for node in parsed_schema.nodes if node.id]
|
||||
if parsed_schema.nodes
|
||||
else []
|
||||
)
|
||||
|
||||
relationships = (
|
||||
[
|
||||
map_to_base_relationship(rel)
|
||||
for rel in parsed_schema.relationships
|
||||
if rel.type and rel.source_node_id and rel.target_node_id
|
||||
]
|
||||
if parsed_schema.relationships
|
||||
else []
|
||||
)
|
||||
# Title / Capitalize
|
||||
return _format_nodes(nodes), _format_relationships(relationships)
|
||||
|
||||
|
||||
class LLMGraphTransformer:
|
||||
"""Transform documents into graph-based documents using a LLM.
|
||||
|
||||
It allows specifying constraints on the types of nodes and relationships to include
|
||||
in the output graph. The class supports extracting properties for both nodes and
|
||||
relationships.
|
||||
|
||||
Args:
|
||||
llm (BaseLanguageModel): An instance of a language model supporting structured
|
||||
output.
|
||||
allowed_nodes (List[str], optional): Specifies which node types are
|
||||
allowed in the graph. Defaults to an empty list, allowing all node types.
|
||||
allowed_relationships (List[str], optional): Specifies which relationship types
|
||||
are allowed in the graph. Defaults to an empty list, allowing all relationship
|
||||
types.
|
||||
prompt (Optional[ChatPromptTemplate], optional): The prompt to pass to
|
||||
the LLM with additional instructions.
|
||||
strict_mode (bool, optional): Determines whether the transformer should apply
|
||||
filtering to strictly adhere to `allowed_nodes` and `allowed_relationships`.
|
||||
Defaults to True.
|
||||
node_properties (Union[bool, List[str]]): If True, the LLM can extract any
|
||||
node properties from text. Alternatively, a list of valid properties can
|
||||
be provided for the LLM to extract, restricting extraction to those specified.
|
||||
relationship_properties (Union[bool, List[str]]): If True, the LLM can extract
|
||||
any relationship properties from text. Alternatively, a list of valid
|
||||
properties can be provided for the LLM to extract, restricting extraction to
|
||||
those specified.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
from langchain_experimental.graph_transformers import LLMGraphTransformer
|
||||
from langchain_core.documents import Document
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
llm=ChatOpenAI(temperature=0)
|
||||
transformer = LLMGraphTransformer(
|
||||
llm=llm,
|
||||
allowed_nodes=["Person", "Organization"])
|
||||
|
||||
doc = Document(page_content="Elon Musk is suing OpenAI")
|
||||
graph_documents = transformer.convert_to_graph_documents([doc])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: BaseLanguageModel,
|
||||
allowed_nodes: List[str] = [],
|
||||
allowed_relationships: List[str] = [],
|
||||
prompt: Optional[ChatPromptTemplate] = None,
|
||||
strict_mode: bool = True,
|
||||
node_properties: Union[bool, List[str]] = False,
|
||||
relationship_properties: Union[bool, List[str]] = False,
|
||||
) -> None:
|
||||
self.allowed_nodes = allowed_nodes
|
||||
self.allowed_relationships = allowed_relationships
|
||||
self.strict_mode = strict_mode
|
||||
self._function_call = True
|
||||
# Check if the LLM really supports structured output
|
||||
try:
|
||||
llm.with_structured_output(_Graph)
|
||||
except NotImplementedError:
|
||||
self._function_call = False
|
||||
if not self._function_call:
|
||||
if node_properties or relationship_properties:
|
||||
raise ValueError(
|
||||
"The 'node_properties' and 'relationship_properties' parameters "
|
||||
"cannot be used in combination with a LLM that doesn't support "
|
||||
"native function calling."
|
||||
)
|
||||
try:
|
||||
import json_repair # type: ignore
|
||||
|
||||
self.json_repair = json_repair
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import json_repair python package. "
|
||||
"Please install it with `pip install json-repair`."
|
||||
)
|
||||
prompt = prompt or create_unstructured_prompt(
|
||||
allowed_nodes, allowed_relationships
|
||||
)
|
||||
self.chain = prompt | llm
|
||||
else:
|
||||
# Define chain
|
||||
try:
|
||||
llm_type = llm._llm_type # type: ignore
|
||||
except AttributeError:
|
||||
llm_type = None
|
||||
schema = create_simple_model(
|
||||
allowed_nodes,
|
||||
allowed_relationships,
|
||||
node_properties,
|
||||
llm_type,
|
||||
relationship_properties,
|
||||
)
|
||||
structured_llm = llm.with_structured_output(schema, include_raw=True)
|
||||
prompt = prompt or default_prompt
|
||||
self.chain = prompt | structured_llm
|
||||
|
||||
def process_response(
|
||||
self, document: Document, config: Optional[RunnableConfig] = None
|
||||
) -> GraphDocument:
|
||||
"""
|
||||
Processes a single document, transforming it into a graph document using
|
||||
an LLM based on the model's schema and constraints.
|
||||
"""
|
||||
text = document.page_content
|
||||
raw_schema = self.chain.invoke({"input": text}, config=config)
|
||||
if self._function_call:
|
||||
raw_schema = cast(Dict[Any, Any], raw_schema)
|
||||
nodes, relationships = _convert_to_graph_document(raw_schema)
|
||||
else:
|
||||
nodes_set = set()
|
||||
relationships = []
|
||||
if not isinstance(raw_schema, str):
|
||||
raw_schema = raw_schema.content
|
||||
parsed_json = self.json_repair.loads(raw_schema)
|
||||
for rel in parsed_json:
|
||||
# Nodes need to be deduplicated using a set
|
||||
nodes_set.add((rel["head"], rel["head_type"]))
|
||||
nodes_set.add((rel["tail"], rel["tail_type"]))
|
||||
|
||||
source_node = Node(id=rel["head"], type=rel["head_type"])
|
||||
target_node = Node(id=rel["tail"], type=rel["tail_type"])
|
||||
relationships.append(
|
||||
Relationship(
|
||||
source=source_node, target=target_node, type=rel["relation"]
|
||||
)
|
||||
)
|
||||
# Create nodes list
|
||||
nodes = [Node(id=el[0], type=el[1]) for el in list(nodes_set)]
|
||||
|
||||
# Strict mode filtering
|
||||
if self.strict_mode and (self.allowed_nodes or self.allowed_relationships):
|
||||
if self.allowed_nodes:
|
||||
lower_allowed_nodes = [el.lower() for el in self.allowed_nodes]
|
||||
nodes = [
|
||||
node for node in nodes if node.type.lower() in lower_allowed_nodes
|
||||
]
|
||||
relationships = [
|
||||
rel
|
||||
for rel in relationships
|
||||
if rel.source.type.lower() in lower_allowed_nodes
|
||||
and rel.target.type.lower() in lower_allowed_nodes
|
||||
]
|
||||
if self.allowed_relationships:
|
||||
relationships = [
|
||||
rel
|
||||
for rel in relationships
|
||||
if rel.type.lower()
|
||||
in [el.lower() for el in self.allowed_relationships]
|
||||
]
|
||||
|
||||
return GraphDocument(nodes=nodes, relationships=relationships, source=document)
|
||||
|
||||
def convert_to_graph_documents(
|
||||
self, documents: Sequence[Document], config: Optional[RunnableConfig] = None
|
||||
) -> List[GraphDocument]:
|
||||
"""Convert a sequence of documents into graph documents.
|
||||
|
||||
Args:
|
||||
documents (Sequence[Document]): The original documents.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
Sequence[GraphDocument]: The transformed documents as graphs.
|
||||
"""
|
||||
return [self.process_response(document, config) for document in documents]
|
||||
|
||||
async def aprocess_response(
|
||||
self, document: Document, config: Optional[RunnableConfig] = None
|
||||
) -> GraphDocument:
|
||||
"""
|
||||
Asynchronously processes a single document, transforming it into a
|
||||
graph document.
|
||||
"""
|
||||
text = document.page_content
|
||||
raw_schema = await self.chain.ainvoke({"input": text}, config=config)
|
||||
raw_schema = cast(Dict[Any, Any], raw_schema)
|
||||
nodes, relationships = _convert_to_graph_document(raw_schema)
|
||||
|
||||
if self.strict_mode and (self.allowed_nodes or self.allowed_relationships):
|
||||
if self.allowed_nodes:
|
||||
lower_allowed_nodes = [el.lower() for el in self.allowed_nodes]
|
||||
nodes = [
|
||||
node for node in nodes if node.type.lower() in lower_allowed_nodes
|
||||
]
|
||||
relationships = [
|
||||
rel
|
||||
for rel in relationships
|
||||
if rel.source.type.lower() in lower_allowed_nodes
|
||||
and rel.target.type.lower() in lower_allowed_nodes
|
||||
]
|
||||
if self.allowed_relationships:
|
||||
relationships = [
|
||||
rel
|
||||
for rel in relationships
|
||||
if rel.type.lower()
|
||||
in [el.lower() for el in self.allowed_relationships]
|
||||
]
|
||||
|
||||
return GraphDocument(nodes=nodes, relationships=relationships, source=document)
|
||||
|
||||
async def aconvert_to_graph_documents(
|
||||
self, documents: Sequence[Document], config: Optional[RunnableConfig] = None
|
||||
) -> List[GraphDocument]:
|
||||
"""
|
||||
Asynchronously convert a sequence of documents into graph documents.
|
||||
"""
|
||||
tasks = [
|
||||
asyncio.create_task(self.aprocess_response(document, config))
|
||||
for document in documents
|
||||
]
|
||||
results = await asyncio.gather(*tasks)
|
||||
return results
|
||||
14
backend/database.py
Normal file
14
backend/database.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
from sqlalchemy import create_engine, Column, Integer, String
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from envs import POSTGRES_DATABASE_URL
|
||||
|
||||
|
||||
engine = create_engine(
|
||||
POSTGRES_DATABASE_URL
|
||||
)
|
||||
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
Base = declarative_base()
|
||||
10
backend/envs.py
Normal file
10
backend/envs.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
#POSTGRES DB TO TRACK USERS
|
||||
POSTGRES_DATABASE_URL = "postgresql+psycopg2://postgres:postgres@localhost:5432/surfsense"
|
||||
|
||||
# API KEY TO VERIFY
|
||||
API_SECRET_KEY = "surfsense"
|
||||
|
||||
# Your JWT secret and algorithm
|
||||
SECRET_KEY = "your_secret_key"
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 720
|
||||
14
backend/models.py
Normal file
14
backend/models.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
from sqlalchemy import Column, Integer, String
|
||||
from database import Base
|
||||
from database import engine
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
username = Column(String, unique=True, index=True)
|
||||
hashed_password = Column(String)
|
||||
|
||||
# Create the database tables if they don't exist
|
||||
|
||||
User.metadata.create_all(bind=engine)
|
||||
84
backend/prompts.py
Normal file
84
backend/prompts.py
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
DATE_TODAY = "Today's date is " + datetime.now(timezone.utc).astimezone().isoformat() + '\n'
|
||||
|
||||
CYPHER_QA_TEMPLATE = DATE_TODAY + """You are an assistant that helps to form nice and human understandable answers.
|
||||
The information part contains the provided information that you must use to construct an answer.
|
||||
The provided information is authoritative, you must never doubt it or try to use your internal knowledge to correct it.
|
||||
Make the answer sound as a response to the question. Do not mention that you based the result on the given information.
|
||||
Here are the examples:
|
||||
|
||||
Question: Website on which the most time was spend on?
|
||||
Context:[{'d.VisitedWebPageURL': 'https://stackoverflow.com/questions/59873698/the-default-export-is-not-a-react-component-in-page-nextjs', 'totalDuration': 8889167}]
|
||||
Helpful Answer: You visited https://stackoverflow.com/questions/59873698/the-default-export-is-not-a-react-component-in-page-nextjs for 8889167 milliseconds or 8889.167 seconds.
|
||||
|
||||
Follow this example when generating answers.
|
||||
If the provided information is empty, then and only then, return exactly 'don't know' as answer.
|
||||
|
||||
Information:
|
||||
{context}
|
||||
|
||||
Question: {question}
|
||||
Helpful Answer:"""
|
||||
|
||||
CYPHER_QA_PROMPT = PromptTemplate(
|
||||
input_variables=["context", "question"], template=CYPHER_QA_TEMPLATE
|
||||
)
|
||||
|
||||
SIMILARITY_SEARCH_RAG = DATE_TODAY + """You are an assistant for question-answering tasks.
|
||||
Use the following pieces of retrieved context to answer the question.
|
||||
If you don't know the answer, return exactly 'don't know' as answer.
|
||||
Question: {question}
|
||||
Context: {context}
|
||||
Answer:"""
|
||||
|
||||
|
||||
SIMILARITY_SEARCH_PROMPT = PromptTemplate(
|
||||
input_variables=["context", "question"], template=SIMILARITY_SEARCH_RAG
|
||||
)
|
||||
|
||||
# doc_extract_chain = DOCUMENT_METADATA_EXTRACTION_PROMT | structured_llm
|
||||
|
||||
|
||||
CYPHER_GENERATION_TEMPLATE = DATE_TODAY + """Task:Generate Cypher statement to query a graph database.
|
||||
Instructions:
|
||||
Use only the provided relationship types and properties in the schema.
|
||||
Do not use any other relationship types or properties that are not provided.
|
||||
|
||||
Schema:
|
||||
{schema}
|
||||
Note: Do not include any explanations or apologies in your responses.
|
||||
Do not respond to any questions that might ask anything else than for you to construct a Cypher statement.
|
||||
Do not include any text except the generated Cypher statement.
|
||||
|
||||
|
||||
The question is:
|
||||
{question}"""
|
||||
CYPHER_GENERATION_PROMPT = PromptTemplate(
|
||||
input_variables=["schema", "question"], template=CYPHER_GENERATION_TEMPLATE
|
||||
)
|
||||
|
||||
|
||||
DOC_DESCRIPTION_TEMPLATE = """Task:Give Detailed Description of the page content of the given document.
|
||||
Instructions:
|
||||
Provide as much details about metadata & page content as if you need to give human readable report of this Browsing session event.
|
||||
|
||||
Document:
|
||||
{document}
|
||||
"""
|
||||
DOC_DESCRIPTION_PROMPT = PromptTemplate(
|
||||
input_variables=["document"], template=DOC_DESCRIPTION_TEMPLATE
|
||||
)
|
||||
|
||||
|
||||
DOCUMENT_METADATA_EXTRACTION_SYSTEM_MESSAGE = DATE_TODAY + """You are a helpful assistant. You are given a Cypher statement result after quering the Neo4j graph database.
|
||||
Generate a very good Query that can be used to perform similarity search on the vector store of the Neo4j graph database"""
|
||||
|
||||
DOCUMENT_METADATA_EXTRACTION_PROMT = ChatPromptTemplate.from_messages([("system", DOCUMENT_METADATA_EXTRACTION_SYSTEM_MESSAGE), ("human", "{input}")])
|
||||
|
||||
|
||||
|
||||
|
||||
41
backend/pydmodels.py
Normal file
41
backend/pydmodels.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
|
||||
class UserQuery(BaseModel):
|
||||
query: str
|
||||
neourl: str
|
||||
neouser: str
|
||||
neopass: str
|
||||
openaikey: str
|
||||
apisecretkey: str
|
||||
|
||||
class DescriptionResponse(BaseModel):
|
||||
response: str
|
||||
|
||||
|
||||
class DocMeta(BaseModel):
|
||||
BrowsingSessionId: Optional[str] = Field(default=None, description="BrowsingSessionId of Document")
|
||||
VisitedWebPageURL: Optional[str] = Field(default=None, description="VisitedWebPageURL of Document")
|
||||
VisitedWebPageTitle: Optional[str] = Field(default=None, description="VisitedWebPageTitle of Document")
|
||||
VisitedWebPageDateWithTimeInISOString: Optional[str] = Field(default=None, description="VisitedWebPageDateWithTimeInISOString of Document")
|
||||
VisitedWebPageReffererURL: Optional[str] = Field(default=None, description="VisitedWebPageReffererURL of Document")
|
||||
VisitedWebPageVisitDurationInMilliseconds: Optional[int] = Field(default=None, description="VisitedWebPageVisitDurationInMilliseconds of Document"),
|
||||
VisitedWebPageContent: Optional[str] = Field(default=None, description="Visited WebPage Content in markdown of Document")
|
||||
|
||||
class RetrivedDocListItem(BaseModel):
|
||||
metadata: DocMeta
|
||||
pageContent: str
|
||||
|
||||
class RetrivedDocList(BaseModel):
|
||||
documents: List[RetrivedDocListItem]
|
||||
neourl: str
|
||||
neouser: str
|
||||
neopass: str
|
||||
openaikey: str
|
||||
apisecretkey: str
|
||||
|
||||
|
||||
class UserQueryResponse(BaseModel):
|
||||
response: str
|
||||
relateddocs: List[DocMeta]
|
||||
|
||||
8
backend/requirements.txt
Normal file
8
backend/requirements.txt
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
bcrypt
|
||||
cryptography
|
||||
fastapi
|
||||
python-jose
|
||||
python-multipart
|
||||
SQLAlchemy
|
||||
uvicorn
|
||||
passlib
|
||||
322
backend/server.py
Normal file
322
backend/server.py
Normal file
|
|
@ -0,0 +1,322 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from langchain.chains import GraphCypherQAChain
|
||||
from langchain_community.graphs import Neo4jGraph
|
||||
from langchain_core.documents import Document
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from langchain_community.vectorstores import Neo4jVector
|
||||
from envs import ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM, API_SECRET_KEY, SECRET_KEY
|
||||
from prompts import CYPHER_QA_PROMPT, DOC_DESCRIPTION_PROMPT, SIMILARITY_SEARCH_PROMPT , CYPHER_GENERATION_PROMPT, DOCUMENT_METADATA_EXTRACTION_PROMT
|
||||
from pydmodels import DescriptionResponse, UserQuery, DocMeta, RetrivedDocList, UserQueryResponse
|
||||
from langchain_experimental.text_splitter import SemanticChunker
|
||||
|
||||
#Our Imps
|
||||
from LLMGraphTransformer import LLMGraphTransformer
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
# Auth Libs
|
||||
from fastapi import FastAPI, Depends, HTTPException, Request, status
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||
from jose import JWTError, jwt
|
||||
from datetime import datetime, timedelta
|
||||
from passlib.context import CryptContext
|
||||
from models import User
|
||||
from database import SessionLocal, engine
|
||||
from pydantic import BaseModel
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@app.post("/")
|
||||
def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
|
||||
|
||||
if(data.apisecretkey != API_SECRET_KEY):
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
|
||||
query = data.query
|
||||
|
||||
graph = Neo4jGraph(url=data.neourl, username=data.neouser, password=data.neopass)
|
||||
|
||||
llm = ChatOpenAI(
|
||||
model="gpt-4o-mini",
|
||||
temperature=0,
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
api_key=data.openaikey
|
||||
)
|
||||
|
||||
embeddings = OpenAIEmbeddings(
|
||||
model="text-embedding-ada-002",
|
||||
api_key=data.openaikey,
|
||||
)
|
||||
|
||||
|
||||
chain = GraphCypherQAChain.from_llm(
|
||||
graph=graph,
|
||||
cypher_prompt=CYPHER_GENERATION_PROMPT,
|
||||
cypher_llm=llm,
|
||||
verbose=True,
|
||||
validate_cypher=True,
|
||||
qa_prompt=CYPHER_QA_PROMPT ,
|
||||
qa_llm=llm,
|
||||
return_intermediate_steps=True,
|
||||
top_k=5,
|
||||
)
|
||||
|
||||
vector_index = Neo4jVector.from_existing_graph(
|
||||
embeddings,
|
||||
graph=graph,
|
||||
search_type="hybrid",
|
||||
node_label="Document",
|
||||
text_node_properties=["text"],
|
||||
embedding_node_property="embedding",
|
||||
)
|
||||
|
||||
docs = vector_index.similarity_search(query,k=5)
|
||||
|
||||
docstoreturn = []
|
||||
|
||||
for doc in docs:
|
||||
docstoreturn.append(
|
||||
DocMeta(
|
||||
BrowsingSessionId=doc.metadata["BrowsingSessionId"] if "BrowsingSessionId" in doc.metadata.keys() else "NOT AVAILABLE",
|
||||
VisitedWebPageURL=doc.metadata["VisitedWebPageURL"] if "VisitedWebPageURL" in doc.metadata.keys()else "NOT AVAILABLE",
|
||||
VisitedWebPageTitle=doc.metadata["VisitedWebPageTitle"] if "VisitedWebPageTitle" in doc.metadata.keys() else "NOT AVAILABLE",
|
||||
VisitedWebPageDateWithTimeInISOString= doc.metadata["VisitedWebPageDateWithTimeInISOString"] if "VisitedWebPageDateWithTimeInISOString" in doc.metadata.keys() else "NOT AVAILABLE",
|
||||
VisitedWebPageReffererURL= doc.metadata["VisitedWebPageReffererURL"] if "VisitedWebPageReffererURL" in doc.metadata.keys() else "NOT AVAILABLE",
|
||||
VisitedWebPageVisitDurationInMilliseconds= doc.metadata["VisitedWebPageVisitDurationInMilliseconds"] if "VisitedWebPageVisitDurationInMilliseconds" in doc.metadata.keys() else None,
|
||||
VisitedWebPageContent= doc.page_content if doc.page_content else "NOT AVAILABLE"
|
||||
)
|
||||
)
|
||||
|
||||
docstoreturn = [i for n, i in enumerate(docstoreturn) if i not in docstoreturn[n + 1:]]
|
||||
|
||||
|
||||
try:
|
||||
response = chain.invoke({"query": query})
|
||||
if "don't know" in response["result"]:
|
||||
raise Exception("No response from graph")
|
||||
|
||||
structured_llm = llm.with_structured_output(RetrivedDocList)
|
||||
doc_extract_chain = DOCUMENT_METADATA_EXTRACTION_PROMT | structured_llm
|
||||
|
||||
query = doc_extract_chain.invoke(response["intermediate_steps"][1]["context"])
|
||||
|
||||
docs = vector_index.similarity_search(query.searchquery,k=5)
|
||||
|
||||
docstoreturn = []
|
||||
|
||||
for doc in docs:
|
||||
docstoreturn.append(
|
||||
DocMeta(
|
||||
BrowsingSessionId=doc.metadata["BrowsingSessionId"] if "BrowsingSessionId" in doc.metadata.keys() else "NOT AVAILABLE",
|
||||
VisitedWebPageURL=doc.metadata["VisitedWebPageURL"] if "VisitedWebPageURL" in doc.metadata.keys()else "NOT AVAILABLE",
|
||||
VisitedWebPageTitle=doc.metadata["VisitedWebPageTitle"] if "VisitedWebPageTitle" in doc.metadata.keys() else "NOT AVAILABLE",
|
||||
VisitedWebPageDateWithTimeInISOString= doc.metadata["VisitedWebPageDateWithTimeInISOString"] if "VisitedWebPageDateWithTimeInISOString" in doc.metadata.keys() else "NOT AVAILABLE",
|
||||
VisitedWebPageReffererURL= doc.metadata["VisitedWebPageReffererURL"] if "VisitedWebPageReffererURL" in doc.metadata.keys() else "NOT AVAILABLE",
|
||||
VisitedWebPageVisitDurationInMilliseconds= doc.metadata["VisitedWebPageVisitDurationInMilliseconds"] if "VisitedWebPageVisitDurationInMilliseconds" in doc.metadata.keys() else None,
|
||||
VisitedWebPageContent= doc.page_content if doc.page_content else "NOT AVAILABLE"
|
||||
)
|
||||
)
|
||||
|
||||
docstoreturn = [i for n, i in enumerate(docstoreturn) if i not in docstoreturn[n + 1:]]
|
||||
|
||||
return UserQueryResponse(relateddocs=docstoreturn,response=response["result"])
|
||||
except:
|
||||
# Fallback to Similarity Search RAG
|
||||
searchchain = SIMILARITY_SEARCH_PROMPT | llm
|
||||
|
||||
response = searchchain.invoke({"question": query, "context": docs})
|
||||
|
||||
return UserQueryResponse(relateddocs=docstoreturn,response=response.content)
|
||||
|
||||
# DOC DESCRIPTION
|
||||
@app.post("/kb/doc")
|
||||
def get_doc_description(data: UserQuery, response_model=DescriptionResponse):
|
||||
if(data.apisecretkey != API_SECRET_KEY):
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
document = data.query
|
||||
llm = ChatOpenAI(
|
||||
model="gpt-4o-mini",
|
||||
temperature=0,
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
api_key=data.openaikey
|
||||
)
|
||||
|
||||
descriptionchain = DOC_DESCRIPTION_PROMPT | llm
|
||||
|
||||
response = descriptionchain.invoke({"document": document})
|
||||
|
||||
return DescriptionResponse(response=response.content)
|
||||
|
||||
|
||||
# SAVE DOCS TO GRAPH DB
|
||||
@app.post("/kb/")
|
||||
def populate_graph(apires: RetrivedDocList):
|
||||
if(apires.apisecretkey != API_SECRET_KEY):
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
print("STARTED")
|
||||
# print(apires)
|
||||
graph = Neo4jGraph(url=apires.neourl, username=apires.neouser, password=apires.neopass)
|
||||
|
||||
llm = ChatOpenAI(
|
||||
model="gpt-4o-mini",
|
||||
temperature=0,
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
api_key=apires.openaikey
|
||||
)
|
||||
|
||||
embeddings = OpenAIEmbeddings(
|
||||
model="text-embedding-ada-002",
|
||||
api_key=apires.openaikey,
|
||||
)
|
||||
|
||||
llm_transformer = LLMGraphTransformer(llm=llm)
|
||||
|
||||
raw_documents = []
|
||||
|
||||
for doc in apires.documents:
|
||||
raw_documents.append(Document(page_content=doc.pageContent, metadata=doc.metadata))
|
||||
|
||||
text_splitter = SemanticChunker(embeddings=embeddings)
|
||||
|
||||
documents = text_splitter.split_documents(raw_documents)
|
||||
graph_documents = llm_transformer.convert_to_graph_documents(documents)
|
||||
|
||||
|
||||
graph.add_graph_documents(
|
||||
graph_documents,
|
||||
baseEntityLabel=True,
|
||||
include_source=True
|
||||
)
|
||||
|
||||
print("FINISHED")
|
||||
|
||||
return {
|
||||
"success": "Graph Will be populated Shortly"
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
#AUTH CODE
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
# Recommended for Local Setups
|
||||
# origins = [
|
||||
# "http://localhost:3000", # Adjust the port if your frontend runs on a different one
|
||||
# "https://yourfrontenddomain.com",
|
||||
# ]
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # Allows all origins from the list
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"], # Allows all methods
|
||||
allow_headers=["*"], # Allows all headers
|
||||
)
|
||||
|
||||
# Dependency
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
class UserCreate(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
apisecretkey: str
|
||||
|
||||
def get_user_by_username(db: Session, username: str):
|
||||
return db.query(User).filter(User.username == username).first()
|
||||
|
||||
def create_user(db: Session, user: UserCreate):
|
||||
hashed_password = pwd_context.hash(user.password)
|
||||
db_user = User(username=user.username, hashed_password=hashed_password)
|
||||
db.add(db_user)
|
||||
db.commit()
|
||||
return "complete"
|
||||
|
||||
@app.post("/register")
|
||||
def register_user(user: UserCreate, db: Session = Depends(get_db)):
|
||||
if(user.apisecretkey != API_SECRET_KEY):
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
db_user = get_user_by_username(db, username=user.username)
|
||||
if db_user:
|
||||
raise HTTPException(status_code=400, detail="Username already registered")
|
||||
|
||||
del user.apisecretkey
|
||||
return create_user(db=db, user=user)
|
||||
|
||||
# Authenticate the user
|
||||
def authenticate_user(username: str, password: str, db: Session):
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
if not user:
|
||||
return False
|
||||
if not pwd_context.verify(password, user.hashed_password):
|
||||
return False
|
||||
return user
|
||||
|
||||
# Create access token
|
||||
def create_access_token(data: dict, expires_delta: timedelta | None = None):
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(minutes=15)
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
@app.post("/token")
|
||||
def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
|
||||
user = authenticate_user(form_data.username, form_data.password, db)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect username or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
access_token = create_access_token(
|
||||
data={"sub": user.username}, expires_delta=access_token_expires
|
||||
)
|
||||
return {"access_token": access_token, "token_type": "bearer"}
|
||||
|
||||
|
||||
def verify_token(token: str = Depends(oauth2_scheme)):
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
return payload
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
@app.get("/verify-token/{token}")
|
||||
async def verify_user_token(token: str):
|
||||
verify_token(token=token)
|
||||
return {"message": "Token is valid"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="127.0.0.1", port=8000)
|
||||
Loading…
Add table
Add a link
Reference in a new issue