sravan commited on
Commit
dc078e3
·
1 Parent(s): ae692a1

first working application

Browse files
Files changed (9) hide show
  1. .gitignore +1 -1
  2. callbacks.py +6 -3
  3. chains.py +88 -8
  4. code_data/langchain_repo +1 -0
  5. data_indexing.py +132 -29
  6. main.py +57 -31
  7. prompts.py +14 -12
  8. sources.txt +0 -0
  9. test.db +0 -0
.gitignore CHANGED
@@ -1,4 +1,4 @@
1
  myenv
2
  *pycache*
3
 
4
-
 
1
  myenv
2
  *pycache*
3
 
4
+ dang.py
callbacks.py CHANGED
@@ -2,6 +2,7 @@ from typing import Dict, Any, List
2
  from langchain_core.callbacks import BaseCallbackHandler
3
  import schemas
4
  import crud
 
5
 
6
 
7
  class LogResponseCallback(BaseCallbackHandler):
@@ -16,13 +17,15 @@ class LogResponseCallback(BaseCallbackHandler):
16
  # TODO: The function on_llm_end is going to be called when the LLM stops sending
17
  # the response. Use the crud.add_message function to capture that response.
18
  type = 'AI'
19
- user_data = crud.get_or_create(self.db, self.user_request.username)
20
- user_id = user_data.user_id
21
  timestamp = datetime.now()
22
- message = outputs.generations[0][0].text # answer from the prompt message
 
23
  message_to_add = schemas.MessageBase(
24
  user_id = user_id,
25
  message = message,
 
26
  type = type,
27
  timestamp = timestamp
28
  )
 
2
  from langchain_core.callbacks import BaseCallbackHandler
3
  import schemas
4
  import crud
5
+ from datetime import datetime
6
 
7
 
8
  class LogResponseCallback(BaseCallbackHandler):
 
17
  # TODO: The function on_llm_end is going to be called when the LLM stops sending
18
  # the response. Use the crud.add_message function to capture that response.
19
  type = 'AI'
20
+ user_data = crud.get_or_create_user(self.db, self.user_request.username)
21
+ user_id = user_data.id
22
  timestamp = datetime.now()
23
+ message = str(outputs) # answer from the prompt message
24
+ print("hoistory messages", message)
25
  message_to_add = schemas.MessageBase(
26
  user_id = user_id,
27
  message = message,
28
+ user=user_data.username,
29
  type = type,
30
  timestamp = timestamp
31
  )
chains.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
3
  from langchain_core.runnables import RunnablePassthrough
 
4
 
5
  import schemas
6
  import prompts
@@ -12,7 +13,7 @@ from prompts import (
12
  standalone_prompt_formatted,
13
  rag_prompt_formatted
14
  )
15
- from data_indexing import DataIndexer
16
  from transformers import AutoTokenizer
17
 
18
  data_indexer = DataIndexer()
@@ -52,37 +53,116 @@ llm_endpoint = HuggingFaceEndpoint(
52
 
53
  llm = ChatHuggingFace(llm=llm_endpoint)
54
 
 
 
 
 
 
 
 
 
55
  simple_chain = (raw_prompt | llm).with_types(input_type=schemas.UserQuestion)
56
 
 
 
57
  # TODO: create formatted_chain by piping raw_prompt_formatted and the LLM endpoint.
58
- formatted_chain = (raw_prompt_formatted | llm).with_types(input_type=schemas.UserQuestion)
59
 
60
  # TODO: use history_prompt_formatted and HistoryInput to create the history_chain
61
- history_chain = (history_prompt_formatted | llm).with_types(input_type=schemas.HistoryInput)
62
 
63
  # TODO: Let's construct the standalone_chain by piping standalone_prompt_formatted with the LLM
64
  standalone_chain = (standalone_prompt_formatted | llm).with_types(input_type=schemas.HistoryInput)
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  input_1 = RunnablePassthrough.assign(new_question=standalone_chain)
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  input_2 = {
68
- 'context': lambda x: format_context(data_indexer.search(x['new_question'])),
69
- 'standalone_question': lambda x: x['new_question'] # new question was the parameter in input1
70
  }
 
71
  input_to_rag_chain = input_1 | input_2
72
 
73
  # TODO: use input_to_rag_chain, rag_prompt_formatted,
74
  # HistoryInput and the LLM to build the rag_chain.
75
- rag_chain = (input_to_rag_chain | rag_prompt_formatted | llm).with_types(input_type=schemas.RagInput)
76
 
77
  # TODO: Implement the filtered_rag_chain. It should be the
78
  # same as the rag_chain but with hybrid_search = True.
79
 
80
  input_2_hybrid_search = {
81
- 'context': lambda x: format_context(data_indexer.search(x['new_question'], hybrid_search=True)),
82
  'standalone_question': lambda x: x['new_question']
83
  }
84
 
85
- filtered_rag_chain = (input_1 | input_2_hybrid_search | rag_prompt_formatted | llm ).with_types(input_type=schemas.RagInput)
86
 
87
 
88
 
 
1
  import os
2
  from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
3
  from langchain_core.runnables import RunnablePassthrough
4
+ from langchain.schema.runnable import RunnableLambda
5
 
6
  import schemas
7
  import prompts
 
13
  standalone_prompt_formatted,
14
  rag_prompt_formatted
15
  )
16
+ from data_indexing import DataIndexer
17
  from transformers import AutoTokenizer
18
 
19
  data_indexer = DataIndexer()
 
53
 
54
  llm = ChatHuggingFace(llm=llm_endpoint)
55
 
56
+ def print_and_pass(prompt_output):
57
+ print("=" * 60)
58
+ print("🔍 RAW PROMPT FORMATTED:")
59
+ print("=" * 60)
60
+ print(prompt_output)
61
+ print("=" * 60)
62
+ return prompt_output # IMPORTANT: Must return the prompt unchanged
63
+
64
  simple_chain = (raw_prompt | llm).with_types(input_type=schemas.UserQuestion)
65
 
66
+
67
+
68
  # TODO: create formatted_chain by piping raw_prompt_formatted and the LLM endpoint.
69
+ formatted_chain = (raw_prompt_formatted | RunnableLambda(print_and_pass) | llm).with_types(input_type=schemas.UserQuestion)
70
 
71
  # TODO: use history_prompt_formatted and HistoryInput to create the history_chain
72
+ history_chain = (history_prompt_formatted | RunnableLambda(print_and_pass) | llm).with_types(input_type=schemas.HistoryInput)
73
 
74
  # TODO: Let's construct the standalone_chain by piping standalone_prompt_formatted with the LLM
75
  standalone_chain = (standalone_prompt_formatted | llm).with_types(input_type=schemas.HistoryInput)
76
 
77
+ # summarize_chain = (summarize_propt_formatted | llm)
78
+
79
+ import ast
80
+
81
+ def extract_definitions(source_code):
82
+ """
83
+ Extract top-level function and class definitions from Python code.
84
+ """
85
+ result = []
86
+ try:
87
+ tree = ast.parse(source_code)
88
+ for node in ast.iter_child_nodes(tree):
89
+ if isinstance(node, (ast.FunctionDef, ast.ClassDef)):
90
+ snippet = ast.get_source_segment(source_code, node)
91
+ if snippet:
92
+ result.append(snippet)
93
+ except Exception as e:
94
+ print(f"Failed to parse code: {e}")
95
+ return result
96
+
97
+ import re
98
+
99
+ def clean_code_text(code_text):
100
+ """
101
+ Remove comments and excessive blank lines for brevity.
102
+ """
103
+ # Remove multiline docstrings and comments
104
+ code_text = re.sub(r'"""(.*?)"""', '', code_text, flags=re.DOTALL)
105
+ code_text = re.sub(r"'''(.*?)'''", '', code_text, flags=re.DOTALL)
106
+
107
+ # Remove inline comments
108
+ code_text = re.sub(r'#.*', '', code_text)
109
+
110
+ # Remove excessive whitespace
111
+ code_text = re.sub(r'\n\s*\n+', '\n\n', code_text)
112
+
113
+ return code_text.strip()
114
+
115
+
116
+ def safe_format_context(search_results):
117
+ try:
118
+ cleaned_results = []
119
+ for result in search_results:
120
+ if isinstance(result, str):
121
+ # Optionally: extract relevant functions/classes
122
+ code_parts = extract_definitions(result)
123
+ for part in code_parts:
124
+ cleaned = clean_code_text(part)
125
+ cleaned_results.append(cleaned)
126
+ return format_context(cleaned_results)
127
+ except Exception as e:
128
+ print(f"Error formatting context: {str(e)}")
129
+ return "No relevant context found."
130
+
131
+
132
  input_1 = RunnablePassthrough.assign(new_question=standalone_chain)
133
+
134
+ # input_1_beta = RunnablePassThrough.assign(new_context=summarize_chain)
135
+
136
+ def extract_question_text(new_question):
137
+ if hasattr(new_question, "content"):
138
+ return new_question.content
139
+ return str(new_question)
140
+
141
+ # summarize_context = {
142
+ # 'context': lambda x: safe_format_context(data_indexer.search(extract_question_text(x['new_question']))),
143
+ # 'standalone_question': lambda x: extract_question_text(x['new_question']),
144
+ # }
145
+
146
  input_2 = {
147
+ 'context': lambda x: safe_format_context(data_indexer.search(extract_question_text(x['new_question']))),
148
+ 'standalone_question': lambda x: extract_question_text(x['new_question']),
149
  }
150
+
151
  input_to_rag_chain = input_1 | input_2
152
 
153
  # TODO: use input_to_rag_chain, rag_prompt_formatted,
154
  # HistoryInput and the LLM to build the rag_chain.
155
+ rag_chain = (input_to_rag_chain | RunnableLambda(print_and_pass) | rag_prompt_formatted | RunnableLambda(print_and_pass) | llm).with_types(input_type=schemas.RagInput)
156
 
157
  # TODO: Implement the filtered_rag_chain. It should be the
158
  # same as the rag_chain but with hybrid_search = True.
159
 
160
  input_2_hybrid_search = {
161
+ 'context': lambda x: safe_format_context(data_indexer.search(extract_question_text(x['new_question']), hybrid_search=True)),
162
  'standalone_question': lambda x: x['new_question']
163
  }
164
 
165
+ filtered_rag_chain = (input_1 | input_2_hybrid_search | rag_prompt_formatted | RunnableLambda(print_and_pass)| llm ).with_types(input_type=schemas.RagInput)
166
 
167
 
168
 
code_data/langchain_repo ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 2d0713c2fc5a457578635b03b7a00e970ce534ee
data_indexing.py CHANGED
@@ -6,9 +6,41 @@ from pinecone import ServerlessSpec
6
  from langchain_community.vectorstores import Chroma
7
  from langchain_openai import OpenAIEmbeddings
8
  from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
9
 
10
  current_dir = Path(__file__).resolve().parent
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  class DataIndexer:
14
 
@@ -18,9 +50,16 @@ class DataIndexer:
18
 
19
  # TODO: choose your embedding model
20
  self.embedding_client = InferenceClient(
21
- "dunzhang/stella_en_1.5B_v5",
 
22
  token=os.environ['HF_TOKEN'],
23
  )
 
 
 
 
 
 
24
  self.spec = ServerlessSpec(
25
  cloud = 'aws',
26
  region='us-east-1'
@@ -34,14 +73,22 @@ class DataIndexer:
34
  # Make sure to choose the dimension that corresponds to your embedding model
35
  self.pinecone_client.create_index(
36
  name=index_name,
37
- dimension=1024,
38
  metric='cosine',
39
  spec=self.spec
40
  )
41
 
42
  self.index = self.pinecone_client.Index(self.index_name)
43
  # TODO: make sure to build the index.
44
- self.source_index = self.get_source_index()
 
 
 
 
 
 
 
 
45
 
46
  def get_source_index(self):
47
  if not os.path.isfile(self.source_file):
@@ -53,9 +100,17 @@ class DataIndexer:
53
  with open(self.source_file, 'r') as file:
54
  sources = file.readlines()
55
 
56
- sources = [s.rstrip('\n') for s in sources]
 
 
 
 
 
 
 
 
57
  vectorstore = Chroma.from_texts(
58
- sources, embedding=self.embedding_client
59
  )
60
  return vectorstore
61
 
@@ -64,19 +119,21 @@ class DataIndexer:
64
  with open(self.source_file, 'a') as file:
65
  for doc in docs:
66
  file.writelines(doc.metadata['source'] + '\n')
 
 
67
 
68
  for i in range(0, len(docs), batch_size):
69
  batch = docs[i: i + batch_size]
70
 
71
  # TODO: create a list of the vector representations of each text data in the batch
72
  # TODO: choose your embedding model
73
- values = self.embedding_client.embed_documents([
74
- doc.page_content for doc in batch
75
- ])
76
-
77
- # values = self.embedding_client.feature_extraction([
78
  # doc.page_content for doc in batch
79
  # ])
 
 
 
 
80
  # values = None
81
 
82
  # TODO: create a list of unique identifiers for each element in the batch with the uuid package.
@@ -85,7 +142,7 @@ class DataIndexer:
85
  # TODO: create a list of dictionaries representing the metadata. Capture the text data
86
  # with the "text" key, and make sure to capture the rest of the doc.metadata.
87
  metadatas = [{"text": doc.page_content,
88
- **doc.metadata
89
  } for doc in batch]
90
 
91
  # create a list of dictionaries with keys "id" (the unique identifiers), "values"
@@ -96,6 +153,8 @@ class DataIndexer:
96
  'metadata': metadata
97
  } for vector_id, value, metadata in zip(vector_ids, values, metadatas)]
98
 
 
 
99
  try:
100
  # TODO: Use the function upsert to upload the data to the database.
101
  upsert_response = self.index.upsert(vectors)
@@ -111,28 +170,48 @@ class DataIndexer:
111
  # to the question. Make sure to adjust this number as you see fit.
112
  source_docs = self.source_index.similarity_search(text_query, 50)
113
  filter = {"source": {"$in":[doc.page_content for doc in source_docs]}}
114
-
115
  # TODO: embed the text_query by using the embedding model
116
  # TODO: choose your embedding model
117
  # vector = self.embedding_client.feature_extraction(text_query)
118
- vector = self.embedding_client.embed_query(text_query)
119
- # vector = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  # TODO: use the vector representation of the text_query to
122
  # search the database by using the query function.
123
- result = self.index.query(vector,
124
- filter=filter,
125
- top_k=top_k,
126
- include_values=True
127
- )
128
-
129
- docs = []
130
- for res in result["matches"]:
131
- # TODO: From the result's metadata, extract the "text" element.
132
- docs.append(res['metadata']['text'])
133
- # pass
134
-
135
- return docs
136
 
137
 
138
  if __name__ == '__main__':
@@ -142,6 +221,7 @@ if __name__ == '__main__':
142
  Language,
143
  RecursiveCharacterTextSplitter,
144
  )
 
145
 
146
  loader = GitLoader(
147
  clone_url="https://github.com/langchain-ai/langchain",
@@ -159,9 +239,32 @@ if __name__ == '__main__':
159
  docs = python_splitter.split_documents(docs)
160
  for doc in docs:
161
  doc.page_content = '# {}\n\n'.format(doc.metadata['source']) + doc.page_content
162
-
163
  indexer = DataIndexer()
164
- with open('/app/sources.txt', 'a') as file:
 
165
  for doc in docs:
166
  file.writelines(doc.metadata['source'] + '\n')
 
167
  indexer.index_data(docs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from langchain_community.vectorstores import Chroma
7
  from langchain_openai import OpenAIEmbeddings
8
  from huggingface_hub import InferenceClient
9
+ from typing import List
10
+ from datetime import datetime
11
+ from sentence_transformers import SentenceTransformer
12
+ from langchain.embeddings.base import Embeddings
13
+ from langchain_community.embeddings import HuggingFaceEmbeddings
14
+ import json
15
+
16
 
17
  current_dir = Path(__file__).resolve().parent
18
 
19
+ class SentenceTransfmEmbeddings(Embeddings):
20
+ """Sentence Transformers embedding class"""
21
+
22
+ def __init__(self, model_name: str = "sentence-transformers/all-mpnet-base-v2"):
23
+ self.model = SentenceTransformer(model_name)
24
+
25
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
26
+ """Embed a list of documents"""
27
+ try:
28
+ embeddings = self.model.encode(texts)
29
+ return embeddings.tolist()
30
+ except Exception as e:
31
+ print(f"Error embedding documents: {e}")
32
+ # Return dummy embeddings to prevent crash
33
+ return [[0.0] * 768 for _ in texts]
34
+
35
+ def embed_query(self, text: str) -> List[float]:
36
+ """Embed a single query"""
37
+ try:
38
+ embedding = self.model.encode([text])
39
+ return embedding[0].tolist()
40
+ except Exception as e:
41
+ print(f"Error embedding query: {e}")
42
+ return [0.0] * 768
43
+
44
 
45
  class DataIndexer:
46
 
 
50
 
51
  # TODO: choose your embedding model
52
  self.embedding_client = InferenceClient(
53
+ # "dunzhang/stella_en_1.5B_v5",
54
+ "sentence-transformers/all-mpnet-base-v2",
55
  token=os.environ['HF_TOKEN'],
56
  )
57
+ self.embeddings = SentenceTransfmEmbeddings(
58
+ "sentence-transformers/all-mpnet-base-v2"
59
+ )
60
+ # self.embeddings = HuggingFaceEmbeddings(
61
+ # model_name="sentence-transformers/all-mpnet-base-v2"
62
+ # )
63
  self.spec = ServerlessSpec(
64
  cloud = 'aws',
65
  region='us-east-1'
 
73
  # Make sure to choose the dimension that corresponds to your embedding model
74
  self.pinecone_client.create_index(
75
  name=index_name,
76
+ dimension=768,
77
  metric='cosine',
78
  spec=self.spec
79
  )
80
 
81
  self.index = self.pinecone_client.Index(self.index_name)
82
  # TODO: make sure to build the index.
83
+ # with open(self.source_file, 'r') as file:
84
+ # sources = file.readlines()
85
+
86
+ # sources = [s.strip() for s in sources if s.strip()]
87
+ # if not sources:
88
+ # self.source_index = None
89
+ # else:
90
+ # self.source_index = self.get_source_index()
91
+ self.source_index=None
92
 
93
  def get_source_index(self):
94
  if not os.path.isfile(self.source_file):
 
100
  with open(self.source_file, 'r') as file:
101
  sources = file.readlines()
102
 
103
+ sources = [s.strip() for s in sources if s.strip()]
104
+ if not sources:
105
+ print("No valid sources to index")
106
+ return None
107
+ print("sources are:", sources)
108
+ ## testing
109
+ embeddings = self.embeddings.embed_documents(sources)
110
+ print(f"Generated {len(embeddings)} embeddings for {len(sources)} sources")
111
+ ## testing
112
  vectorstore = Chroma.from_texts(
113
+ sources, embedding=self.embeddings
114
  )
115
  return vectorstore
116
 
 
119
  with open(self.source_file, 'a') as file:
120
  for doc in docs:
121
  file.writelines(doc.metadata['source'] + '\n')
122
+
123
+ self.source_index = self.get_source_index()
124
 
125
  for i in range(0, len(docs), batch_size):
126
  batch = docs[i: i + batch_size]
127
 
128
  # TODO: create a list of the vector representations of each text data in the batch
129
  # TODO: choose your embedding model
130
+ # values = self.embedding_client.embed_documents([
 
 
 
 
131
  # doc.page_content for doc in batch
132
  # ])
133
+
134
+ values = self.embedding_client.feature_extraction([
135
+ doc.page_content for doc in batch
136
+ ])
137
  # values = None
138
 
139
  # TODO: create a list of unique identifiers for each element in the batch with the uuid package.
 
142
  # TODO: create a list of dictionaries representing the metadata. Capture the text data
143
  # with the "text" key, and make sure to capture the rest of the doc.metadata.
144
  metadatas = [{"text": doc.page_content,
145
+ **(doc.metadata if doc.metadata else {})
146
  } for doc in batch]
147
 
148
  # create a list of dictionaries with keys "id" (the unique identifiers), "values"
 
153
  'metadata': metadata
154
  } for vector_id, value, metadata in zip(vector_ids, values, metadatas)]
155
 
156
+ for v in vectors[:5]:
157
+ print("Metadata:", v['metadata'])
158
  try:
159
  # TODO: Use the function upsert to upload the data to the database.
160
  upsert_response = self.index.upsert(vectors)
 
170
  # to the question. Make sure to adjust this number as you see fit.
171
  source_docs = self.source_index.similarity_search(text_query, 50)
172
  filter = {"source": {"$in":[doc.page_content for doc in source_docs]}}
173
+ result=""
174
  # TODO: embed the text_query by using the embedding model
175
  # TODO: choose your embedding model
176
  # vector = self.embedding_client.feature_extraction(text_query)
177
+ try:
178
+ print("text")
179
+ print(text_query)
180
+ vector = self.embedding_client.feature_extraction(
181
+ text = text_query,
182
+ )
183
+ if vector is None:
184
+ print("failed to embed the text query in vector search query for pinecone")
185
+ return []
186
+ else:
187
+ print("debug1_result")
188
+ result = self.index.query(vector,
189
+ filter=filter,
190
+ top_k=top_k,
191
+ include_values=True,
192
+ include_metadata=True
193
+ )
194
+ print(f"debugged_result query successful without error for the question:{text_query}")
195
+
196
+ docs = []
197
+ # print(f" none type in result? {result}")
198
+ for res in result["matches"]:
199
+ # TODO: From the result's metadata, extract the "text" element.
200
+ print("results filename:",res['metadata']['file_name'])
201
+ print("result score:",res['score'])
202
+ if res['score']>0.540:
203
+ docs.append(res['metadata']['text'])
204
+ # pass
205
+ # print("docs: ",docs[0])
206
+
207
+ return docs
208
+ except Exception as e:
209
+ print(f"error in search:{e}")
210
+ return []
211
 
212
  # TODO: use the vector representation of the text_query to
213
  # search the database by using the query function.
214
+
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
 
217
  if __name__ == '__main__':
 
221
  Language,
222
  RecursiveCharacterTextSplitter,
223
  )
224
+ print("start:", datetime.now())
225
 
226
  loader = GitLoader(
227
  clone_url="https://github.com/langchain-ai/langchain",
 
239
  docs = python_splitter.split_documents(docs)
240
  for doc in docs:
241
  doc.page_content = '# {}\n\n'.format(doc.metadata['source']) + doc.page_content
242
+ print("before instacing the indexer:", datetime.now())
243
  indexer = DataIndexer()
244
+ print("after instacing the indexer:", datetime.now())
245
+ with open('./app/sources.txt', 'a') as file:
246
  for doc in docs:
247
  file.writelines(doc.metadata['source'] + '\n')
248
+ print("after writing the indexer:", datetime.now())
249
  indexer.index_data(docs)
250
+ print("end:", datetime.now())
251
+
252
+ # ###### test ###########
253
+ # test_docs = docs[:2] # Just try first two documents
254
+ # print("\nTest Document Details:")
255
+ # print(f"Number of test documents: {len(test_docs)}")
256
+ # for idx, doc in enumerate(test_docs):
257
+ # print(f"\nDocument {idx + 1}:")
258
+ # print(f"Content length: {len(doc.page_content)}")
259
+ # # print(f"First 100 chars: {doc.page_content[:100]}")
260
+ # print(f"Metadata: {doc.metadata}")
261
+
262
+ # # try:
263
+ # print("\nInitializing DataIndexer...")
264
+ # indexer = DataIndexer()
265
+ # print("\nStarting indexing...")
266
+ # indexer.index_data(test_docs)
267
+ # print("Test indexing successful")
268
+ # # except Exception as e:
269
+ # # print(f"Test indexing failed: {str(e)}")
270
+
main.py CHANGED
@@ -9,11 +9,12 @@ from datetime import datetime
9
 
10
  import schemas
11
  from models import Message
12
- from chains import simple_chain, formatted_chain, history_chain, rag_chain
13
  from prompts import format_chat_history
14
  import crud, models, schemas
15
  from database import SessionLocal, engine
16
  from callbacks import LogResponseCallback
 
17
 
18
  # temporary
19
  from database import engine
@@ -42,17 +43,32 @@ def get_db():
42
  # yield {'data': data, "event": "data"}
43
  # yield {"event": "end"}
44
 
45
- async def generate_stream(input_data: schemas.BaseModel, runnable: Runnable, callbacks: List[BaseCallbackHandler]=[]):
46
- for chunk in runnable.stream(input_data.dict(), config={"callbacks": callbacks}):
47
- # ChatHuggingFace returns message chunks with content attribute
48
- if hasattr(chunk, 'content'):
49
- content = chunk.content
50
- else:
51
- content = str(chunk)
52
-
53
- if content: # Only yield non-empty content
54
- yield {'data': content, "event": "data"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  yield {"event": "end"}
 
56
 
57
 
58
 
@@ -73,7 +89,7 @@ async def formatted_stream(request: Request):
73
  output = EventSourceResponse(
74
  generate_stream(
75
  input_data = user_question,
76
- runnable = formatted_chain )
77
  )
78
  # print(output.generations[0][0].text)
79
  return output
@@ -99,6 +115,7 @@ async def history_stream(request: Request, db: Session = Depends(get_db)):
99
  # since history stream means
100
  # we have existing user's no need to check for a user
101
  chat_history = crud.get_user_chat_history(db, user_request.username)
 
102
  history_input = schemas.HistoryInput(
103
  chat_history = format_chat_history(chat_history),
104
  question=user_request.question
@@ -106,7 +123,7 @@ async def history_stream(request: Request, db: Session = Depends(get_db)):
106
 
107
  ## adding messgae to message database
108
  type = 'Human'
109
- user_data = crud.get_or_create_user(db, user_request.username)
110
  user_id = user_data.id
111
  timestamp = str(datetime.now())
112
  add_message = schemas.MessageBase(
@@ -120,9 +137,10 @@ async def history_stream(request: Request, db: Session = Depends(get_db)):
120
  _ = crud.add_message(db,add_message, username = user_request.username)
121
  # chat history contains: [{ message, type, timestamp}]
122
 
123
- output = EventSourceResponse(generate_stream(history_input, history_chain))
124
- LogResponseCallback.on_llm_end(outputs = output)
125
-
 
126
  return output
127
  # raise NotImplemented
128
 
@@ -151,11 +169,19 @@ async def rag_stream(request: Request, db: Session = Depends(get_db)):
151
  user_id = user_id,
152
  message = user_request.question,
153
  type = type,
154
- timestamp = timestamp
 
155
  )
156
 
157
  _ = crud.add_message(db,add_message, username = user_request.username)
158
- return EventSourceResponse(generate_stream(history_input, rag_chain))
 
 
 
 
 
 
 
159
  # raise NotImplemented
160
 
161
 
@@ -169,19 +195,11 @@ async def filtered_rag_stream(request: Request, db: Session = Depends(get_db)):
169
  # - We create an instance of HistoryInput by using format_chat_history.
170
  # - We use the history input within the filtered rag chain.
171
  data = await request.json()
172
- user_request = models.UserRequest(**dat['input'])
173
-
174
- messages = db.Query(
175
- Message.message,
176
- Message.type,
177
- Message.timestamp
178
- ).filter(Message.user_id == user_request.username)
179
  chat_history = messages
180
 
181
- history_input = schemas.HistoryInput(
182
- chat_history = format_chat_history(chat_history),
183
- question=user_request.question
184
- )
185
  ## adding messgae to message database
186
  type = 'Human'
187
  user_data = crud.get_or_create_user(db, user_request.username)
@@ -191,12 +209,20 @@ async def filtered_rag_stream(request: Request, db: Session = Depends(get_db)):
191
  user_id = user_id,
192
  message = user_request.question,
193
  type = type,
194
- timestamp = timestamp
 
195
  )
196
 
197
  _ = crud.add_message(db,add_message, username = user_request.username)
 
 
 
 
 
 
 
198
 
199
- return EventSourceResponse(generate_stream(history_input, filtered_rag_chain))
200
  # raise NotImplemented
201
 
202
 
 
9
 
10
  import schemas
11
  from models import Message
12
+ from chains import simple_chain, formatted_chain, history_chain, rag_chain, filtered_rag_chain
13
  from prompts import format_chat_history
14
  import crud, models, schemas
15
  from database import SessionLocal, engine
16
  from callbacks import LogResponseCallback
17
+ import json
18
 
19
  # temporary
20
  from database import engine
 
43
  # yield {'data': data, "event": "data"}
44
  # yield {"event": "end"}
45
 
46
+ async def generate_stream(input_data: schemas.BaseModel, runnable: Runnable, callbacks: List[BaseCallbackHandler]=[], response_callback=None):
47
+ complete_response=""
48
+ if callbacks is None:
49
+ callbacks=[]
50
+ try:
51
+ stream_iterator = runnable.stream(input_data.dict(), config={"callbacks":callbacks})
52
+ for chunk in stream_iterator:
53
+ # ChatHuggingFace returns message chunks with content attribute
54
+ if hasattr(chunk, 'content'):
55
+ content = chunk.content
56
+ else:
57
+ content = str(chunk)
58
+
59
+ complete_response +=content
60
+ if content!="" or len(content)!=0: # Only yield non-empty content
61
+ yield {'data': json.dumps({"content":content}), "event": "data"}
62
+ # yield {'data': content, "event": "data"}
63
+ except StopIteration:
64
+ print("stream ended with StopIteration")
65
+ yield {"event":"end"}
66
+ # except Exception as e:
67
+ # print(f"error geenrating response :{e}")
68
+ if response_callback:
69
+ response_callback(complete_response)
70
  yield {"event": "end"}
71
+
72
 
73
 
74
 
 
89
  output = EventSourceResponse(
90
  generate_stream(
91
  input_data = user_question,
92
+ runnable = formatted_chain)
93
  )
94
  # print(output.generations[0][0].text)
95
  return output
 
115
  # since history stream means
116
  # we have existing user's no need to check for a user
117
  chat_history = crud.get_user_chat_history(db, user_request.username)
118
+ print("chat_history from the database", chat_history)
119
  history_input = schemas.HistoryInput(
120
  chat_history = format_chat_history(chat_history),
121
  question=user_request.question
 
123
 
124
  ## adding messgae to message database
125
  type = 'Human'
126
+ user_data = crud.get_or_create_user(db, user_request.username)
127
  user_id = user_data.id
128
  timestamp = str(datetime.now())
129
  add_message = schemas.MessageBase(
 
137
  _ = crud.add_message(db,add_message, username = user_request.username)
138
  # chat history contains: [{ message, type, timestamp}]
139
 
140
+ init = LogResponseCallback(user_request = user_request, db = db)
141
+ def save_full_response(complete_response):
142
+ init.on_llm_end(outputs=complete_response)
143
+ output = EventSourceResponse(generate_stream(history_input, history_chain, response_callback=save_full_response))
144
  return output
145
  # raise NotImplemented
146
 
 
169
  user_id = user_id,
170
  message = user_request.question,
171
  type = type,
172
+ timestamp = timestamp,
173
+ user=user_request.username,
174
  )
175
 
176
  _ = crud.add_message(db,add_message, username = user_request.username)
177
+ print("/rag/stream: \n: succesfully affed message to database")
178
+ init = LogResponseCallback(user_request = user_request, db = db)
179
+ print("succesfully intiated LogResponseCallback ")
180
+ def save_full_response(complete_response):
181
+ init.on_llm_end(outputs=complete_response)
182
+
183
+ print("calling EventSourceResponse to generate stream............")
184
+ return EventSourceResponse(generate_stream(history_input, rag_chain, response_callback=save_full_response))
185
  # raise NotImplemented
186
 
187
 
 
195
  # - We create an instance of HistoryInput by using format_chat_history.
196
  # - We use the history input within the filtered rag chain.
197
  data = await request.json()
198
+ user_request = schemas.UserRequest(**data['input'])
199
+ messages = crud.get_user_chat_history(db, user_request.username)
 
 
 
 
 
200
  chat_history = messages
201
 
202
+ history_input = schemas.HistoryInput( chat_history = format_chat_history(chat_history), question=user_request.question)
 
 
 
203
  ## adding messgae to message database
204
  type = 'Human'
205
  user_data = crud.get_or_create_user(db, user_request.username)
 
209
  user_id = user_id,
210
  message = user_request.question,
211
  type = type,
212
+ timestamp = timestamp,
213
+ user=user_request.username,
214
  )
215
 
216
  _ = crud.add_message(db,add_message, username = user_request.username)
217
+ print("/rag/stream: \n: succesfully affed message to database")
218
+ init = LogResponseCallback(user_request = user_request, db = db)
219
+ print("succesfully intiated LogResponseCallback ")
220
+ def save_full_response(complete_response):
221
+ init.on_llm_end(outputs=complete_response)
222
+
223
+ print("calling EventSourceResponse to generate stream............")
224
 
225
+ return EventSourceResponse(generate_stream(history_input, filtered_rag_chain, response_callback=save_full_response))
226
  # raise NotImplemented
227
 
228
 
prompts.py CHANGED
@@ -8,24 +8,24 @@ def format_prompt(prompt) -> PromptTemplate:
8
  template = f"""
9
  <|begin_of_text|><|start_header_id|>system<|end_header_id|>
10
  You are a helpful assistant.<|eot_id|>
11
- <|start_header_id|>user<|end_header_id|>
12
  {prompt}<|eot_id|>
13
  <|start_header_id|>assistant<|end_header_id|>
14
  """
15
- raw_template = [
16
- {"role": "system", "content":"You are a helpful assistant." },
17
- {"role": "user", "content": "{{prompt}}"},
18
- ]
19
- tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
20
- formatted_template = tokenizer.apply_chat_template(
21
- raw_template,
22
- tokenize=False,
23
- add_generation_prompt=True
24
- )
25
 
26
  prompt_template = PromptTemplate.from_template(
27
  # input_variables=["question"], the variables will be auto detected by langchain package
28
- formatted_template
29
  )
30
  # TODO: return a langchain PromptTemplate
31
  return prompt_template
@@ -64,6 +64,8 @@ raw_prompt = "{question}"
64
  history_prompt: str = """
65
  Given the following conversation provide a helpful answer to the following up question.
66
 
 
 
67
  Chat History:
68
 
69
  {chat_history}
 
8
  template = f"""
9
  <|begin_of_text|><|start_header_id|>system<|end_header_id|>
10
  You are a helpful assistant.<|eot_id|>
11
+ <|start_header_id|>user<|end_header_id|> Before answering tell me if you are given an empty context or not then answer
12
  {prompt}<|eot_id|>
13
  <|start_header_id|>assistant<|end_header_id|>
14
  """
15
+ # raw_template = [
16
+ # {"role": "system", "content":"You are a helpful assistant." },
17
+ # {"role": "user", "content": "{{prompt}}"},
18
+ # ]
19
+ # tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
20
+ # formatted_template = tokenizer.apply_chat_template(
21
+ # raw_template,
22
+ # tokenize=False,
23
+ # add_generation_prompt=True
24
+ # )
25
 
26
  prompt_template = PromptTemplate.from_template(
27
  # input_variables=["question"], the variables will be auto detected by langchain package
28
+ template
29
  )
30
  # TODO: return a langchain PromptTemplate
31
  return prompt_template
 
64
  history_prompt: str = """
65
  Given the following conversation provide a helpful answer to the following up question.
66
 
67
+ explain me the previous questions if I ask,
68
+
69
  Chat History:
70
 
71
  {chat_history}
sources.txt ADDED
The diff for this file is too large to render. See raw diff
 
test.db CHANGED
Binary files a/test.db and b/test.db differ