mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-28 18:36:23 +02:00
feat: gpt-researcher custom response.Now very close to perplexity.
This commit is contained in:
parent
dfb0967dbe
commit
46c9b228df
5 changed files with 215 additions and 194 deletions
|
|
@ -1,3 +1,7 @@
|
|||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
from gpt_researcher import GPTResearcher
|
||||
from langchain_chroma import Chroma
|
||||
from langchain_ollama import OllamaLLM, OllamaEmbeddings
|
||||
from langchain_community.vectorstores.utils import filter_complex_metadata
|
||||
|
|
@ -14,12 +18,23 @@ from langchain_core.prompts import PromptTemplate
|
|||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from pydmodels import AIAnswer, Reference
|
||||
from database import SessionLocal
|
||||
from models import Documents, User
|
||||
from prompts import CONTEXT_ANSWER_PROMPT
|
||||
load_dotenv()
|
||||
|
||||
IS_LOCAL_SETUP = os.environ.get("IS_LOCAL_SETUP")
|
||||
FAST_LLM = os.environ.get("FAST_LLM")
|
||||
EMBEDDING = os.environ.get("EMBEDDING")
|
||||
IS_LOCAL_SETUP = True if FAST_LLM.startswith("ollama") else False
|
||||
|
||||
|
||||
def extract_model_name(model_string: str) -> tuple[str, str]:
|
||||
part1, part2 = model_string.split(":", 1) # Split into two parts at the first colon
|
||||
return part2
|
||||
|
||||
MODEL_NAME = extract_model_name(FAST_LLM)
|
||||
EMBEDDING_MODEL = extract_model_name(EMBEDDING)
|
||||
|
||||
# Dependency
|
||||
def get_db():
|
||||
|
|
@ -35,12 +50,12 @@ class HIndices:
|
|||
"""
|
||||
"""
|
||||
self.username = username
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
self.llm = OllamaLLM(model="mistral-nemo",temperature=0)
|
||||
self.embeddings = OllamaEmbeddings(model="mistral-nemo")
|
||||
if(IS_LOCAL_SETUP == True):
|
||||
self.llm = OllamaLLM(model=MODEL_NAME,temperature=0)
|
||||
self.embeddings = OllamaEmbeddings(model=EMBEDDING_MODEL)
|
||||
else:
|
||||
self.llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", api_key=api_key)
|
||||
self.embeddings = OpenAIEmbeddings(api_key=api_key)
|
||||
self.llm = ChatOpenAI(temperature=0, model_name=MODEL_NAME, api_key=api_key)
|
||||
self.embeddings = OpenAIEmbeddings(api_key=api_key,model=EMBEDDING_MODEL)
|
||||
|
||||
self.summary_store = Chroma(
|
||||
collection_name="summary_store",
|
||||
|
|
@ -92,34 +107,25 @@ class HIndices:
|
|||
report_chain = report_prompt | self.llm
|
||||
|
||||
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
# Local LLMS suck at summaries so need this slow and painful procedure
|
||||
text_splitter = SemanticChunker(embeddings=self.embeddings)
|
||||
chunks = text_splitter.split_documents([doc])
|
||||
combined_summary = ""
|
||||
for i, chunk in enumerate(chunks):
|
||||
print("GENERATING SUMMARY FOR CHUNK "+ str(i))
|
||||
chunk_summary = report_chain.invoke({"document": chunk})
|
||||
combined_summary += "\n\n" + chunk_summary + "\n\n"
|
||||
|
||||
response = combined_summary
|
||||
if(IS_LOCAL_SETUP == True):
|
||||
|
||||
response = report_chain.invoke({"document": doc})
|
||||
|
||||
|
||||
metadict = {
|
||||
"page": page_no,
|
||||
"summary": True,
|
||||
"search_space": search_space,
|
||||
}
|
||||
|
||||
# metadict['languages'] = metadict['languages'][0]
|
||||
|
||||
metadict.update(doc.metadata)
|
||||
|
||||
# metadict['languages'] = metadict['languages'][0]
|
||||
|
||||
return Document(
|
||||
id=str(page_no),
|
||||
page_content=response,
|
||||
metadata=metadict
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
response = report_chain.invoke({"document": doc})
|
||||
|
|
@ -177,17 +183,8 @@ class HIndices:
|
|||
report_chain = report_prompt | self.llm
|
||||
|
||||
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
# Local LLMS suck at summaries so need this slow and painful procedure
|
||||
text_splitter = SemanticChunker(embeddings=self.embeddings)
|
||||
chunks = text_splitter.split_documents([doc])
|
||||
combined_summary = ""
|
||||
for i, chunk in enumerate(chunks):
|
||||
print("GENERATING SUMMARY FOR CHUNK "+ str(i))
|
||||
chunk_summary = report_chain.invoke({"document": chunk})
|
||||
combined_summary += "\n\n" + chunk_summary + "\n\n"
|
||||
|
||||
response = combined_summary
|
||||
if(IS_LOCAL_SETUP == True):
|
||||
response = report_chain.invoke({"document": doc})
|
||||
|
||||
return Document(
|
||||
id=str(page_no),
|
||||
|
|
@ -204,7 +201,8 @@ class HIndices:
|
|||
"VisitedWebPageReffererURL": doc.metadata['VisitedWebPageReffererURL'],
|
||||
"VisitedWebPageVisitDurationInMilliseconds": doc.metadata['VisitedWebPageVisitDurationInMilliseconds'],
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
response = report_chain.invoke({"document": doc})
|
||||
|
||||
|
|
@ -229,19 +227,6 @@ class HIndices:
|
|||
"""
|
||||
Creates and Saves/Updates docs in hierarchical indices and postgres table
|
||||
"""
|
||||
|
||||
|
||||
# DocumentPgEntry = []
|
||||
# searchspace = db.query(SearchSpace).filter(SearchSpace.search_space == search_space).first()
|
||||
|
||||
# for doc in documents:
|
||||
# pgdocmeta = stringify(doc.metadata)
|
||||
|
||||
# if(searchspace):
|
||||
# DocumentPgEntry.append(Documents(file_type='WEBPAGE',title=doc.metadata.VisitedWebPageTitle,search_space=search_space, document_metadata=pgdocmeta, page_content=doc.page_content))
|
||||
# else:
|
||||
# DocumentPgEntry.append(Documents(file_type='WEBPAGE',title=doc.metadata.VisitedWebPageTitle,search_space=SearchSpace(search_space=search_space), document_metadata=pgdocmeta, page_content=doc.page_content))
|
||||
|
||||
|
||||
prev_doc_idx = len(documents) + 1
|
||||
# #Save docs in PG
|
||||
|
|
@ -262,22 +247,20 @@ class HIndices:
|
|||
else:
|
||||
batch_summaries = [self.summarize_file_doc(page_no = i + summary_last_id, doc=doc, search_space=search_space) for i, doc in enumerate(documents)]
|
||||
|
||||
# batch_summaries = [summarize_doc(i + summary_last_id, doc) for i, doc in enumerate(documents)]
|
||||
|
||||
summaries.extend(batch_summaries)
|
||||
|
||||
detailed_chunks = []
|
||||
|
||||
for i, summary in enumerate(summaries):
|
||||
|
||||
# Semantic chucking for better contexual comprression
|
||||
# Semantic chucking for better contexual compression
|
||||
text_splitter = SemanticChunker(embeddings=self.embeddings)
|
||||
chunks = text_splitter.split_documents([documents[i]])
|
||||
|
||||
user.documents[-(len(summaries) - i)].desc_vector_start = detail_id_counter
|
||||
user.documents[-(len(summaries) - i)].desc_vector_end = detail_id_counter + len(chunks)
|
||||
# summary_entry = db.query(Documents).filter(Documents.id == int(user.documents[-1].id)).first()
|
||||
# summary_entry.desc_vector_start = detail_id_counter
|
||||
# summary_entry.desc_vector_end = detail_id_counter + len(chunks)
|
||||
|
||||
|
||||
db.commit()
|
||||
|
||||
|
|
@ -290,6 +273,30 @@ class HIndices:
|
|||
"page": summary.metadata['page'],
|
||||
})
|
||||
|
||||
if(files_type == 'WEBPAGE'):
|
||||
ieee_content = (
|
||||
f"=======================================DOCUMENT METADATA==================================== \n"
|
||||
f"Source: {chunk.metadata['VisitedWebPageURL']} \n"
|
||||
f"Title: {chunk.metadata['VisitedWebPageTitle']} \n"
|
||||
f"Visited Date and Time : {chunk.metadata['VisitedWebPageDateWithTimeInISOString']} \n"
|
||||
f"============================DOCUMENT PAGE CONTENT CHUNK===================================== \n"
|
||||
f"Page Content Chunk: \n\n{chunk.page_content}\n\n"
|
||||
f"===================================================================================== \n"
|
||||
)
|
||||
|
||||
else:
|
||||
ieee_content = (
|
||||
f"=======================================DOCUMENT METADATA==================================== \n"
|
||||
f"Source: {chunk.metadata['filename']} \n"
|
||||
f"Title: {chunk.metadata['filename']} \n"
|
||||
f"Visited Date and Time : {datetime.now()} \n"
|
||||
f"============================DOCUMENT PAGE CONTENT CHUNK===================================== \n"
|
||||
f"Page Content Chunk: \n\n{chunk.page_content}\n\n"
|
||||
f"===================================================================================== \n"
|
||||
)
|
||||
|
||||
chunk.page_content = ieee_content
|
||||
|
||||
detail_id_counter += 1
|
||||
|
||||
detailed_chunks.extend(chunks)
|
||||
|
|
@ -312,34 +319,68 @@ class HIndices:
|
|||
db.commit()
|
||||
|
||||
return "success"
|
||||
|
||||
def is_query_answerable(self, query, context):
|
||||
prompt = PromptTemplate(
|
||||
template="""You are a grader assessing relevance of a retrieved document to a user question. \n
|
||||
Here is the retrieved document: \n\n {context} \n\n
|
||||
Here is the user question: {question} \n
|
||||
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
|
||||
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.
|
||||
Only return 'yes' or 'no'""",
|
||||
input_variables=["context", "question"],
|
||||
)
|
||||
|
||||
def summary_vector_search(self,query, search_space='GENERAL'):
|
||||
top_summaries_compressor = FlashrankRerank(top_n=20)
|
||||
|
||||
ans_chain = prompt | self.llm
|
||||
|
||||
finalans = ans_chain.invoke({"question": query, "context": context})
|
||||
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
return finalans
|
||||
else:
|
||||
return finalans.content
|
||||
|
||||
def local_search(self, query, search_space='GENERAL'):
|
||||
top_summaries_compressor = FlashrankRerank(top_n=5)
|
||||
details_compressor = FlashrankRerank(top_n=30)
|
||||
top_summaries_retreiver = ContextualCompressionRetriever(
|
||||
base_compressor=top_summaries_compressor, base_retriever=self.summary_store.as_retriever(search_kwargs={'filter': {'search_space': search_space}})
|
||||
)
|
||||
|
||||
return top_summaries_retreiver.invoke(query)
|
||||
|
||||
def deduplicate_references_and_update_answer(self, answer: str, references: List[Reference]) -> tuple[str, List[Reference]]:
|
||||
"""
|
||||
Deduplicates references and updates the answer text to maintain correct reference numbering.
|
||||
|
||||
Args:
|
||||
answer: The text containing reference citations
|
||||
references: List of Reference objects
|
||||
|
||||
Returns:
|
||||
tuple: (updated_answer, deduplicated_references)
|
||||
"""
|
||||
# Track unique references and create ID mapping using a dictionary comprehension
|
||||
unique_refs = {}
|
||||
id_mapping = {
|
||||
ref.id: unique_refs.setdefault(
|
||||
ref.url, Reference(id=str(len(unique_refs) + 1), title=ref.title, url=ref.url)
|
||||
).id
|
||||
for ref in references
|
||||
}
|
||||
|
||||
# Apply new mappings to the answer text
|
||||
updated_answer = answer
|
||||
for old_id, new_id in sorted(id_mapping.items(), key=lambda x: len(x[0]), reverse=True):
|
||||
updated_answer = updated_answer.replace(f'[{old_id}]', f'[{new_id}]')
|
||||
|
||||
return updated_answer, list(unique_refs.values())
|
||||
|
||||
async def get_vectorstore_report(self, query: str, report_type: str, report_source: str, documents: List[Document]) -> str:
|
||||
researcher = GPTResearcher(query=query, report_type=report_type, report_source=report_source, documents=documents, report_format="IEEE")
|
||||
await researcher.conduct_research()
|
||||
report = await researcher.write_report()
|
||||
return report
|
||||
|
||||
async def get_web_report(self, query: str, report_type: str, report_source: str) -> str:
|
||||
researcher = GPTResearcher(query=query, report_type=report_type, report_source=report_source, report_format="IEEE")
|
||||
await researcher.conduct_research()
|
||||
report = await researcher.write_report()
|
||||
return report
|
||||
|
||||
def new_search(self, query, search_space='GENERAL'):
|
||||
report_type = "custom_report"
|
||||
report_source = "langchain_documents"
|
||||
contextdocs = []
|
||||
|
||||
|
||||
|
||||
top_summaries_compressor = FlashrankRerank(top_n=5)
|
||||
details_compressor = FlashrankRerank(top_n=50)
|
||||
top_summaries_retreiver = ContextualCompressionRetriever(
|
||||
base_compressor=top_summaries_compressor, base_retriever=self.summary_store.as_retriever(search_kwargs={'filter': {'search_space': search_space}})#
|
||||
)
|
||||
|
||||
top_summaries_compressed_docs = top_summaries_retreiver.invoke(query)
|
||||
|
||||
for summary in top_summaries_compressed_docs:
|
||||
|
|
@ -349,66 +390,45 @@ class HIndices:
|
|||
detailed_compression_retriever = ContextualCompressionRetriever(
|
||||
base_compressor=details_compressor, base_retriever=self.detailed_store.as_retriever(search_kwargs={'filter': {'page': page_number}})
|
||||
)
|
||||
|
||||
|
||||
detailed_compressed_docs = detailed_compression_retriever.invoke(
|
||||
query
|
||||
)
|
||||
|
||||
contextdocs = top_summaries_compressed_docs + detailed_compressed_docs
|
||||
|
||||
context_to_answer = ""
|
||||
for i, doc in enumerate(contextdocs):
|
||||
content = f":DOCUMENT {str(i)}\n"
|
||||
content += f"=======================================METADATA==================================== \n"
|
||||
content += f"{doc.metadata} \n"
|
||||
content += f"===================================================================================== \n"
|
||||
content += f"DOCUMENT CONTENT: \n\n {doc.page_content} \n\n"
|
||||
content += f"===================================================================================== \n"
|
||||
|
||||
context_to_answer += content
|
||||
|
||||
content = ""
|
||||
|
||||
if(self.is_query_answerable(query=query, context=context_to_answer).lower() == 'yes'):
|
||||
ans_chain = CONTEXT_ANSWER_PROMPT | self.llm
|
||||
contextdocs.extend(detailed_compressed_docs)
|
||||
|
||||
finalans = ans_chain.invoke({"query": query, "context": context_to_answer})
|
||||
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
return finalans
|
||||
else:
|
||||
return finalans.content
|
||||
else:
|
||||
continue
|
||||
|
||||
return "I couldn't find any answer"
|
||||
custom_prompt = """
|
||||
Please answer the following user query in the format shown below, using in-text citations and IEEE-style references based on the provided documents.
|
||||
USER QUERY : """+ query +"""
|
||||
|
||||
def global_search(self,query, search_space='GENERAL'):
|
||||
top_summaries_compressor = FlashrankRerank(top_n=20)
|
||||
Ensure the answer includes:
|
||||
- A detailed yet concise explanation with IEEE-style in-text citations (e.g., [1], [2]).
|
||||
- A list of non-duplicated sources at the end, following IEEE format. Hyperlink each source using: [Website Name](URL).
|
||||
- Where applicable, provide sources in the text to back up key points.
|
||||
|
||||
top_summaries_retreiver = ContextualCompressionRetriever(
|
||||
base_compressor=top_summaries_compressor, base_retriever=self.summary_store.as_retriever(search_kwargs={'filter': {'search_space': search_space}})
|
||||
)
|
||||
Ensure your response is structured something like this (here user query : Explain the impact of artificial intelligence on modern healthcare.):
|
||||
---
|
||||
**Answer:**
|
||||
Artificial intelligence (AI) has significantly transformed modern healthcare by enhancing diagnostic accuracy, personalizing patient care, and optimizing operational efficiency. AI algorithms can analyze vast datasets to identify patterns that may be missed by human practitioners, leading to improved diagnostic outcomes [1]. For instance, AI systems have been deployed in radiology to detect anomalies in medical imaging with high precision [2]. Moreover, AI-driven tools facilitate personalized treatment plans by considering individual patient data, thereby improving patient outcomes [3].
|
||||
|
||||
**References:**
|
||||
1. (2024, October 23). [Highly Effective Prompt for Summarizing — GPT-4 Optimized: r/ChatGPT.](https://www.reddit.com/r/ChatGPT/comments/13na8yp/highly_effective_prompt_for_summarizing_gpt4/)
|
||||
2. (2024, October 23). [MODSetter/SurfSense: Personal AI Assistant for Internet Surfers and Researchers.](https://github.com/MODSetter/SurfSense)
|
||||
3. filename.pdf
|
||||
|
||||
---
|
||||
|
||||
"""
|
||||
|
||||
top_summaries_compressed_docs = top_summaries_retreiver.invoke(query)
|
||||
local_report = asyncio.run(self.get_vectorstore_report(query=custom_prompt, report_type=report_type, report_source=report_source, documents=contextdocs))
|
||||
|
||||
context_to_answer = ""
|
||||
for i, doc in enumerate(top_summaries_compressed_docs):
|
||||
content = f":DOCUMENT {str(i)}\n"
|
||||
content += f"=======================================METADATA==================================== \n"
|
||||
content += f"{doc.metadata} \n"
|
||||
content += f"===================================================================================== \n"
|
||||
content += f"DOCUMENT CONTENT: \n\n {doc.page_content} \n\n"
|
||||
content += f"===================================================================================== \n"
|
||||
|
||||
context_to_answer += content
|
||||
|
||||
ans_chain = CONTEXT_ANSWER_PROMPT | self.llm
|
||||
# web_report = asyncio.run(get_web_report(query=custom_prompt, report_type=report_type, report_source="web"))
|
||||
|
||||
# structured_llm = self.llm.with_structured_output(AIAnswer)
|
||||
|
||||
finalans = ans_chain.invoke({"query": query, "context": context_to_answer})
|
||||
# out = structured_llm.invoke("Extract exact(i.e without changing) answer string and references information from : \n\n\n" + local_report)
|
||||
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
return finalans, top_summaries_compressed_docs
|
||||
else:
|
||||
return finalans.content, top_summaries_compressed_docs
|
||||
# mod_out = self.deduplicate_references_and_update_answer(answer=out.answer, references=out.references)
|
||||
|
||||
return local_report
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue