Summarise-files / ocr.py
HarshKalia-24's picture
optimised OCR
24ad58e
import os
import tempfile
import io
import pdfplumber
from PIL import Image
import logging
from typing import List, Dict, Tuple
import numpy as np
import easyocr
import re
import requests
from spellchecker import SpellChecker
import dateutil.parser as dparser
from nameparser import HumanName
from functools import lru_cache
import math
from collections import Counter
import time
import google.generativeai as genai
# Initialize spell checker with custom dictionary if needed
spell = SpellChecker()
cache_dir = os.path.join(tempfile.gettempdir(), '.EasyOCR')
os.environ['EASYOCR_MODULE_PATH'] = cache_dir
os.environ['EASYOCR_CACHE_DIR'] = cache_dir
os.makedirs(cache_dir, exist_ok=True)
os.chmod(cache_dir, 0o755)
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".tif", ".tiff"}
PDF_EXTS = {".pdf"}
# Pre-compile regex patterns for better performance
DATE_PATTERNS = [
re.compile(r'\b\d{1,2}[-/]\d{1,2}[-/]\d{2,4}\b'),
re.compile(r'\b\d{1,2}\s+(Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)[a-z]*\s+\d{2,4}\b', re.IGNORECASE),
re.compile(r'\b(Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)[a-z]*\s+\d{1,2},?\s+\d{2,4}\b', re.IGNORECASE)
]
EMAIL_PATTERN = re.compile(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b')
PHONE_PATTERN = re.compile(r'\b(\+\d{1,2}\s?)?(\(\d{3}\)|\d{3})[-.\s]?\d{3}[-.\s]?\d{4}\b')
NAME_PATTERN = re.compile(r'\b(?:Mr|Mrs|Ms|Dr|Prof)\.?\s*[A-Z][a-z]+(?:\s+[A-Z][a-z]+)+\b')
# Configurable thresholds - making them more conservative
OCR_CONF_THRESHOLD = 0.3
UNSTRUCTURED_ENTROPY_THRESHOLD = 4.0 # Increased to be more conservative
UNSTRUCTURED_MISSPELL_RATIO = 0.3 # Increased to be more conservative
MIN_CONFIDENCE_FOR_CORRECTION = 0.8 # Only correct if very confident
generator = None
LLM_PROMPT_TEMPLATE = """
You are an OCR correction assistant. Only correct text if you are VERY confident it's wrong.
Preserve all names, dates, numbers, emails, and phone numbers exactly as written unless you're 95% sure they're incorrect.
Focus only on obvious misspellings of common words.
Text to correct: {text}
"""
# Lazy global reader
ocr_reader = None
def calculate_shannon_entropy(text: str) -> float:
"""
Calculate the Shannon entropy of a text string.
Higher values indicate more randomness/unpredictability.
"""
if not text:
return 0.0
# Count frequency of each character
counter = Counter(text)
text_length = len(text)
# Calculate entropy
entropy_val = 0.0
for count in counter.values():
probability = count / text_length
if probability > 0: # Avoid log(0)
entropy_val -= probability * math.log2(probability)
return entropy_val
def initialize_gemini_generator():
"""Initialize Gemini generator with retries"""
global generator
if generator is not None:
return generator
api_key = os.getenv("GOOGLE_API_KEY")
if not api_key:
logger.warning("No GOOGLE_API_KEY provided; skipping AI correction")
return None
attempts = 0
max_attempts = 3
while attempts < max_attempts:
try:
logger.info(f"Attempt {attempts + 1} to initialize Gemini generator...")
genai.configure(api_key=api_key)
generator = genai.GenerativeModel(
model_name="gemini-2.0-flash",
safety_settings=[
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
]
)
logger.info("Gemini generator initialized successfully")
return generator
except Exception as e:
attempts += 1
logger.error(f"Initialization attempt failed: {e}")
if attempts == max_attempts:
logger.warning(f"All {max_attempts} attempts failed: {e}")
return None
wait_time = min(2 ** attempts, 10) # Exponential backoff
time.sleep(wait_time)
def get_ocr_reader():
"""Initialize EasyOCR only when needed"""
global ocr_reader
if ocr_reader is None:
logger.info("Initializing EasyOCR reader (this may take some time)...")
try:
original_cwd = os.getcwd()
os.chdir(cache_dir)
ocr_reader = easyocr.Reader(
['en'],
gpu=False,
verbose=True,
model_storage_directory=cache_dir,
user_network_directory=cache_dir,
download_enabled=True
)
os.chdir(original_cwd)
logger.info("EasyOCR initialized successfully ✅")
except Exception as e:
logger.error(f"EasyOCR initialization failed ❌: {e}")
try:
alt_cache_dir = "/tmp/easyocr_cache"
os.makedirs(alt_cache_dir, exist_ok=True)
os.chmod(alt_cache_dir, 0o755)
ocr_reader = easyocr.Reader(
['en'],
gpu=False,
model_storage_directory=alt_cache_dir,
user_network_directory=alt_cache_dir
)
logger.info("EasyOCR alternative initialization successful ✅")
except Exception as alt_error:
logger.error(f"All EasyOCR initialization attempts failed: {alt_error}")
ocr_reader = None
return ocr_reader
@lru_cache(maxsize=128)
def correct_spelling(text: str) -> str:
"""Correct spelling errors in text while preserving proper nouns and special terms"""
# Don't correct very short texts or texts that look like they contain important data
if len(text) < 5 or any(char.isdigit() for char in text):
return text
words = text.split()
corrected_words = []
for word in words:
# Skip words that are likely proper nouns, numbers, or special patterns
if (word.istitle() or word.isupper() or
any(char.isdigit() for char in word) or
EMAIL_PATTERN.match(word) or
PHONE_PATTERN.match(word) or
'.' in word or
'-' in word or
len(word) <= 2):
corrected_words.append(word)
continue
# Only correct if the word is definitely wrong
if spell.unknown([word]):
correction = spell.correction(word)
# Only apply correction if it's significantly different and we're confident
if (correction and correction != word and
len(correction) > 2 and # Avoid very short corrections
not any(pat.match(word) for pat in DATE_PATTERNS)): # Don't "correct" dates
# Additional check: only correct if the original word is not a known word variant
logger.info(f"Corrected '{word}' to '{correction}'")
corrected_words.append(correction)
else:
corrected_words.append(word)
else:
corrected_words.append(word)
return ' '.join(corrected_words)
def validate_and_format_date(date_str: str) -> str:
"""Try to parse and standardize date formats - but be conservative"""
# If it already looks like a standard date format, don't change it
if re.match(r'\d{4}-\d{2}-\d{2}', date_str):
return date_str
try:
parsed_date = dparser.parse(date_str, fuzzy=True)
formatted = parsed_date.strftime("%Y-%m-%d")
# Only return formatted date if it's significantly different and we're confident
if abs(len(date_str) - len(formatted)) <= 2: # Don't change formats drastically
return formatted
else:
return date_str
except:
return date_str
def extract_and_normalize_names(text: str) -> Dict[str, str]:
"""Extract and normalize potential names - but be conservative"""
potential_names = NAME_PATTERN.findall(text)
normalized = {}
for name_str in potential_names:
# Don't "normalize" if it already looks like a standard name format
if re.match(r'^[A-Z][a-z]+ [A-Z][a-z]+$', name_str):
# Already in "First Last" format, don't change
continue
parsed = HumanName(name_str)
if parsed.first and parsed.last: # Only normalize if we can extract both parts
normalized_name = f"{parsed.first} {parsed.last}".strip()
# Only add if it's different from original and makes sense
if (normalized_name != name_str and
len(normalized_name.split()) >= 2 and # At least two parts
all(len(part) > 1 for part in normalized_name.split())): # Each part has more than 1 char
normalized[name_str] = normalized_name
logger.info(f"Normalized name '{name_str}' to '{normalized[name_str]}'")
return normalized
@lru_cache(maxsize=128)
def validate_place_name(place_name: str) -> Tuple[str, float]:
"""Validate and correct place name using Nominatim API - but be conservative"""
# Don't attempt to validate very short place names
if len(place_name) < 5:
return place_name, 0.3
try:
url = f"https://nominatim.openstreetmap.org/search?q={requests.utils.quote(place_name)}&format=json&limit=1"
headers = {'User-Agent': 'OCRTextExtractor/1.0'}
response = requests.get(url, headers=headers, timeout=5)
if response.status_code == 200:
data = response.json()
if data:
corrected = data[0]['display_name']
# Only use correction if it's very similar to the original
if (place_name.lower() in corrected.lower() or
corrected.lower() in place_name.lower()):
confidence = 0.9
logger.info(f"Validated '{place_name}' as '{corrected}' with confidence {confidence}")
return corrected, confidence
return place_name, 0.3 # Lower confidence for fallback
except Exception as e:
logger.error(f"Place validation failed: {e}")
return place_name, 0.3
def call_llm_for_correction(text: str) -> str:
"""Call Google Gemini for advanced correction/verification - but be conservative"""
# Don't call LLM for very short texts or texts that look like they contain structured data
if len(text) < 20 or any(char.isdigit() for char in text):
return text
gen = initialize_gemini_generator()
if gen is None:
logger.warning("Gemini generator not available; skipping AI correction")
return text
try:
prompt = LLM_PROMPT_TEMPLATE.format(text=text)
logger.info("Calling Gemini for text correction...")
response = gen.generate_content(prompt)
corrected = response.text.strip()
# Only use the correction if it's not too different from the original
original_words = set(text.lower().split())
corrected_words = set(corrected.lower().split())
common_words = original_words.intersection(corrected_words)
similarity = len(common_words) / max(len(original_words), len(corrected_words))
if similarity > 0.7: # Only use if at least 70% similar
logger.info(f"Gemini corrected text to: {corrected}")
return corrected
else:
logger.info("Gemini correction was too different from original, keeping original")
return text
except Exception as e:
logger.error(f"Gemini call failed: {e}")
return text # Fallback to original text
def is_unstructured(text: str, ocr_results: List) -> bool:
"""Detect if text is unstructured based on heuristics - more conservative now"""
if not text:
return False
# If text contains structured data patterns, it's probably structured
if (any(pat.search(text) for pat in DATE_PATTERNS) or
EMAIL_PATTERN.search(text) or
PHONE_PATTERN.search(text) or
NAME_PATTERN.search(text)):
return False
# Average OCR confidence
confidences = [res[2] for res in ocr_results if len(res) == 3]
avg_conf = np.mean(confidences) if confidences else 1.0
# Misspelling ratio
words = text.split()
if len(words) < 5: # Don't check spelling for very short texts
misspelled = 0
else:
misspelled = len(spell.unknown(words)) / len(words) if words else 0
# Entropy (randomness)
text_entropy = calculate_shannon_entropy(text)
# More conservative thresholds
return ((avg_conf < OCR_CONF_THRESHOLD - 0.1) or # Lower confidence threshold
(misspelled > UNSTRUCTURED_MISSPELL_RATIO + 0.1) or # Higher misspelling threshold
(text_entropy > UNSTRUCTURED_ENTROPY_THRESHOLD + 0.5)) # Higher entropy threshold
def post_process_text(text: str, ocr_results: List = None) -> str:
"""Apply various post-processing techniques to clean and validate text - more conservative"""
if not text or len(text.strip()) < 3:
return text
# Store original text for comparison
original_text = text
# Basic corrections always - but more conservative now
corrected_text = correct_spelling(text)
# Entity-specific processing
lines = corrected_text.split('\n')
processed_lines = []
for line in lines:
processed_line = line
# Dates - only format if they look like dates and formatting makes sense
for pattern in DATE_PATTERNS:
dates = pattern.findall(line)
for date in dates:
formatted = validate_and_format_date(date)
# Only replace if the formatted date is different but still recognizable
if formatted != date and len(formatted) >= 6: # At least YYYY-MM
processed_line = processed_line.replace(date, formatted)
# Names - only normalize if they don't already look like standard names
names_dict = extract_and_normalize_names(processed_line)
for original, normalized in names_dict.items():
processed_line = processed_line.replace(original, normalized)
# Preserve emails/phones (already skipped in spelling)
emails = EMAIL_PATTERN.findall(line)
phones = PHONE_PATTERN.findall(line)
if emails:
logger.info(f"Found emails: {emails}")
if phones:
logger.info(f"Found phone numbers: {phones}")
processed_lines.append(processed_line)
corrected_text = '\n'.join(processed_lines)
# Advanced AI/external only if unstructured AND significantly different from original
if (ocr_results and
is_unstructured(corrected_text, ocr_results) and
calculate_shannon_entropy(corrected_text) > calculate_shannon_entropy(original_text) + 0.5):
# Validate places (extract potential places heuristically)
potential_places = re.findall(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)?\s*(?:Street|St|Ave|Road|Rd|City|Town|etc)\b', corrected_text)
for place in set(potential_places):
corrected_place, conf = validate_place_name(place)
if conf > MIN_CONFIDENCE_FOR_CORRECTION:
corrected_text = corrected_text.replace(place, corrected_place)
# LLM for overall verification/correction - only if really needed
corrected_text = call_llm_for_correction(corrected_text)
return corrected_text
def extract_text_from_image(file_bytes: bytes) -> str:
reader = get_ocr_reader()
if reader is None:
return ""
try:
image = Image.open(io.BytesIO(file_bytes)).convert('RGB')
image_np = np.array(image)
results = reader.readtext(image_np, paragraph=True, detail=1)
logger.info(f"OCR found {len(results)} text segments")
text_parts = []
for result in results:
if len(result) == 3:
bbox, text, confidence = result
elif len(result) == 2:
bbox, text = result
confidence = 1.0
else:
continue
if confidence > OCR_CONF_THRESHOLD:
text_parts.append(text)
raw_text = "\n".join(text_parts).strip()
return post_process_text(raw_text, results)
except Exception as e:
logger.error(f"Image OCR failed: {e}")
return ""
def extract_text_from_pdf(file_bytes: bytes) -> List[str]:
text_parts = []
try:
with pdfplumber.open(io.BytesIO(file_bytes)) as pdf:
logger.info(f"Opened PDF with {len(pdf.pages)} pages")
for i, page in enumerate(pdf.pages):
try:
page_text = page.extract_text() or ""
if len(page_text.strip()) < 50:
reader = get_ocr_reader()
if reader:
im = page.to_image(resolution=200).original.convert("RGB")
image_np = np.array(im)
results = reader.readtext(image_np, paragraph=True, detail=1)
ocr_text_parts = []
for result in results:
if len(result) == 3:
bbox, text, confidence = result
elif len(result) == 2:
bbox, text = result
confidence = 1.0
else:
continue
if confidence > OCR_CONF_THRESHOLD:
ocr_text_parts.append(text)
raw_ocr_text = "\n".join(ocr_text_parts).strip()
processed = post_process_text(raw_ocr_text, results)
text_parts.append(processed if processed else page_text)
else:
text_parts.append(page_text)
else:
processed = post_process_text(page_text)
text_parts.append(processed)
except Exception as e:
logger.error(f"Page {i+1} extraction failed: {e}")
text_parts.append("")
return text_parts
except Exception as e:
logger.error(f"PDF extraction failed: {e}")
return []
def guess_and_extract(filename: str, file_bytes: bytes) -> List[str]:
ext = ("." + filename.lower().rsplit(".", 1)[-1]) if "." in filename else ""
try:
if ext in PDF_EXTS:
return extract_text_from_pdf(file_bytes)
elif ext in IMAGE_EXTS:
text = extract_text_from_image(file_bytes)
return [text] if text else []
else:
for encoding in ["utf-8", "latin-1", "iso-8859-1"]:
try:
raw_text = file_bytes.decode(encoding).strip()
processed = post_process_text(raw_text)
return [processed]
except UnicodeDecodeError:
continue
return []
except Exception as e:
logger.error(f"Extraction failed for {filename}: {e}")
return []