Spaces:
Sleeping
Sleeping
Upload app.py with huggingface_hub
Browse files
app.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
|
|
|
2 |
# Import necessary libraries
|
3 |
import os # Interacting with the operating system (reading/writing files)
|
4 |
import chromadb # High-performance vector database for storing/querying dense vectors
|
@@ -61,7 +62,7 @@ MEM0_api_key = os.getenv("MEM0_API_KEY")
|
|
61 |
embedding_function = chromadb.utils.embedding_functions.OpenAIEmbeddingFunction(
|
62 |
api_base=endpoint, # Complete the code to define the API base endpoint
|
63 |
api_key=api_key, # Complete the code to define the API key
|
64 |
-
model_name='text-embedding-
|
65 |
)
|
66 |
|
67 |
# This initializes the OpenAI embedding function for the Chroma vectorstore, using the provided endpoint and API key.
|
@@ -70,7 +71,7 @@ embedding_function = chromadb.utils.embedding_functions.OpenAIEmbeddingFunction(
|
|
70 |
embedding_model = OpenAIEmbeddings(
|
71 |
openai_api_base=endpoint,
|
72 |
openai_api_key=api_key,
|
73 |
-
model='text-embedding-
|
74 |
)
|
75 |
|
76 |
|
@@ -111,48 +112,51 @@ def expand_query(state):
|
|
111 |
Dict: The updated state with the expanded query.
|
112 |
"""
|
113 |
print("---------Expanding Query---------")
|
114 |
-
system_message =
|
|
|
115 |
|
116 |
-
|
117 |
-
|
118 |
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
|
|
130 |
|
131 |
expand_prompt = ChatPromptTemplate.from_messages([
|
132 |
("system", system_message),
|
133 |
("user", "Expand this query: {query} using the feedback: {query_feedback}")
|
134 |
-
|
135 |
])
|
136 |
|
137 |
chain = expand_prompt | llm | StrOutputParser()
|
138 |
-
expanded_query = chain.invoke({"query": state['query'], "query_feedback":state["query_feedback"]})
|
139 |
print("expanded_query", expanded_query)
|
140 |
state["expanded_query"] = expanded_query
|
141 |
return state
|
142 |
|
143 |
|
|
|
144 |
# Initialize the Chroma vector store for retrieving documents
|
145 |
vector_store = Chroma(
|
146 |
-
collection_name=
|
147 |
-
persist_directory=
|
148 |
-
embedding_function=embedding_model
|
149 |
-
|
150 |
)
|
151 |
|
152 |
# Create a retriever from the vector store
|
|
|
|
|
153 |
retriever = vector_store.as_retriever(
|
154 |
-
search_type='similarity',
|
155 |
-
search_kwargs={'k':
|
156 |
)
|
157 |
|
158 |
def retrieve_context(state):
|
@@ -166,18 +170,21 @@ def retrieve_context(state):
|
|
166 |
Dict: The updated state with the retrieved context.
|
167 |
"""
|
168 |
print("---------retrieve_context---------")
|
169 |
-
query = state
|
|
|
170 |
#print("Query used for retrieval:", query) # Debugging: Print the query
|
171 |
|
172 |
-
# Retrieve
|
173 |
docs = retriever.invoke(query)
|
174 |
print("Retrieved documents:", docs) # Debugging: Print the raw docs object
|
175 |
|
|
|
|
|
176 |
# Extract both page_content and metadata from each document
|
177 |
context= [
|
178 |
-
|
179 |
-
"content": doc.
|
180 |
-
"metadata": doc.metadata
|
181 |
}
|
182 |
for doc in docs
|
183 |
]
|
@@ -199,7 +206,8 @@ def craft_response(state: Dict) -> Dict:
|
|
199 |
Dict: The updated state with the generated response.
|
200 |
"""
|
201 |
print("---------craft_response---------")
|
202 |
-
system_message = '''
|
|
|
203 |
|
204 |
Your job is to generate concise, accurate answers strictly based on the provided context from a textbook or trusted source.
|
205 |
|
@@ -218,20 +226,22 @@ def craft_response(state: Dict) -> Dict:
|
|
218 |
("system", system_message),
|
219 |
("user", "Query: {query}\nContext: {context}\n\nfeedback: {feedback}")
|
220 |
])
|
|
|
|
|
221 |
|
|
|
222 |
chain = response_prompt | llm
|
223 |
response = chain.invoke({
|
224 |
"query": state['query'],
|
225 |
-
"context":
|
226 |
-
"feedback":
|
227 |
})
|
228 |
-
state['response'] = response
|
229 |
-
print("intermediate response: ", response
|
230 |
|
231 |
return state
|
232 |
|
233 |
|
234 |
-
|
235 |
def score_groundedness(state: Dict) -> Dict:
|
236 |
"""
|
237 |
Checks whether the response is grounded in the retrieved context.
|
@@ -265,7 +275,7 @@ def score_groundedness(state: Dict) -> Dict:
|
|
265 |
chain = groundedness_prompt | llm | StrOutputParser()
|
266 |
groundedness_score = float(chain.invoke({
|
267 |
"context": "\n".join([doc["content"] for doc in state['context']]),
|
268 |
-
"response": state[
|
269 |
}))
|
270 |
print("groundedness_score: ", groundedness_score)
|
271 |
state['groundedness_loop_count'] += 1
|
@@ -275,10 +285,9 @@ def score_groundedness(state: Dict) -> Dict:
|
|
275 |
return state
|
276 |
|
277 |
|
278 |
-
|
279 |
def check_precision(state: Dict) -> Dict:
|
280 |
"""
|
281 |
-
Checks whether the response precisely addresses the user
|
282 |
|
283 |
Args:
|
284 |
state (Dict): The current state of the workflow, containing the query and response.
|
@@ -308,7 +317,7 @@ def check_precision(state: Dict) -> Dict:
|
|
308 |
chain = precision_prompt | llm | StrOutputParser() # Complete the code to define the chain of processing
|
309 |
precision_score = float(chain.invoke({
|
310 |
"query": state['query'],
|
311 |
-
"response":
|
312 |
}))
|
313 |
state['precision_score'] = precision_score
|
314 |
print("precision_score:", precision_score)
|
@@ -317,7 +326,6 @@ def check_precision(state: Dict) -> Dict:
|
|
317 |
return state
|
318 |
|
319 |
|
320 |
-
|
321 |
def refine_response(state: Dict) -> Dict:
|
322 |
"""
|
323 |
Suggests improvements for the generated response.
|
@@ -357,7 +365,6 @@ def refine_response(state: Dict) -> Dict:
|
|
357 |
return state
|
358 |
|
359 |
|
360 |
-
|
361 |
def refine_query(state: Dict) -> Dict:
|
362 |
"""
|
363 |
Suggests improvements for the expanded query.
|
@@ -401,25 +408,24 @@ def should_continue_groundedness(state):
|
|
401 |
"""Decides if groundedness is sufficient or needs improvement."""
|
402 |
print("---------should_continue_groundedness---------")
|
403 |
print("groundedness loop count: ", state['groundedness_loop_count'])
|
404 |
-
if state['groundedness_score'] >=
|
405 |
print("Moving to precision")
|
406 |
return "check_precision"
|
407 |
else:
|
408 |
-
if state["groundedness_loop_count"]
|
409 |
return "max_iterations_reached"
|
410 |
else:
|
411 |
print(f"---------Groundedness Score Threshold Not met. Refining Response-----------")
|
412 |
return "refine_response"
|
413 |
|
414 |
-
|
415 |
def should_continue_precision(state: Dict) -> str:
|
416 |
"""Decides if precision is sufficient or needs improvement."""
|
417 |
print("---------should_continue_precision---------")
|
418 |
-
print("precision loop count: ", state[
|
419 |
-
if
|
420 |
return "pass" # Complete the workflow
|
421 |
else:
|
422 |
-
if state[
|
423 |
return "max_iterations_reached"
|
424 |
else:
|
425 |
print(f"---------Precision Score Threshold Not met. Refining Query-----------") # Debugging
|
|
|
1 |
|
2 |
+
|
3 |
# Import necessary libraries
|
4 |
import os # Interacting with the operating system (reading/writing files)
|
5 |
import chromadb # High-performance vector database for storing/querying dense vectors
|
|
|
62 |
embedding_function = chromadb.utils.embedding_functions.OpenAIEmbeddingFunction(
|
63 |
api_base=endpoint, # Complete the code to define the API base endpoint
|
64 |
api_key=api_key, # Complete the code to define the API key
|
65 |
+
model_name='text-embedding-3-small' # This is a fixed value and does not need modification
|
66 |
)
|
67 |
|
68 |
# This initializes the OpenAI embedding function for the Chroma vectorstore, using the provided endpoint and API key.
|
|
|
71 |
embedding_model = OpenAIEmbeddings(
|
72 |
openai_api_base=endpoint,
|
73 |
openai_api_key=api_key,
|
74 |
+
model='text-embedding-3-small'
|
75 |
)
|
76 |
|
77 |
|
|
|
112 |
Dict: The updated state with the expanded query.
|
113 |
"""
|
114 |
print("---------Expanding Query---------")
|
115 |
+
system_message = """
|
116 |
+
You are a query-expansion engine for a medical retrieval system.
|
117 |
|
118 |
+
Your job:
|
119 |
+
1. Expand the user's query into 6–8 alternative questions that could retrieve the same medical information.
|
120 |
|
121 |
+
Rules:
|
122 |
+
- Do NOT answer the query.
|
123 |
+
- Keep output in the same language as input.
|
124 |
+
- Preserve key entities (e.g., vitamins, disorders, nutrients).
|
125 |
+
- Each query must be ≤ 16 words.
|
126 |
+
- Output strict JSON only. No explanation. No extra text.
|
127 |
|
128 |
+
Schema:
|
129 |
+
{{
|
130 |
+
"queries": ["...", "..."]
|
131 |
+
}}
|
132 |
+
"""
|
133 |
|
134 |
expand_prompt = ChatPromptTemplate.from_messages([
|
135 |
("system", system_message),
|
136 |
("user", "Expand this query: {query} using the feedback: {query_feedback}")
|
|
|
137 |
])
|
138 |
|
139 |
chain = expand_prompt | llm | StrOutputParser()
|
140 |
+
expanded_query = chain.invoke({"query": state['query'], "query_feedback": state["query_feedback"]})
|
141 |
print("expanded_query", expanded_query)
|
142 |
state["expanded_query"] = expanded_query
|
143 |
return state
|
144 |
|
145 |
|
146 |
+
|
147 |
# Initialize the Chroma vector store for retrieving documents
|
148 |
vector_store = Chroma(
|
149 |
+
collection_name='nutritional_hypotheticals', # Complete the code to define the collection name
|
150 |
+
persist_directory='./nutritional_db', # Complete the code to define the directory for persistence
|
151 |
+
embedding_function=embedding_model # Complete the code to define the embedding function
|
|
|
152 |
)
|
153 |
|
154 |
# Create a retriever from the vector store
|
155 |
+
|
156 |
+
# this is the provided code but I want to use the structured retriever
|
157 |
retriever = vector_store.as_retriever(
|
158 |
+
search_type='similarity', # Complete the code to define the search type
|
159 |
+
search_kwargs={'k': 6} # Complete the code to define the number of results to retrieve
|
160 |
)
|
161 |
|
162 |
def retrieve_context(state):
|
|
|
170 |
Dict: The updated state with the retrieved context.
|
171 |
"""
|
172 |
print("---------retrieve_context---------")
|
173 |
+
query = state.get('expanded_query') or state.get('query') # Complete the code to define the key for the expanded query
|
174 |
+
|
175 |
#print("Query used for retrieval:", query) # Debugging: Print the query
|
176 |
|
177 |
+
# Retrieve hypothetical questions from the vector store
|
178 |
docs = retriever.invoke(query)
|
179 |
print("Retrieved documents:", docs) # Debugging: Print the raw docs object
|
180 |
|
181 |
+
|
182 |
+
|
183 |
# Extract both page_content and metadata from each document
|
184 |
context= [
|
185 |
+
{
|
186 |
+
"content": doc.metadata.get("original_content", ""),
|
187 |
+
"metadata": doc.metadata
|
188 |
}
|
189 |
for doc in docs
|
190 |
]
|
|
|
206 |
Dict: The updated state with the generated response.
|
207 |
"""
|
208 |
print("---------craft_response---------")
|
209 |
+
system_message = '''
|
210 |
+
You are a medical assistant specializing in nutritional disorders.
|
211 |
|
212 |
Your job is to generate concise, accurate answers strictly based on the provided context from a textbook or trusted source.
|
213 |
|
|
|
226 |
("system", system_message),
|
227 |
("user", "Query: {query}\nContext: {context}\n\nfeedback: {feedback}")
|
228 |
])
|
229 |
+
context_docs = state.get("context", [])
|
230 |
+
context_string = "\n\n".join(doc["metadata"].get("original_content", "") for doc in context_docs)
|
231 |
|
232 |
+
feedback_text = state.get("feedback", "None")
|
233 |
chain = response_prompt | llm
|
234 |
response = chain.invoke({
|
235 |
"query": state['query'],
|
236 |
+
"context": context_string,
|
237 |
+
"feedback": feedback_text # add feedback to the prompt
|
238 |
})
|
239 |
+
state['response'] = response
|
240 |
+
print("intermediate response: ", response)
|
241 |
|
242 |
return state
|
243 |
|
244 |
|
|
|
245 |
def score_groundedness(state: Dict) -> Dict:
|
246 |
"""
|
247 |
Checks whether the response is grounded in the retrieved context.
|
|
|
275 |
chain = groundedness_prompt | llm | StrOutputParser()
|
276 |
groundedness_score = float(chain.invoke({
|
277 |
"context": "\n".join([doc["content"] for doc in state['context']]),
|
278 |
+
"response": state["response"] # Complete the code to define the response
|
279 |
}))
|
280 |
print("groundedness_score: ", groundedness_score)
|
281 |
state['groundedness_loop_count'] += 1
|
|
|
285 |
return state
|
286 |
|
287 |
|
|
|
288 |
def check_precision(state: Dict) -> Dict:
|
289 |
"""
|
290 |
+
Checks whether the response precisely addresses the user’s query.
|
291 |
|
292 |
Args:
|
293 |
state (Dict): The current state of the workflow, containing the query and response.
|
|
|
317 |
chain = precision_prompt | llm | StrOutputParser() # Complete the code to define the chain of processing
|
318 |
precision_score = float(chain.invoke({
|
319 |
"query": state['query'],
|
320 |
+
"response":state['response'] # Complete the code to access the response from the state
|
321 |
}))
|
322 |
state['precision_score'] = precision_score
|
323 |
print("precision_score:", precision_score)
|
|
|
326 |
return state
|
327 |
|
328 |
|
|
|
329 |
def refine_response(state: Dict) -> Dict:
|
330 |
"""
|
331 |
Suggests improvements for the generated response.
|
|
|
365 |
return state
|
366 |
|
367 |
|
|
|
368 |
def refine_query(state: Dict) -> Dict:
|
369 |
"""
|
370 |
Suggests improvements for the expanded query.
|
|
|
408 |
"""Decides if groundedness is sufficient or needs improvement."""
|
409 |
print("---------should_continue_groundedness---------")
|
410 |
print("groundedness loop count: ", state['groundedness_loop_count'])
|
411 |
+
if state['groundedness_score'] >= 0.85: # Complete the code to define the threshold for groundedness
|
412 |
print("Moving to precision")
|
413 |
return "check_precision"
|
414 |
else:
|
415 |
+
if state["groundedness_loop_count"] >= state['loop_max_iter']:
|
416 |
return "max_iterations_reached"
|
417 |
else:
|
418 |
print(f"---------Groundedness Score Threshold Not met. Refining Response-----------")
|
419 |
return "refine_response"
|
420 |
|
|
|
421 |
def should_continue_precision(state: Dict) -> str:
|
422 |
"""Decides if precision is sufficient or needs improvement."""
|
423 |
print("---------should_continue_precision---------")
|
424 |
+
print("precision loop count: ", state["precision_loop_count"])
|
425 |
+
if state["precision_score"] >= 0.85: # Threshold for precision
|
426 |
return "pass" # Complete the workflow
|
427 |
else:
|
428 |
+
if state["precision_loop_count"] >= state["loop_max_iter"]: # Maximum allowed loops
|
429 |
return "max_iterations_reached"
|
430 |
else:
|
431 |
print(f"---------Precision Score Threshold Not met. Refining Query-----------") # Debugging
|