|
import os |
|
from dotenv import load_dotenv |
|
import re |
|
import pickle |
|
import faiss |
|
import numpy as np |
|
from typing import List, Dict |
|
from sentence_transformers import SentenceTransformer, CrossEncoder, util |
|
from rank_bm25 import BM25Okapi |
|
import nltk |
|
from nltk.corpus import stopwords |
|
import requests |
|
import json |
|
from openai import OpenAI |
|
import logging |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s %(levelname)s %(message)s', |
|
handlers=[logging.StreamHandler()] |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
nltk.download("stopwords") |
|
STOPWORDS = set(stopwords.words("english")) |
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2" |
|
CROSS_ENCODER = "cross-encoder/ms-marco-MiniLM-L-6-v2" |
|
OUT_DIR = "data/index_merged" |
|
|
|
FAISS_PATH = os.path.join(OUT_DIR, "faiss_merged.index") |
|
BM25_PATH = os.path.join(OUT_DIR, "bm25_merged.pkl") |
|
META_PATH = os.path.join(OUT_DIR, "meta_merged.pkl") |
|
|
|
|
|
logger.info("Loading FAISS, BM25, metadata, and models...") |
|
try: |
|
faiss_index = faiss.read_index(FAISS_PATH) |
|
with open(BM25_PATH, "rb") as f: |
|
bm25_obj = pickle.load(f) |
|
bm25 = bm25_obj["bm25"] |
|
with open(META_PATH, "rb") as f: |
|
meta: List[Dict] = pickle.load(f) |
|
embed_model = SentenceTransformer(EMBED_MODEL) |
|
reranker = CrossEncoder(CROSS_ENCODER) |
|
api_key = os.getenv("HF_API_KEY") |
|
if not api_key: |
|
logger.error("HF_API_KEY environment variable not set. Please check your .env file or environment.") |
|
raise ValueError("HF_API_KEY environment variable not set.") |
|
client = OpenAI( |
|
base_url="https://router.huggingface.co/v1", |
|
api_key=api_key |
|
) |
|
except Exception as e: |
|
logger.error(f"Error loading models or indexes: {e}") |
|
raise |
|
|
|
def get_mistral_answer(query: str, context: str) -> str: |
|
""" |
|
Calls Mistral 7B Instruct API via Hugging Face Inference API. |
|
Adds error handling and logging. |
|
""" |
|
prompt = f"Context:\n{context}\n\nQuestion: {query}\nAnswer in full sentences using context." |
|
try: |
|
logger.info(f"Calling Mistral API for query: {query}") |
|
completion = client.chat.completions.create( |
|
model="dphn/Dolphin-Mistral-24B-Venice-Edition:featherless-ai", |
|
messages=[ |
|
{ |
|
"role": "user", |
|
"content": prompt |
|
} |
|
] |
|
) |
|
answer = str(completion.choices[0].message.content) |
|
logger.info(f"Mistral API response: {answer}") |
|
return answer |
|
except Exception as e: |
|
logger.error(f"Error in Mistral API call: {e}") |
|
return f"Error fetching answer from LLM: {e}" |
|
|
|
|
|
BLOCKED_TERMS = ["weather", "cricket", "movie", "song", "football", "holiday", |
|
"travel", "recipe", "music", "game", "sports", "politics", "election"] |
|
|
|
FINANCE_DOMAINS = [ |
|
"financial reporting", "balance sheet", "income statement", |
|
"assets and liabilities", "equity", "revenue", "profit and loss", |
|
"goodwill impairment", "cash flow", "dividends", "taxation", |
|
"investment", "valuation", "capital structure", "ownership interests", |
|
"subsidiaries", "shareholders equity", "expenses", "earnings", |
|
"debt", "amortization", "depreciation" |
|
] |
|
finance_embeds = embed_model.encode(FINANCE_DOMAINS, convert_to_tensor=True) |
|
|
|
def validate_query(query: str, threshold: float = 0.5) -> bool: |
|
q_lower = query.lower() |
|
if any(bad in q_lower for bad in BLOCKED_TERMS): |
|
print("[Guardrail] Rejected by blocklist.") |
|
return False |
|
q_emb = embed_model.encode(query, convert_to_tensor=True) |
|
sim_scores = util.cos_sim(q_emb, finance_embeds) |
|
max_score = float(sim_scores.max()) |
|
if max_score > threshold: |
|
print(f"[Guardrail] Accepted (semantic match {max_score:.2f})") |
|
return True |
|
else: |
|
print(f"[Guardrail] Rejected (low semantic score {max_score:.2f})") |
|
return False |
|
|
|
|
|
def preprocess_query(query: str, remove_stopwords: bool = True) -> str: |
|
query = query.lower() |
|
query = re.sub(r"[^a-z0-9\s]", " ", query) |
|
tokens = query.split() |
|
if remove_stopwords: |
|
tokens = [t for t in tokens if t not in STOPWORDS] |
|
return " ".join(tokens) |
|
|
|
|
|
def hybrid_candidates(query: str, candidate_k: int = 50, alpha: float = 0.5) -> List[int]: |
|
q_emb = embed_model.encode([preprocess_query(query, remove_stopwords=False)], convert_to_numpy=True, normalize_embeddings=True) |
|
faiss_scores, faiss_ids = faiss_index.search(q_emb, max(candidate_k, 50)) |
|
faiss_ids = faiss_ids[0] |
|
faiss_scores = faiss_scores[0] |
|
|
|
tokenized_query = preprocess_query(query).split() |
|
bm25_scores = bm25.get_scores(tokenized_query) |
|
|
|
topN = max(candidate_k, 50) |
|
bm25_top = np.argsort(bm25_scores)[::-1][:topN] |
|
faiss_top = faiss_ids[:topN] |
|
union_ids = np.unique(np.concatenate([bm25_top, faiss_top])) |
|
|
|
faiss_score_map = {int(i): float(s) for i, s in zip(faiss_ids, faiss_scores)} |
|
f_arr = np.array([faiss_score_map.get(int(i), -1.0) for i in union_ids], dtype=float) |
|
f_min = np.min(f_arr) |
|
if np.any(f_arr < 0): |
|
f_arr = np.where(f_arr < 0, f_min, f_arr) |
|
b_arr = np.array([bm25_scores[int(i)] for i in union_ids], dtype=float) |
|
|
|
def _norm(x): return (x - np.min(x)) / (np.ptp(x) + 1e-9) |
|
combined = alpha * _norm(f_arr) + (1 - alpha) * _norm(b_arr) |
|
order = np.argsort(combined)[::-1] |
|
return union_ids[order][:candidate_k].tolist() |
|
|
|
|
|
def rerank_cross_encoder(query: str, cand_ids: List[int], top_k: int = 10) -> List[Dict]: |
|
pairs = [(query, meta[i]["content"]) for i in cand_ids] |
|
scores = reranker.predict(pairs) |
|
order = np.argsort(scores)[::-1][:top_k] |
|
return [{"id": cand_ids[i], "chunk_size": meta[cand_ids[i]]["chunk_size"], "content": meta[cand_ids[i]]["content"], "rerank_score": float(scores[i])} for i in order] |
|
|
|
|
|
def extract_value_for_year_and_concept(year: str, concept: str, context_docs: List[Dict]) -> str: |
|
target_year = str(year) |
|
concept_lower = concept.lower() |
|
for doc in context_docs: |
|
text = doc.get("content", "") |
|
lines = [line for line in text.split("\n") if line.strip() and any(c.isdigit() for c in line)] |
|
header_idx = None |
|
year_to_col = {} |
|
for idx, line in enumerate(lines): |
|
years_in_line = re.findall(r"20\d{2}", line) |
|
if years_in_line: |
|
for col_idx, y in enumerate(years_in_line): |
|
year_to_col[y] = col_idx |
|
header_idx = idx |
|
break |
|
if target_year not in year_to_col or header_idx is None: |
|
continue |
|
for line in lines[header_idx+1:]: |
|
if concept_lower in line.lower(): |
|
cols = re.split(r"\s{2,}|\t", line) |
|
col_idx = year_to_col[target_year] |
|
if col_idx < len(cols): |
|
return cols[col_idx].replace(",", "") |
|
return "" |
|
|
|
|
|
def generate_answer(query: str, top_k: int = 5, candidate_k: int = 50, alpha: float = 0.6): |
|
logger.info(f"Received query: {query}") |
|
try: |
|
if not validate_query(query): |
|
logger.warning("Query rejected: Not finance-related.") |
|
return "Query rejected: Please ask finance-related questions." |
|
|
|
cand_ids = hybrid_candidates(query, candidate_k=candidate_k, alpha=alpha) |
|
logger.info(f"Hybrid candidates retrieved: {cand_ids}") |
|
reranked = rerank_cross_encoder(query, cand_ids, top_k=top_k) |
|
logger.info(f"Reranked top docs: {[d['id'] for d in reranked]}") |
|
|
|
year_match = re.search(r"(20\d{2})", query) |
|
year = year_match.group(0) if year_match else None |
|
concept = re.sub(r"for the year 20\d{2}", "", query, flags=re.IGNORECASE).strip() |
|
|
|
year_specific_answer = None |
|
if year and concept: |
|
year_specific_answer = extract_value_for_year_and_concept(year, concept, reranked) |
|
logger.info(f"Year-specific answer: {year_specific_answer}") |
|
|
|
if year_specific_answer: |
|
answer = year_specific_answer |
|
else: |
|
|
|
context_text = "\n".join([d["content"] for d in reranked]) |
|
answer = get_mistral_answer(query, context_text) |
|
final_answer = answer |
|
logger.info(f"Final Answer: {final_answer}") |
|
return final_answer |
|
except Exception as e: |
|
logger.error(f"Error in RAG pipeline: {e}") |
|
return f"Error in RAG pipeline: {e}" |