nciso commited on
Commit
bd346cd
·
verified ·
1 Parent(s): db2c89b

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +52 -46
app.py CHANGED
@@ -1,4 +1,5 @@
1
 
 
2
  # Import necessary libraries
3
  import os # Interacting with the operating system (reading/writing files)
4
  import chromadb # High-performance vector database for storing/querying dense vectors
@@ -61,7 +62,7 @@ MEM0_api_key = os.getenv("MEM0_API_KEY")
61
  embedding_function = chromadb.utils.embedding_functions.OpenAIEmbeddingFunction(
62
  api_base=endpoint, # Complete the code to define the API base endpoint
63
  api_key=api_key, # Complete the code to define the API key
64
- model_name='text-embedding-ada-002' # This is a fixed value and does not need modification
65
  )
66
 
67
  # This initializes the OpenAI embedding function for the Chroma vectorstore, using the provided endpoint and API key.
@@ -70,7 +71,7 @@ embedding_function = chromadb.utils.embedding_functions.OpenAIEmbeddingFunction(
70
  embedding_model = OpenAIEmbeddings(
71
  openai_api_base=endpoint,
72
  openai_api_key=api_key,
73
- model='text-embedding-ada-002'
74
  )
75
 
76
 
@@ -111,48 +112,51 @@ def expand_query(state):
111
  Dict: The updated state with the expanded query.
112
  """
113
  print("---------Expanding Query---------")
114
- system_message = '''You are a query-expansion engine for a medical retrieval system.
 
115
 
116
- Your job:
117
- 1. Expand the user's query into 6–8 alternative questions that could retrieve the same medical information.
118
 
119
- Rules:
120
- - Do NOT answer the query.
121
- - Keep output in the same language as input.
122
- - Preserve key entities (e.g., vitamins, disorders, nutrients).
123
- - Each query must be ≤ 16 words.
124
- - Output strict JSON only. No explanation. No extra text.
125
 
126
- Schema:
127
- {{
128
- "queries": ["...", "..."]
129
- }}'''
 
130
 
131
  expand_prompt = ChatPromptTemplate.from_messages([
132
  ("system", system_message),
133
  ("user", "Expand this query: {query} using the feedback: {query_feedback}")
134
-
135
  ])
136
 
137
  chain = expand_prompt | llm | StrOutputParser()
138
- expanded_query = chain.invoke({"query": state['query'], "query_feedback":state["query_feedback"]})
139
  print("expanded_query", expanded_query)
140
  state["expanded_query"] = expanded_query
141
  return state
142
 
143
 
 
144
  # Initialize the Chroma vector store for retrieving documents
145
  vector_store = Chroma(
146
- collection_name="nutritional_hypotheticals",
147
- persist_directory="./nutritional_db",
148
- embedding_function=embedding_model
149
-
150
  )
151
 
152
  # Create a retriever from the vector store
 
 
153
  retriever = vector_store.as_retriever(
154
- search_type='similarity',
155
- search_kwargs={'k': 3}
156
  )
157
 
158
  def retrieve_context(state):
@@ -166,18 +170,21 @@ def retrieve_context(state):
166
  Dict: The updated state with the retrieved context.
167
  """
168
  print("---------retrieve_context---------")
169
- query = state['expanded_query'] # Complete the code to define the key for the expanded query
 
170
  #print("Query used for retrieval:", query) # Debugging: Print the query
171
 
172
- # Retrieve documents from the vector store
173
  docs = retriever.invoke(query)
174
  print("Retrieved documents:", docs) # Debugging: Print the raw docs object
175
 
 
 
176
  # Extract both page_content and metadata from each document
177
  context= [
178
- {
179
- "content": doc.page_content, # The actual content of the document
180
- "metadata": doc.metadata # The metadata (e.g., source, page number, etc.)
181
  }
182
  for doc in docs
183
  ]
@@ -199,7 +206,8 @@ def craft_response(state: Dict) -> Dict:
199
  Dict: The updated state with the generated response.
200
  """
201
  print("---------craft_response---------")
202
- system_message = '''You are a medical assistant specializing in nutritional disorders.
 
203
 
204
  Your job is to generate concise, accurate answers strictly based on the provided context from a textbook or trusted source.
205
 
@@ -218,20 +226,22 @@ def craft_response(state: Dict) -> Dict:
218
  ("system", system_message),
219
  ("user", "Query: {query}\nContext: {context}\n\nfeedback: {feedback}")
220
  ])
 
 
221
 
 
222
  chain = response_prompt | llm
223
  response = chain.invoke({
224
  "query": state['query'],
225
- "context": "\n".join([doc["content"] for doc in state['context']]),
226
- "feedback": state.get('feedback', '') # add feedback to the prompt
227
  })
228
- state['response'] = response.content
229
- print("intermediate response: ", response.content)
230
 
231
  return state
232
 
233
 
234
-
235
  def score_groundedness(state: Dict) -> Dict:
236
  """
237
  Checks whether the response is grounded in the retrieved context.
@@ -265,7 +275,7 @@ def score_groundedness(state: Dict) -> Dict:
265
  chain = groundedness_prompt | llm | StrOutputParser()
266
  groundedness_score = float(chain.invoke({
267
  "context": "\n".join([doc["content"] for doc in state['context']]),
268
- "response": state['response'] # Complete the code to define the response
269
  }))
270
  print("groundedness_score: ", groundedness_score)
271
  state['groundedness_loop_count'] += 1
@@ -275,10 +285,9 @@ def score_groundedness(state: Dict) -> Dict:
275
  return state
276
 
277
 
278
-
279
  def check_precision(state: Dict) -> Dict:
280
  """
281
- Checks whether the response precisely addresses the user's query.
282
 
283
  Args:
284
  state (Dict): The current state of the workflow, containing the query and response.
@@ -308,7 +317,7 @@ def check_precision(state: Dict) -> Dict:
308
  chain = precision_prompt | llm | StrOutputParser() # Complete the code to define the chain of processing
309
  precision_score = float(chain.invoke({
310
  "query": state['query'],
311
- "response": state['response'] # Complete the code to access the response from the state
312
  }))
313
  state['precision_score'] = precision_score
314
  print("precision_score:", precision_score)
@@ -317,7 +326,6 @@ def check_precision(state: Dict) -> Dict:
317
  return state
318
 
319
 
320
-
321
  def refine_response(state: Dict) -> Dict:
322
  """
323
  Suggests improvements for the generated response.
@@ -357,7 +365,6 @@ def refine_response(state: Dict) -> Dict:
357
  return state
358
 
359
 
360
-
361
  def refine_query(state: Dict) -> Dict:
362
  """
363
  Suggests improvements for the expanded query.
@@ -401,25 +408,24 @@ def should_continue_groundedness(state):
401
  """Decides if groundedness is sufficient or needs improvement."""
402
  print("---------should_continue_groundedness---------")
403
  print("groundedness loop count: ", state['groundedness_loop_count'])
404
- if state['groundedness_score'] >= 7.0: # Complete the code to define the threshold for groundedness
405
  print("Moving to precision")
406
  return "check_precision"
407
  else:
408
- if state["groundedness_loop_count"] > state['loop_max_iter']:
409
  return "max_iterations_reached"
410
  else:
411
  print(f"---------Groundedness Score Threshold Not met. Refining Response-----------")
412
  return "refine_response"
413
 
414
-
415
  def should_continue_precision(state: Dict) -> str:
416
  """Decides if precision is sufficient or needs improvement."""
417
  print("---------should_continue_precision---------")
418
- print("precision loop count: ", state['precision_loop_count'])
419
- if state['precision_score'] >= 7.0: # Threshold for precision
420
  return "pass" # Complete the workflow
421
  else:
422
- if state['precision_loop_count'] > state['loop_max_iter']: # Maximum allowed loops
423
  return "max_iterations_reached"
424
  else:
425
  print(f"---------Precision Score Threshold Not met. Refining Query-----------") # Debugging
 
1
 
2
+
3
  # Import necessary libraries
4
  import os # Interacting with the operating system (reading/writing files)
5
  import chromadb # High-performance vector database for storing/querying dense vectors
 
62
  embedding_function = chromadb.utils.embedding_functions.OpenAIEmbeddingFunction(
63
  api_base=endpoint, # Complete the code to define the API base endpoint
64
  api_key=api_key, # Complete the code to define the API key
65
+ model_name='text-embedding-3-small' # This is a fixed value and does not need modification
66
  )
67
 
68
  # This initializes the OpenAI embedding function for the Chroma vectorstore, using the provided endpoint and API key.
 
71
  embedding_model = OpenAIEmbeddings(
72
  openai_api_base=endpoint,
73
  openai_api_key=api_key,
74
+ model='text-embedding-3-small'
75
  )
76
 
77
 
 
112
  Dict: The updated state with the expanded query.
113
  """
114
  print("---------Expanding Query---------")
115
+ system_message = """
116
+ You are a query-expansion engine for a medical retrieval system.
117
 
118
+ Your job:
119
+ 1. Expand the user's query into 6–8 alternative questions that could retrieve the same medical information.
120
 
121
+ Rules:
122
+ - Do NOT answer the query.
123
+ - Keep output in the same language as input.
124
+ - Preserve key entities (e.g., vitamins, disorders, nutrients).
125
+ - Each query must be ≤ 16 words.
126
+ - Output strict JSON only. No explanation. No extra text.
127
 
128
+ Schema:
129
+ {{
130
+ "queries": ["...", "..."]
131
+ }}
132
+ """
133
 
134
  expand_prompt = ChatPromptTemplate.from_messages([
135
  ("system", system_message),
136
  ("user", "Expand this query: {query} using the feedback: {query_feedback}")
 
137
  ])
138
 
139
  chain = expand_prompt | llm | StrOutputParser()
140
+ expanded_query = chain.invoke({"query": state['query'], "query_feedback": state["query_feedback"]})
141
  print("expanded_query", expanded_query)
142
  state["expanded_query"] = expanded_query
143
  return state
144
 
145
 
146
+
147
  # Initialize the Chroma vector store for retrieving documents
148
  vector_store = Chroma(
149
+ collection_name='nutritional_hypotheticals', # Complete the code to define the collection name
150
+ persist_directory='./nutritional_db', # Complete the code to define the directory for persistence
151
+ embedding_function=embedding_model # Complete the code to define the embedding function
 
152
  )
153
 
154
  # Create a retriever from the vector store
155
+
156
+ # this is the provided code but I want to use the structured retriever
157
  retriever = vector_store.as_retriever(
158
+ search_type='similarity', # Complete the code to define the search type
159
+ search_kwargs={'k': 6} # Complete the code to define the number of results to retrieve
160
  )
161
 
162
  def retrieve_context(state):
 
170
  Dict: The updated state with the retrieved context.
171
  """
172
  print("---------retrieve_context---------")
173
+ query = state.get('expanded_query') or state.get('query') # Complete the code to define the key for the expanded query
174
+
175
  #print("Query used for retrieval:", query) # Debugging: Print the query
176
 
177
+ # Retrieve hypothetical questions from the vector store
178
  docs = retriever.invoke(query)
179
  print("Retrieved documents:", docs) # Debugging: Print the raw docs object
180
 
181
+
182
+
183
  # Extract both page_content and metadata from each document
184
  context= [
185
+ {
186
+ "content": doc.metadata.get("original_content", ""),
187
+ "metadata": doc.metadata
188
  }
189
  for doc in docs
190
  ]
 
206
  Dict: The updated state with the generated response.
207
  """
208
  print("---------craft_response---------")
209
+ system_message = '''
210
+ You are a medical assistant specializing in nutritional disorders.
211
 
212
  Your job is to generate concise, accurate answers strictly based on the provided context from a textbook or trusted source.
213
 
 
226
  ("system", system_message),
227
  ("user", "Query: {query}\nContext: {context}\n\nfeedback: {feedback}")
228
  ])
229
+ context_docs = state.get("context", [])
230
+ context_string = "\n\n".join(doc["metadata"].get("original_content", "") for doc in context_docs)
231
 
232
+ feedback_text = state.get("feedback", "None")
233
  chain = response_prompt | llm
234
  response = chain.invoke({
235
  "query": state['query'],
236
+ "context": context_string,
237
+ "feedback": feedback_text # add feedback to the prompt
238
  })
239
+ state['response'] = response
240
+ print("intermediate response: ", response)
241
 
242
  return state
243
 
244
 
 
245
  def score_groundedness(state: Dict) -> Dict:
246
  """
247
  Checks whether the response is grounded in the retrieved context.
 
275
  chain = groundedness_prompt | llm | StrOutputParser()
276
  groundedness_score = float(chain.invoke({
277
  "context": "\n".join([doc["content"] for doc in state['context']]),
278
+ "response": state["response"] # Complete the code to define the response
279
  }))
280
  print("groundedness_score: ", groundedness_score)
281
  state['groundedness_loop_count'] += 1
 
285
  return state
286
 
287
 
 
288
  def check_precision(state: Dict) -> Dict:
289
  """
290
+ Checks whether the response precisely addresses the users query.
291
 
292
  Args:
293
  state (Dict): The current state of the workflow, containing the query and response.
 
317
  chain = precision_prompt | llm | StrOutputParser() # Complete the code to define the chain of processing
318
  precision_score = float(chain.invoke({
319
  "query": state['query'],
320
+ "response":state['response'] # Complete the code to access the response from the state
321
  }))
322
  state['precision_score'] = precision_score
323
  print("precision_score:", precision_score)
 
326
  return state
327
 
328
 
 
329
  def refine_response(state: Dict) -> Dict:
330
  """
331
  Suggests improvements for the generated response.
 
365
  return state
366
 
367
 
 
368
  def refine_query(state: Dict) -> Dict:
369
  """
370
  Suggests improvements for the expanded query.
 
408
  """Decides if groundedness is sufficient or needs improvement."""
409
  print("---------should_continue_groundedness---------")
410
  print("groundedness loop count: ", state['groundedness_loop_count'])
411
+ if state['groundedness_score'] >= 0.85: # Complete the code to define the threshold for groundedness
412
  print("Moving to precision")
413
  return "check_precision"
414
  else:
415
+ if state["groundedness_loop_count"] >= state['loop_max_iter']:
416
  return "max_iterations_reached"
417
  else:
418
  print(f"---------Groundedness Score Threshold Not met. Refining Response-----------")
419
  return "refine_response"
420
 
 
421
  def should_continue_precision(state: Dict) -> str:
422
  """Decides if precision is sufficient or needs improvement."""
423
  print("---------should_continue_precision---------")
424
+ print("precision loop count: ", state["precision_loop_count"])
425
+ if state["precision_score"] >= 0.85: # Threshold for precision
426
  return "pass" # Complete the workflow
427
  else:
428
+ if state["precision_loop_count"] >= state["loop_max_iter"]: # Maximum allowed loops
429
  return "max_iterations_reached"
430
  else:
431
  print(f"---------Precision Score Threshold Not met. Refining Query-----------") # Debugging