Spaces:
Sleeping
Sleeping
sravan
commited on
Commit
·
dc078e3
1
Parent(s):
ae692a1
first working application
Browse files- .gitignore +1 -1
- callbacks.py +6 -3
- chains.py +88 -8
- code_data/langchain_repo +1 -0
- data_indexing.py +132 -29
- main.py +57 -31
- prompts.py +14 -12
- sources.txt +0 -0
- test.db +0 -0
.gitignore
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
myenv
|
2 |
*pycache*
|
3 |
|
4 |
-
|
|
|
1 |
myenv
|
2 |
*pycache*
|
3 |
|
4 |
+
dang.py
|
callbacks.py
CHANGED
@@ -2,6 +2,7 @@ from typing import Dict, Any, List
|
|
2 |
from langchain_core.callbacks import BaseCallbackHandler
|
3 |
import schemas
|
4 |
import crud
|
|
|
5 |
|
6 |
|
7 |
class LogResponseCallback(BaseCallbackHandler):
|
@@ -16,13 +17,15 @@ class LogResponseCallback(BaseCallbackHandler):
|
|
16 |
# TODO: The function on_llm_end is going to be called when the LLM stops sending
|
17 |
# the response. Use the crud.add_message function to capture that response.
|
18 |
type = 'AI'
|
19 |
-
user_data = crud.
|
20 |
-
user_id = user_data.
|
21 |
timestamp = datetime.now()
|
22 |
-
message = outputs
|
|
|
23 |
message_to_add = schemas.MessageBase(
|
24 |
user_id = user_id,
|
25 |
message = message,
|
|
|
26 |
type = type,
|
27 |
timestamp = timestamp
|
28 |
)
|
|
|
2 |
from langchain_core.callbacks import BaseCallbackHandler
|
3 |
import schemas
|
4 |
import crud
|
5 |
+
from datetime import datetime
|
6 |
|
7 |
|
8 |
class LogResponseCallback(BaseCallbackHandler):
|
|
|
17 |
# TODO: The function on_llm_end is going to be called when the LLM stops sending
|
18 |
# the response. Use the crud.add_message function to capture that response.
|
19 |
type = 'AI'
|
20 |
+
user_data = crud.get_or_create_user(self.db, self.user_request.username)
|
21 |
+
user_id = user_data.id
|
22 |
timestamp = datetime.now()
|
23 |
+
message = str(outputs) # answer from the prompt message
|
24 |
+
print("hoistory messages", message)
|
25 |
message_to_add = schemas.MessageBase(
|
26 |
user_id = user_id,
|
27 |
message = message,
|
28 |
+
user=user_data.username,
|
29 |
type = type,
|
30 |
timestamp = timestamp
|
31 |
)
|
chains.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import os
|
2 |
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
|
3 |
from langchain_core.runnables import RunnablePassthrough
|
|
|
4 |
|
5 |
import schemas
|
6 |
import prompts
|
@@ -12,7 +13,7 @@ from prompts import (
|
|
12 |
standalone_prompt_formatted,
|
13 |
rag_prompt_formatted
|
14 |
)
|
15 |
-
from data_indexing import DataIndexer
|
16 |
from transformers import AutoTokenizer
|
17 |
|
18 |
data_indexer = DataIndexer()
|
@@ -52,37 +53,116 @@ llm_endpoint = HuggingFaceEndpoint(
|
|
52 |
|
53 |
llm = ChatHuggingFace(llm=llm_endpoint)
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
simple_chain = (raw_prompt | llm).with_types(input_type=schemas.UserQuestion)
|
56 |
|
|
|
|
|
57 |
# TODO: create formatted_chain by piping raw_prompt_formatted and the LLM endpoint.
|
58 |
-
formatted_chain = (raw_prompt_formatted | llm).with_types(input_type=schemas.UserQuestion)
|
59 |
|
60 |
# TODO: use history_prompt_formatted and HistoryInput to create the history_chain
|
61 |
-
history_chain = (history_prompt_formatted | llm).with_types(input_type=schemas.HistoryInput)
|
62 |
|
63 |
# TODO: Let's construct the standalone_chain by piping standalone_prompt_formatted with the LLM
|
64 |
standalone_chain = (standalone_prompt_formatted | llm).with_types(input_type=schemas.HistoryInput)
|
65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
input_1 = RunnablePassthrough.assign(new_question=standalone_chain)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
input_2 = {
|
68 |
-
'context': lambda x:
|
69 |
-
'standalone_question': lambda x: x['new_question']
|
70 |
}
|
|
|
71 |
input_to_rag_chain = input_1 | input_2
|
72 |
|
73 |
# TODO: use input_to_rag_chain, rag_prompt_formatted,
|
74 |
# HistoryInput and the LLM to build the rag_chain.
|
75 |
-
rag_chain = (input_to_rag_chain | rag_prompt_formatted | llm).with_types(input_type=schemas.RagInput)
|
76 |
|
77 |
# TODO: Implement the filtered_rag_chain. It should be the
|
78 |
# same as the rag_chain but with hybrid_search = True.
|
79 |
|
80 |
input_2_hybrid_search = {
|
81 |
-
'context': lambda x:
|
82 |
'standalone_question': lambda x: x['new_question']
|
83 |
}
|
84 |
|
85 |
-
filtered_rag_chain = (input_1 | input_2_hybrid_search | rag_prompt_formatted | llm ).with_types(input_type=schemas.RagInput)
|
86 |
|
87 |
|
88 |
|
|
|
1 |
import os
|
2 |
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
|
3 |
from langchain_core.runnables import RunnablePassthrough
|
4 |
+
from langchain.schema.runnable import RunnableLambda
|
5 |
|
6 |
import schemas
|
7 |
import prompts
|
|
|
13 |
standalone_prompt_formatted,
|
14 |
rag_prompt_formatted
|
15 |
)
|
16 |
+
from data_indexing import DataIndexer
|
17 |
from transformers import AutoTokenizer
|
18 |
|
19 |
data_indexer = DataIndexer()
|
|
|
53 |
|
54 |
llm = ChatHuggingFace(llm=llm_endpoint)
|
55 |
|
56 |
+
def print_and_pass(prompt_output):
|
57 |
+
print("=" * 60)
|
58 |
+
print("🔍 RAW PROMPT FORMATTED:")
|
59 |
+
print("=" * 60)
|
60 |
+
print(prompt_output)
|
61 |
+
print("=" * 60)
|
62 |
+
return prompt_output # IMPORTANT: Must return the prompt unchanged
|
63 |
+
|
64 |
simple_chain = (raw_prompt | llm).with_types(input_type=schemas.UserQuestion)
|
65 |
|
66 |
+
|
67 |
+
|
68 |
# TODO: create formatted_chain by piping raw_prompt_formatted and the LLM endpoint.
|
69 |
+
formatted_chain = (raw_prompt_formatted | RunnableLambda(print_and_pass) | llm).with_types(input_type=schemas.UserQuestion)
|
70 |
|
71 |
# TODO: use history_prompt_formatted and HistoryInput to create the history_chain
|
72 |
+
history_chain = (history_prompt_formatted | RunnableLambda(print_and_pass) | llm).with_types(input_type=schemas.HistoryInput)
|
73 |
|
74 |
# TODO: Let's construct the standalone_chain by piping standalone_prompt_formatted with the LLM
|
75 |
standalone_chain = (standalone_prompt_formatted | llm).with_types(input_type=schemas.HistoryInput)
|
76 |
|
77 |
+
# summarize_chain = (summarize_propt_formatted | llm)
|
78 |
+
|
79 |
+
import ast
|
80 |
+
|
81 |
+
def extract_definitions(source_code):
|
82 |
+
"""
|
83 |
+
Extract top-level function and class definitions from Python code.
|
84 |
+
"""
|
85 |
+
result = []
|
86 |
+
try:
|
87 |
+
tree = ast.parse(source_code)
|
88 |
+
for node in ast.iter_child_nodes(tree):
|
89 |
+
if isinstance(node, (ast.FunctionDef, ast.ClassDef)):
|
90 |
+
snippet = ast.get_source_segment(source_code, node)
|
91 |
+
if snippet:
|
92 |
+
result.append(snippet)
|
93 |
+
except Exception as e:
|
94 |
+
print(f"Failed to parse code: {e}")
|
95 |
+
return result
|
96 |
+
|
97 |
+
import re
|
98 |
+
|
99 |
+
def clean_code_text(code_text):
|
100 |
+
"""
|
101 |
+
Remove comments and excessive blank lines for brevity.
|
102 |
+
"""
|
103 |
+
# Remove multiline docstrings and comments
|
104 |
+
code_text = re.sub(r'"""(.*?)"""', '', code_text, flags=re.DOTALL)
|
105 |
+
code_text = re.sub(r"'''(.*?)'''", '', code_text, flags=re.DOTALL)
|
106 |
+
|
107 |
+
# Remove inline comments
|
108 |
+
code_text = re.sub(r'#.*', '', code_text)
|
109 |
+
|
110 |
+
# Remove excessive whitespace
|
111 |
+
code_text = re.sub(r'\n\s*\n+', '\n\n', code_text)
|
112 |
+
|
113 |
+
return code_text.strip()
|
114 |
+
|
115 |
+
|
116 |
+
def safe_format_context(search_results):
|
117 |
+
try:
|
118 |
+
cleaned_results = []
|
119 |
+
for result in search_results:
|
120 |
+
if isinstance(result, str):
|
121 |
+
# Optionally: extract relevant functions/classes
|
122 |
+
code_parts = extract_definitions(result)
|
123 |
+
for part in code_parts:
|
124 |
+
cleaned = clean_code_text(part)
|
125 |
+
cleaned_results.append(cleaned)
|
126 |
+
return format_context(cleaned_results)
|
127 |
+
except Exception as e:
|
128 |
+
print(f"Error formatting context: {str(e)}")
|
129 |
+
return "No relevant context found."
|
130 |
+
|
131 |
+
|
132 |
input_1 = RunnablePassthrough.assign(new_question=standalone_chain)
|
133 |
+
|
134 |
+
# input_1_beta = RunnablePassThrough.assign(new_context=summarize_chain)
|
135 |
+
|
136 |
+
def extract_question_text(new_question):
|
137 |
+
if hasattr(new_question, "content"):
|
138 |
+
return new_question.content
|
139 |
+
return str(new_question)
|
140 |
+
|
141 |
+
# summarize_context = {
|
142 |
+
# 'context': lambda x: safe_format_context(data_indexer.search(extract_question_text(x['new_question']))),
|
143 |
+
# 'standalone_question': lambda x: extract_question_text(x['new_question']),
|
144 |
+
# }
|
145 |
+
|
146 |
input_2 = {
|
147 |
+
'context': lambda x: safe_format_context(data_indexer.search(extract_question_text(x['new_question']))),
|
148 |
+
'standalone_question': lambda x: extract_question_text(x['new_question']),
|
149 |
}
|
150 |
+
|
151 |
input_to_rag_chain = input_1 | input_2
|
152 |
|
153 |
# TODO: use input_to_rag_chain, rag_prompt_formatted,
|
154 |
# HistoryInput and the LLM to build the rag_chain.
|
155 |
+
rag_chain = (input_to_rag_chain | RunnableLambda(print_and_pass) | rag_prompt_formatted | RunnableLambda(print_and_pass) | llm).with_types(input_type=schemas.RagInput)
|
156 |
|
157 |
# TODO: Implement the filtered_rag_chain. It should be the
|
158 |
# same as the rag_chain but with hybrid_search = True.
|
159 |
|
160 |
input_2_hybrid_search = {
|
161 |
+
'context': lambda x: safe_format_context(data_indexer.search(extract_question_text(x['new_question']), hybrid_search=True)),
|
162 |
'standalone_question': lambda x: x['new_question']
|
163 |
}
|
164 |
|
165 |
+
filtered_rag_chain = (input_1 | input_2_hybrid_search | rag_prompt_formatted | RunnableLambda(print_and_pass)| llm ).with_types(input_type=schemas.RagInput)
|
166 |
|
167 |
|
168 |
|
code_data/langchain_repo
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit 2d0713c2fc5a457578635b03b7a00e970ce534ee
|
data_indexing.py
CHANGED
@@ -6,9 +6,41 @@ from pinecone import ServerlessSpec
|
|
6 |
from langchain_community.vectorstores import Chroma
|
7 |
from langchain_openai import OpenAIEmbeddings
|
8 |
from huggingface_hub import InferenceClient
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
current_dir = Path(__file__).resolve().parent
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
class DataIndexer:
|
14 |
|
@@ -18,9 +50,16 @@ class DataIndexer:
|
|
18 |
|
19 |
# TODO: choose your embedding model
|
20 |
self.embedding_client = InferenceClient(
|
21 |
-
"dunzhang/stella_en_1.5B_v5",
|
|
|
22 |
token=os.environ['HF_TOKEN'],
|
23 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
self.spec = ServerlessSpec(
|
25 |
cloud = 'aws',
|
26 |
region='us-east-1'
|
@@ -34,14 +73,22 @@ class DataIndexer:
|
|
34 |
# Make sure to choose the dimension that corresponds to your embedding model
|
35 |
self.pinecone_client.create_index(
|
36 |
name=index_name,
|
37 |
-
dimension=
|
38 |
metric='cosine',
|
39 |
spec=self.spec
|
40 |
)
|
41 |
|
42 |
self.index = self.pinecone_client.Index(self.index_name)
|
43 |
# TODO: make sure to build the index.
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
def get_source_index(self):
|
47 |
if not os.path.isfile(self.source_file):
|
@@ -53,9 +100,17 @@ class DataIndexer:
|
|
53 |
with open(self.source_file, 'r') as file:
|
54 |
sources = file.readlines()
|
55 |
|
56 |
-
sources = [s.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
vectorstore = Chroma.from_texts(
|
58 |
-
sources, embedding=self.
|
59 |
)
|
60 |
return vectorstore
|
61 |
|
@@ -64,19 +119,21 @@ class DataIndexer:
|
|
64 |
with open(self.source_file, 'a') as file:
|
65 |
for doc in docs:
|
66 |
file.writelines(doc.metadata['source'] + '\n')
|
|
|
|
|
67 |
|
68 |
for i in range(0, len(docs), batch_size):
|
69 |
batch = docs[i: i + batch_size]
|
70 |
|
71 |
# TODO: create a list of the vector representations of each text data in the batch
|
72 |
# TODO: choose your embedding model
|
73 |
-
values = self.embedding_client.embed_documents([
|
74 |
-
doc.page_content for doc in batch
|
75 |
-
])
|
76 |
-
|
77 |
-
# values = self.embedding_client.feature_extraction([
|
78 |
# doc.page_content for doc in batch
|
79 |
# ])
|
|
|
|
|
|
|
|
|
80 |
# values = None
|
81 |
|
82 |
# TODO: create a list of unique identifiers for each element in the batch with the uuid package.
|
@@ -85,7 +142,7 @@ class DataIndexer:
|
|
85 |
# TODO: create a list of dictionaries representing the metadata. Capture the text data
|
86 |
# with the "text" key, and make sure to capture the rest of the doc.metadata.
|
87 |
metadatas = [{"text": doc.page_content,
|
88 |
-
**doc.metadata
|
89 |
} for doc in batch]
|
90 |
|
91 |
# create a list of dictionaries with keys "id" (the unique identifiers), "values"
|
@@ -96,6 +153,8 @@ class DataIndexer:
|
|
96 |
'metadata': metadata
|
97 |
} for vector_id, value, metadata in zip(vector_ids, values, metadatas)]
|
98 |
|
|
|
|
|
99 |
try:
|
100 |
# TODO: Use the function upsert to upload the data to the database.
|
101 |
upsert_response = self.index.upsert(vectors)
|
@@ -111,28 +170,48 @@ class DataIndexer:
|
|
111 |
# to the question. Make sure to adjust this number as you see fit.
|
112 |
source_docs = self.source_index.similarity_search(text_query, 50)
|
113 |
filter = {"source": {"$in":[doc.page_content for doc in source_docs]}}
|
114 |
-
|
115 |
# TODO: embed the text_query by using the embedding model
|
116 |
# TODO: choose your embedding model
|
117 |
# vector = self.embedding_client.feature_extraction(text_query)
|
118 |
-
|
119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
|
121 |
# TODO: use the vector representation of the text_query to
|
122 |
# search the database by using the query function.
|
123 |
-
|
124 |
-
filter=filter,
|
125 |
-
top_k=top_k,
|
126 |
-
include_values=True
|
127 |
-
)
|
128 |
-
|
129 |
-
docs = []
|
130 |
-
for res in result["matches"]:
|
131 |
-
# TODO: From the result's metadata, extract the "text" element.
|
132 |
-
docs.append(res['metadata']['text'])
|
133 |
-
# pass
|
134 |
-
|
135 |
-
return docs
|
136 |
|
137 |
|
138 |
if __name__ == '__main__':
|
@@ -142,6 +221,7 @@ if __name__ == '__main__':
|
|
142 |
Language,
|
143 |
RecursiveCharacterTextSplitter,
|
144 |
)
|
|
|
145 |
|
146 |
loader = GitLoader(
|
147 |
clone_url="https://github.com/langchain-ai/langchain",
|
@@ -159,9 +239,32 @@ if __name__ == '__main__':
|
|
159 |
docs = python_splitter.split_documents(docs)
|
160 |
for doc in docs:
|
161 |
doc.page_content = '# {}\n\n'.format(doc.metadata['source']) + doc.page_content
|
162 |
-
|
163 |
indexer = DataIndexer()
|
164 |
-
|
|
|
165 |
for doc in docs:
|
166 |
file.writelines(doc.metadata['source'] + '\n')
|
|
|
167 |
indexer.index_data(docs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from langchain_community.vectorstores import Chroma
|
7 |
from langchain_openai import OpenAIEmbeddings
|
8 |
from huggingface_hub import InferenceClient
|
9 |
+
from typing import List
|
10 |
+
from datetime import datetime
|
11 |
+
from sentence_transformers import SentenceTransformer
|
12 |
+
from langchain.embeddings.base import Embeddings
|
13 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
14 |
+
import json
|
15 |
+
|
16 |
|
17 |
current_dir = Path(__file__).resolve().parent
|
18 |
|
19 |
+
class SentenceTransfmEmbeddings(Embeddings):
|
20 |
+
"""Sentence Transformers embedding class"""
|
21 |
+
|
22 |
+
def __init__(self, model_name: str = "sentence-transformers/all-mpnet-base-v2"):
|
23 |
+
self.model = SentenceTransformer(model_name)
|
24 |
+
|
25 |
+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
26 |
+
"""Embed a list of documents"""
|
27 |
+
try:
|
28 |
+
embeddings = self.model.encode(texts)
|
29 |
+
return embeddings.tolist()
|
30 |
+
except Exception as e:
|
31 |
+
print(f"Error embedding documents: {e}")
|
32 |
+
# Return dummy embeddings to prevent crash
|
33 |
+
return [[0.0] * 768 for _ in texts]
|
34 |
+
|
35 |
+
def embed_query(self, text: str) -> List[float]:
|
36 |
+
"""Embed a single query"""
|
37 |
+
try:
|
38 |
+
embedding = self.model.encode([text])
|
39 |
+
return embedding[0].tolist()
|
40 |
+
except Exception as e:
|
41 |
+
print(f"Error embedding query: {e}")
|
42 |
+
return [0.0] * 768
|
43 |
+
|
44 |
|
45 |
class DataIndexer:
|
46 |
|
|
|
50 |
|
51 |
# TODO: choose your embedding model
|
52 |
self.embedding_client = InferenceClient(
|
53 |
+
# "dunzhang/stella_en_1.5B_v5",
|
54 |
+
"sentence-transformers/all-mpnet-base-v2",
|
55 |
token=os.environ['HF_TOKEN'],
|
56 |
)
|
57 |
+
self.embeddings = SentenceTransfmEmbeddings(
|
58 |
+
"sentence-transformers/all-mpnet-base-v2"
|
59 |
+
)
|
60 |
+
# self.embeddings = HuggingFaceEmbeddings(
|
61 |
+
# model_name="sentence-transformers/all-mpnet-base-v2"
|
62 |
+
# )
|
63 |
self.spec = ServerlessSpec(
|
64 |
cloud = 'aws',
|
65 |
region='us-east-1'
|
|
|
73 |
# Make sure to choose the dimension that corresponds to your embedding model
|
74 |
self.pinecone_client.create_index(
|
75 |
name=index_name,
|
76 |
+
dimension=768,
|
77 |
metric='cosine',
|
78 |
spec=self.spec
|
79 |
)
|
80 |
|
81 |
self.index = self.pinecone_client.Index(self.index_name)
|
82 |
# TODO: make sure to build the index.
|
83 |
+
# with open(self.source_file, 'r') as file:
|
84 |
+
# sources = file.readlines()
|
85 |
+
|
86 |
+
# sources = [s.strip() for s in sources if s.strip()]
|
87 |
+
# if not sources:
|
88 |
+
# self.source_index = None
|
89 |
+
# else:
|
90 |
+
# self.source_index = self.get_source_index()
|
91 |
+
self.source_index=None
|
92 |
|
93 |
def get_source_index(self):
|
94 |
if not os.path.isfile(self.source_file):
|
|
|
100 |
with open(self.source_file, 'r') as file:
|
101 |
sources = file.readlines()
|
102 |
|
103 |
+
sources = [s.strip() for s in sources if s.strip()]
|
104 |
+
if not sources:
|
105 |
+
print("No valid sources to index")
|
106 |
+
return None
|
107 |
+
print("sources are:", sources)
|
108 |
+
## testing
|
109 |
+
embeddings = self.embeddings.embed_documents(sources)
|
110 |
+
print(f"Generated {len(embeddings)} embeddings for {len(sources)} sources")
|
111 |
+
## testing
|
112 |
vectorstore = Chroma.from_texts(
|
113 |
+
sources, embedding=self.embeddings
|
114 |
)
|
115 |
return vectorstore
|
116 |
|
|
|
119 |
with open(self.source_file, 'a') as file:
|
120 |
for doc in docs:
|
121 |
file.writelines(doc.metadata['source'] + '\n')
|
122 |
+
|
123 |
+
self.source_index = self.get_source_index()
|
124 |
|
125 |
for i in range(0, len(docs), batch_size):
|
126 |
batch = docs[i: i + batch_size]
|
127 |
|
128 |
# TODO: create a list of the vector representations of each text data in the batch
|
129 |
# TODO: choose your embedding model
|
130 |
+
# values = self.embedding_client.embed_documents([
|
|
|
|
|
|
|
|
|
131 |
# doc.page_content for doc in batch
|
132 |
# ])
|
133 |
+
|
134 |
+
values = self.embedding_client.feature_extraction([
|
135 |
+
doc.page_content for doc in batch
|
136 |
+
])
|
137 |
# values = None
|
138 |
|
139 |
# TODO: create a list of unique identifiers for each element in the batch with the uuid package.
|
|
|
142 |
# TODO: create a list of dictionaries representing the metadata. Capture the text data
|
143 |
# with the "text" key, and make sure to capture the rest of the doc.metadata.
|
144 |
metadatas = [{"text": doc.page_content,
|
145 |
+
**(doc.metadata if doc.metadata else {})
|
146 |
} for doc in batch]
|
147 |
|
148 |
# create a list of dictionaries with keys "id" (the unique identifiers), "values"
|
|
|
153 |
'metadata': metadata
|
154 |
} for vector_id, value, metadata in zip(vector_ids, values, metadatas)]
|
155 |
|
156 |
+
for v in vectors[:5]:
|
157 |
+
print("Metadata:", v['metadata'])
|
158 |
try:
|
159 |
# TODO: Use the function upsert to upload the data to the database.
|
160 |
upsert_response = self.index.upsert(vectors)
|
|
|
170 |
# to the question. Make sure to adjust this number as you see fit.
|
171 |
source_docs = self.source_index.similarity_search(text_query, 50)
|
172 |
filter = {"source": {"$in":[doc.page_content for doc in source_docs]}}
|
173 |
+
result=""
|
174 |
# TODO: embed the text_query by using the embedding model
|
175 |
# TODO: choose your embedding model
|
176 |
# vector = self.embedding_client.feature_extraction(text_query)
|
177 |
+
try:
|
178 |
+
print("text")
|
179 |
+
print(text_query)
|
180 |
+
vector = self.embedding_client.feature_extraction(
|
181 |
+
text = text_query,
|
182 |
+
)
|
183 |
+
if vector is None:
|
184 |
+
print("failed to embed the text query in vector search query for pinecone")
|
185 |
+
return []
|
186 |
+
else:
|
187 |
+
print("debug1_result")
|
188 |
+
result = self.index.query(vector,
|
189 |
+
filter=filter,
|
190 |
+
top_k=top_k,
|
191 |
+
include_values=True,
|
192 |
+
include_metadata=True
|
193 |
+
)
|
194 |
+
print(f"debugged_result query successful without error for the question:{text_query}")
|
195 |
+
|
196 |
+
docs = []
|
197 |
+
# print(f" none type in result? {result}")
|
198 |
+
for res in result["matches"]:
|
199 |
+
# TODO: From the result's metadata, extract the "text" element.
|
200 |
+
print("results filename:",res['metadata']['file_name'])
|
201 |
+
print("result score:",res['score'])
|
202 |
+
if res['score']>0.540:
|
203 |
+
docs.append(res['metadata']['text'])
|
204 |
+
# pass
|
205 |
+
# print("docs: ",docs[0])
|
206 |
+
|
207 |
+
return docs
|
208 |
+
except Exception as e:
|
209 |
+
print(f"error in search:{e}")
|
210 |
+
return []
|
211 |
|
212 |
# TODO: use the vector representation of the text_query to
|
213 |
# search the database by using the query function.
|
214 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
|
216 |
|
217 |
if __name__ == '__main__':
|
|
|
221 |
Language,
|
222 |
RecursiveCharacterTextSplitter,
|
223 |
)
|
224 |
+
print("start:", datetime.now())
|
225 |
|
226 |
loader = GitLoader(
|
227 |
clone_url="https://github.com/langchain-ai/langchain",
|
|
|
239 |
docs = python_splitter.split_documents(docs)
|
240 |
for doc in docs:
|
241 |
doc.page_content = '# {}\n\n'.format(doc.metadata['source']) + doc.page_content
|
242 |
+
print("before instacing the indexer:", datetime.now())
|
243 |
indexer = DataIndexer()
|
244 |
+
print("after instacing the indexer:", datetime.now())
|
245 |
+
with open('./app/sources.txt', 'a') as file:
|
246 |
for doc in docs:
|
247 |
file.writelines(doc.metadata['source'] + '\n')
|
248 |
+
print("after writing the indexer:", datetime.now())
|
249 |
indexer.index_data(docs)
|
250 |
+
print("end:", datetime.now())
|
251 |
+
|
252 |
+
# ###### test ###########
|
253 |
+
# test_docs = docs[:2] # Just try first two documents
|
254 |
+
# print("\nTest Document Details:")
|
255 |
+
# print(f"Number of test documents: {len(test_docs)}")
|
256 |
+
# for idx, doc in enumerate(test_docs):
|
257 |
+
# print(f"\nDocument {idx + 1}:")
|
258 |
+
# print(f"Content length: {len(doc.page_content)}")
|
259 |
+
# # print(f"First 100 chars: {doc.page_content[:100]}")
|
260 |
+
# print(f"Metadata: {doc.metadata}")
|
261 |
+
|
262 |
+
# # try:
|
263 |
+
# print("\nInitializing DataIndexer...")
|
264 |
+
# indexer = DataIndexer()
|
265 |
+
# print("\nStarting indexing...")
|
266 |
+
# indexer.index_data(test_docs)
|
267 |
+
# print("Test indexing successful")
|
268 |
+
# # except Exception as e:
|
269 |
+
# # print(f"Test indexing failed: {str(e)}")
|
270 |
+
|
main.py
CHANGED
@@ -9,11 +9,12 @@ from datetime import datetime
|
|
9 |
|
10 |
import schemas
|
11 |
from models import Message
|
12 |
-
from chains import simple_chain, formatted_chain, history_chain, rag_chain
|
13 |
from prompts import format_chat_history
|
14 |
import crud, models, schemas
|
15 |
from database import SessionLocal, engine
|
16 |
from callbacks import LogResponseCallback
|
|
|
17 |
|
18 |
# temporary
|
19 |
from database import engine
|
@@ -42,17 +43,32 @@ def get_db():
|
|
42 |
# yield {'data': data, "event": "data"}
|
43 |
# yield {"event": "end"}
|
44 |
|
45 |
-
async def generate_stream(input_data: schemas.BaseModel, runnable: Runnable, callbacks: List[BaseCallbackHandler]=[]):
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
yield {"event": "end"}
|
|
|
56 |
|
57 |
|
58 |
|
@@ -73,7 +89,7 @@ async def formatted_stream(request: Request):
|
|
73 |
output = EventSourceResponse(
|
74 |
generate_stream(
|
75 |
input_data = user_question,
|
76 |
-
runnable = formatted_chain
|
77 |
)
|
78 |
# print(output.generations[0][0].text)
|
79 |
return output
|
@@ -99,6 +115,7 @@ async def history_stream(request: Request, db: Session = Depends(get_db)):
|
|
99 |
# since history stream means
|
100 |
# we have existing user's no need to check for a user
|
101 |
chat_history = crud.get_user_chat_history(db, user_request.username)
|
|
|
102 |
history_input = schemas.HistoryInput(
|
103 |
chat_history = format_chat_history(chat_history),
|
104 |
question=user_request.question
|
@@ -106,7 +123,7 @@ async def history_stream(request: Request, db: Session = Depends(get_db)):
|
|
106 |
|
107 |
## adding messgae to message database
|
108 |
type = 'Human'
|
109 |
-
user_data = crud.get_or_create_user(db, user_request.username)
|
110 |
user_id = user_data.id
|
111 |
timestamp = str(datetime.now())
|
112 |
add_message = schemas.MessageBase(
|
@@ -120,9 +137,10 @@ async def history_stream(request: Request, db: Session = Depends(get_db)):
|
|
120 |
_ = crud.add_message(db,add_message, username = user_request.username)
|
121 |
# chat history contains: [{ message, type, timestamp}]
|
122 |
|
123 |
-
|
124 |
-
|
125 |
-
|
|
|
126 |
return output
|
127 |
# raise NotImplemented
|
128 |
|
@@ -151,11 +169,19 @@ async def rag_stream(request: Request, db: Session = Depends(get_db)):
|
|
151 |
user_id = user_id,
|
152 |
message = user_request.question,
|
153 |
type = type,
|
154 |
-
timestamp = timestamp
|
|
|
155 |
)
|
156 |
|
157 |
_ = crud.add_message(db,add_message, username = user_request.username)
|
158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
# raise NotImplemented
|
160 |
|
161 |
|
@@ -169,19 +195,11 @@ async def filtered_rag_stream(request: Request, db: Session = Depends(get_db)):
|
|
169 |
# - We create an instance of HistoryInput by using format_chat_history.
|
170 |
# - We use the history input within the filtered rag chain.
|
171 |
data = await request.json()
|
172 |
-
user_request =
|
173 |
-
|
174 |
-
messages = db.Query(
|
175 |
-
Message.message,
|
176 |
-
Message.type,
|
177 |
-
Message.timestamp
|
178 |
-
).filter(Message.user_id == user_request.username)
|
179 |
chat_history = messages
|
180 |
|
181 |
-
history_input = schemas.HistoryInput(
|
182 |
-
chat_history = format_chat_history(chat_history),
|
183 |
-
question=user_request.question
|
184 |
-
)
|
185 |
## adding messgae to message database
|
186 |
type = 'Human'
|
187 |
user_data = crud.get_or_create_user(db, user_request.username)
|
@@ -191,12 +209,20 @@ async def filtered_rag_stream(request: Request, db: Session = Depends(get_db)):
|
|
191 |
user_id = user_id,
|
192 |
message = user_request.question,
|
193 |
type = type,
|
194 |
-
timestamp = timestamp
|
|
|
195 |
)
|
196 |
|
197 |
_ = crud.add_message(db,add_message, username = user_request.username)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
|
199 |
-
return EventSourceResponse(generate_stream(history_input, filtered_rag_chain))
|
200 |
# raise NotImplemented
|
201 |
|
202 |
|
|
|
9 |
|
10 |
import schemas
|
11 |
from models import Message
|
12 |
+
from chains import simple_chain, formatted_chain, history_chain, rag_chain, filtered_rag_chain
|
13 |
from prompts import format_chat_history
|
14 |
import crud, models, schemas
|
15 |
from database import SessionLocal, engine
|
16 |
from callbacks import LogResponseCallback
|
17 |
+
import json
|
18 |
|
19 |
# temporary
|
20 |
from database import engine
|
|
|
43 |
# yield {'data': data, "event": "data"}
|
44 |
# yield {"event": "end"}
|
45 |
|
46 |
+
async def generate_stream(input_data: schemas.BaseModel, runnable: Runnable, callbacks: List[BaseCallbackHandler]=[], response_callback=None):
|
47 |
+
complete_response=""
|
48 |
+
if callbacks is None:
|
49 |
+
callbacks=[]
|
50 |
+
try:
|
51 |
+
stream_iterator = runnable.stream(input_data.dict(), config={"callbacks":callbacks})
|
52 |
+
for chunk in stream_iterator:
|
53 |
+
# ChatHuggingFace returns message chunks with content attribute
|
54 |
+
if hasattr(chunk, 'content'):
|
55 |
+
content = chunk.content
|
56 |
+
else:
|
57 |
+
content = str(chunk)
|
58 |
+
|
59 |
+
complete_response +=content
|
60 |
+
if content!="" or len(content)!=0: # Only yield non-empty content
|
61 |
+
yield {'data': json.dumps({"content":content}), "event": "data"}
|
62 |
+
# yield {'data': content, "event": "data"}
|
63 |
+
except StopIteration:
|
64 |
+
print("stream ended with StopIteration")
|
65 |
+
yield {"event":"end"}
|
66 |
+
# except Exception as e:
|
67 |
+
# print(f"error geenrating response :{e}")
|
68 |
+
if response_callback:
|
69 |
+
response_callback(complete_response)
|
70 |
yield {"event": "end"}
|
71 |
+
|
72 |
|
73 |
|
74 |
|
|
|
89 |
output = EventSourceResponse(
|
90 |
generate_stream(
|
91 |
input_data = user_question,
|
92 |
+
runnable = formatted_chain)
|
93 |
)
|
94 |
# print(output.generations[0][0].text)
|
95 |
return output
|
|
|
115 |
# since history stream means
|
116 |
# we have existing user's no need to check for a user
|
117 |
chat_history = crud.get_user_chat_history(db, user_request.username)
|
118 |
+
print("chat_history from the database", chat_history)
|
119 |
history_input = schemas.HistoryInput(
|
120 |
chat_history = format_chat_history(chat_history),
|
121 |
question=user_request.question
|
|
|
123 |
|
124 |
## adding messgae to message database
|
125 |
type = 'Human'
|
126 |
+
user_data = crud.get_or_create_user(db, user_request.username)
|
127 |
user_id = user_data.id
|
128 |
timestamp = str(datetime.now())
|
129 |
add_message = schemas.MessageBase(
|
|
|
137 |
_ = crud.add_message(db,add_message, username = user_request.username)
|
138 |
# chat history contains: [{ message, type, timestamp}]
|
139 |
|
140 |
+
init = LogResponseCallback(user_request = user_request, db = db)
|
141 |
+
def save_full_response(complete_response):
|
142 |
+
init.on_llm_end(outputs=complete_response)
|
143 |
+
output = EventSourceResponse(generate_stream(history_input, history_chain, response_callback=save_full_response))
|
144 |
return output
|
145 |
# raise NotImplemented
|
146 |
|
|
|
169 |
user_id = user_id,
|
170 |
message = user_request.question,
|
171 |
type = type,
|
172 |
+
timestamp = timestamp,
|
173 |
+
user=user_request.username,
|
174 |
)
|
175 |
|
176 |
_ = crud.add_message(db,add_message, username = user_request.username)
|
177 |
+
print("/rag/stream: \n: succesfully affed message to database")
|
178 |
+
init = LogResponseCallback(user_request = user_request, db = db)
|
179 |
+
print("succesfully intiated LogResponseCallback ")
|
180 |
+
def save_full_response(complete_response):
|
181 |
+
init.on_llm_end(outputs=complete_response)
|
182 |
+
|
183 |
+
print("calling EventSourceResponse to generate stream............")
|
184 |
+
return EventSourceResponse(generate_stream(history_input, rag_chain, response_callback=save_full_response))
|
185 |
# raise NotImplemented
|
186 |
|
187 |
|
|
|
195 |
# - We create an instance of HistoryInput by using format_chat_history.
|
196 |
# - We use the history input within the filtered rag chain.
|
197 |
data = await request.json()
|
198 |
+
user_request = schemas.UserRequest(**data['input'])
|
199 |
+
messages = crud.get_user_chat_history(db, user_request.username)
|
|
|
|
|
|
|
|
|
|
|
200 |
chat_history = messages
|
201 |
|
202 |
+
history_input = schemas.HistoryInput( chat_history = format_chat_history(chat_history), question=user_request.question)
|
|
|
|
|
|
|
203 |
## adding messgae to message database
|
204 |
type = 'Human'
|
205 |
user_data = crud.get_or_create_user(db, user_request.username)
|
|
|
209 |
user_id = user_id,
|
210 |
message = user_request.question,
|
211 |
type = type,
|
212 |
+
timestamp = timestamp,
|
213 |
+
user=user_request.username,
|
214 |
)
|
215 |
|
216 |
_ = crud.add_message(db,add_message, username = user_request.username)
|
217 |
+
print("/rag/stream: \n: succesfully affed message to database")
|
218 |
+
init = LogResponseCallback(user_request = user_request, db = db)
|
219 |
+
print("succesfully intiated LogResponseCallback ")
|
220 |
+
def save_full_response(complete_response):
|
221 |
+
init.on_llm_end(outputs=complete_response)
|
222 |
+
|
223 |
+
print("calling EventSourceResponse to generate stream............")
|
224 |
|
225 |
+
return EventSourceResponse(generate_stream(history_input, filtered_rag_chain, response_callback=save_full_response))
|
226 |
# raise NotImplemented
|
227 |
|
228 |
|
prompts.py
CHANGED
@@ -8,24 +8,24 @@ def format_prompt(prompt) -> PromptTemplate:
|
|
8 |
template = f"""
|
9 |
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
10 |
You are a helpful assistant.<|eot_id|>
|
11 |
-
<|start_header_id|>user<|end_header_id|>
|
12 |
{prompt}<|eot_id|>
|
13 |
<|start_header_id|>assistant<|end_header_id|>
|
14 |
"""
|
15 |
-
raw_template = [
|
16 |
-
|
17 |
-
|
18 |
-
]
|
19 |
-
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
|
20 |
-
formatted_template = tokenizer.apply_chat_template(
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
)
|
25 |
|
26 |
prompt_template = PromptTemplate.from_template(
|
27 |
# input_variables=["question"], the variables will be auto detected by langchain package
|
28 |
-
|
29 |
)
|
30 |
# TODO: return a langchain PromptTemplate
|
31 |
return prompt_template
|
@@ -64,6 +64,8 @@ raw_prompt = "{question}"
|
|
64 |
history_prompt: str = """
|
65 |
Given the following conversation provide a helpful answer to the following up question.
|
66 |
|
|
|
|
|
67 |
Chat History:
|
68 |
|
69 |
{chat_history}
|
|
|
8 |
template = f"""
|
9 |
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
10 |
You are a helpful assistant.<|eot_id|>
|
11 |
+
<|start_header_id|>user<|end_header_id|> Before answering tell me if you are given an empty context or not then answer
|
12 |
{prompt}<|eot_id|>
|
13 |
<|start_header_id|>assistant<|end_header_id|>
|
14 |
"""
|
15 |
+
# raw_template = [
|
16 |
+
# {"role": "system", "content":"You are a helpful assistant." },
|
17 |
+
# {"role": "user", "content": "{{prompt}}"},
|
18 |
+
# ]
|
19 |
+
# tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
|
20 |
+
# formatted_template = tokenizer.apply_chat_template(
|
21 |
+
# raw_template,
|
22 |
+
# tokenize=False,
|
23 |
+
# add_generation_prompt=True
|
24 |
+
# )
|
25 |
|
26 |
prompt_template = PromptTemplate.from_template(
|
27 |
# input_variables=["question"], the variables will be auto detected by langchain package
|
28 |
+
template
|
29 |
)
|
30 |
# TODO: return a langchain PromptTemplate
|
31 |
return prompt_template
|
|
|
64 |
history_prompt: str = """
|
65 |
Given the following conversation provide a helpful answer to the following up question.
|
66 |
|
67 |
+
explain me the previous questions if I ask,
|
68 |
+
|
69 |
Chat History:
|
70 |
|
71 |
{chat_history}
|
sources.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
test.db
CHANGED
Binary files a/test.db and b/test.db differ
|
|