Spaces:
Sleeping
Sleeping
import os | |
import tempfile | |
import io | |
import pdfplumber | |
from PIL import Image | |
import logging | |
from typing import List, Dict, Tuple | |
import numpy as np | |
import easyocr | |
import re | |
import requests | |
from spellchecker import SpellChecker | |
import dateutil.parser as dparser | |
from nameparser import HumanName | |
from functools import lru_cache | |
import math | |
from collections import Counter | |
import time | |
import google.generativeai as genai | |
# Initialize spell checker with custom dictionary if needed | |
spell = SpellChecker() | |
cache_dir = os.path.join(tempfile.gettempdir(), '.EasyOCR') | |
os.environ['EASYOCR_MODULE_PATH'] = cache_dir | |
os.environ['EASYOCR_CACHE_DIR'] = cache_dir | |
os.makedirs(cache_dir, exist_ok=True) | |
os.chmod(cache_dir, 0o755) | |
logger = logging.getLogger(__name__) | |
logging.basicConfig(level=logging.INFO) | |
IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".tif", ".tiff"} | |
PDF_EXTS = {".pdf"} | |
# Pre-compile regex patterns for better performance | |
DATE_PATTERNS = [ | |
re.compile(r'\b\d{1,2}[-/]\d{1,2}[-/]\d{2,4}\b'), | |
re.compile(r'\b\d{1,2}\s+(Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)[a-z]*\s+\d{2,4}\b', re.IGNORECASE), | |
re.compile(r'\b(Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)[a-z]*\s+\d{1,2},?\s+\d{2,4}\b', re.IGNORECASE) | |
] | |
EMAIL_PATTERN = re.compile(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b') | |
PHONE_PATTERN = re.compile(r'\b(\+\d{1,2}\s?)?(\(\d{3}\)|\d{3})[-.\s]?\d{3}[-.\s]?\d{4}\b') | |
NAME_PATTERN = re.compile(r'\b(?:Mr|Mrs|Ms|Dr|Prof)\.?\s*[A-Z][a-z]+(?:\s+[A-Z][a-z]+)+\b') | |
# Configurable thresholds - making them more conservative | |
OCR_CONF_THRESHOLD = 0.3 | |
UNSTRUCTURED_ENTROPY_THRESHOLD = 4.0 # Increased to be more conservative | |
UNSTRUCTURED_MISSPELL_RATIO = 0.3 # Increased to be more conservative | |
MIN_CONFIDENCE_FOR_CORRECTION = 0.8 # Only correct if very confident | |
generator = None | |
LLM_PROMPT_TEMPLATE = """ | |
You are an OCR correction assistant. Only correct text if you are VERY confident it's wrong. | |
Preserve all names, dates, numbers, emails, and phone numbers exactly as written unless you're 95% sure they're incorrect. | |
Focus only on obvious misspellings of common words. | |
Text to correct: {text} | |
""" | |
# Lazy global reader | |
ocr_reader = None | |
def calculate_shannon_entropy(text: str) -> float: | |
""" | |
Calculate the Shannon entropy of a text string. | |
Higher values indicate more randomness/unpredictability. | |
""" | |
if not text: | |
return 0.0 | |
# Count frequency of each character | |
counter = Counter(text) | |
text_length = len(text) | |
# Calculate entropy | |
entropy_val = 0.0 | |
for count in counter.values(): | |
probability = count / text_length | |
if probability > 0: # Avoid log(0) | |
entropy_val -= probability * math.log2(probability) | |
return entropy_val | |
def initialize_gemini_generator(): | |
"""Initialize Gemini generator with retries""" | |
global generator | |
if generator is not None: | |
return generator | |
api_key = os.getenv("GOOGLE_API_KEY") | |
if not api_key: | |
logger.warning("No GOOGLE_API_KEY provided; skipping AI correction") | |
return None | |
attempts = 0 | |
max_attempts = 3 | |
while attempts < max_attempts: | |
try: | |
logger.info(f"Attempt {attempts + 1} to initialize Gemini generator...") | |
genai.configure(api_key=api_key) | |
generator = genai.GenerativeModel( | |
model_name="gemini-2.0-flash", | |
safety_settings=[ | |
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, | |
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, | |
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, | |
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}, | |
] | |
) | |
logger.info("Gemini generator initialized successfully") | |
return generator | |
except Exception as e: | |
attempts += 1 | |
logger.error(f"Initialization attempt failed: {e}") | |
if attempts == max_attempts: | |
logger.warning(f"All {max_attempts} attempts failed: {e}") | |
return None | |
wait_time = min(2 ** attempts, 10) # Exponential backoff | |
time.sleep(wait_time) | |
def get_ocr_reader(): | |
"""Initialize EasyOCR only when needed""" | |
global ocr_reader | |
if ocr_reader is None: | |
logger.info("Initializing EasyOCR reader (this may take some time)...") | |
try: | |
original_cwd = os.getcwd() | |
os.chdir(cache_dir) | |
ocr_reader = easyocr.Reader( | |
['en'], | |
gpu=False, | |
verbose=True, | |
model_storage_directory=cache_dir, | |
user_network_directory=cache_dir, | |
download_enabled=True | |
) | |
os.chdir(original_cwd) | |
logger.info("EasyOCR initialized successfully ✅") | |
except Exception as e: | |
logger.error(f"EasyOCR initialization failed ❌: {e}") | |
try: | |
alt_cache_dir = "/tmp/easyocr_cache" | |
os.makedirs(alt_cache_dir, exist_ok=True) | |
os.chmod(alt_cache_dir, 0o755) | |
ocr_reader = easyocr.Reader( | |
['en'], | |
gpu=False, | |
model_storage_directory=alt_cache_dir, | |
user_network_directory=alt_cache_dir | |
) | |
logger.info("EasyOCR alternative initialization successful ✅") | |
except Exception as alt_error: | |
logger.error(f"All EasyOCR initialization attempts failed: {alt_error}") | |
ocr_reader = None | |
return ocr_reader | |
def correct_spelling(text: str) -> str: | |
"""Correct spelling errors in text while preserving proper nouns and special terms""" | |
# Don't correct very short texts or texts that look like they contain important data | |
if len(text) < 5 or any(char.isdigit() for char in text): | |
return text | |
words = text.split() | |
corrected_words = [] | |
for word in words: | |
# Skip words that are likely proper nouns, numbers, or special patterns | |
if (word.istitle() or word.isupper() or | |
any(char.isdigit() for char in word) or | |
EMAIL_PATTERN.match(word) or | |
PHONE_PATTERN.match(word) or | |
'.' in word or | |
'-' in word or | |
len(word) <= 2): | |
corrected_words.append(word) | |
continue | |
# Only correct if the word is definitely wrong | |
if spell.unknown([word]): | |
correction = spell.correction(word) | |
# Only apply correction if it's significantly different and we're confident | |
if (correction and correction != word and | |
len(correction) > 2 and # Avoid very short corrections | |
not any(pat.match(word) for pat in DATE_PATTERNS)): # Don't "correct" dates | |
# Additional check: only correct if the original word is not a known word variant | |
logger.info(f"Corrected '{word}' to '{correction}'") | |
corrected_words.append(correction) | |
else: | |
corrected_words.append(word) | |
else: | |
corrected_words.append(word) | |
return ' '.join(corrected_words) | |
def validate_and_format_date(date_str: str) -> str: | |
"""Try to parse and standardize date formats - but be conservative""" | |
# If it already looks like a standard date format, don't change it | |
if re.match(r'\d{4}-\d{2}-\d{2}', date_str): | |
return date_str | |
try: | |
parsed_date = dparser.parse(date_str, fuzzy=True) | |
formatted = parsed_date.strftime("%Y-%m-%d") | |
# Only return formatted date if it's significantly different and we're confident | |
if abs(len(date_str) - len(formatted)) <= 2: # Don't change formats drastically | |
return formatted | |
else: | |
return date_str | |
except: | |
return date_str | |
def extract_and_normalize_names(text: str) -> Dict[str, str]: | |
"""Extract and normalize potential names - but be conservative""" | |
potential_names = NAME_PATTERN.findall(text) | |
normalized = {} | |
for name_str in potential_names: | |
# Don't "normalize" if it already looks like a standard name format | |
if re.match(r'^[A-Z][a-z]+ [A-Z][a-z]+$', name_str): | |
# Already in "First Last" format, don't change | |
continue | |
parsed = HumanName(name_str) | |
if parsed.first and parsed.last: # Only normalize if we can extract both parts | |
normalized_name = f"{parsed.first} {parsed.last}".strip() | |
# Only add if it's different from original and makes sense | |
if (normalized_name != name_str and | |
len(normalized_name.split()) >= 2 and # At least two parts | |
all(len(part) > 1 for part in normalized_name.split())): # Each part has more than 1 char | |
normalized[name_str] = normalized_name | |
logger.info(f"Normalized name '{name_str}' to '{normalized[name_str]}'") | |
return normalized | |
def validate_place_name(place_name: str) -> Tuple[str, float]: | |
"""Validate and correct place name using Nominatim API - but be conservative""" | |
# Don't attempt to validate very short place names | |
if len(place_name) < 5: | |
return place_name, 0.3 | |
try: | |
url = f"https://nominatim.openstreetmap.org/search?q={requests.utils.quote(place_name)}&format=json&limit=1" | |
headers = {'User-Agent': 'OCRTextExtractor/1.0'} | |
response = requests.get(url, headers=headers, timeout=5) | |
if response.status_code == 200: | |
data = response.json() | |
if data: | |
corrected = data[0]['display_name'] | |
# Only use correction if it's very similar to the original | |
if (place_name.lower() in corrected.lower() or | |
corrected.lower() in place_name.lower()): | |
confidence = 0.9 | |
logger.info(f"Validated '{place_name}' as '{corrected}' with confidence {confidence}") | |
return corrected, confidence | |
return place_name, 0.3 # Lower confidence for fallback | |
except Exception as e: | |
logger.error(f"Place validation failed: {e}") | |
return place_name, 0.3 | |
def call_llm_for_correction(text: str) -> str: | |
"""Call Google Gemini for advanced correction/verification - but be conservative""" | |
# Don't call LLM for very short texts or texts that look like they contain structured data | |
if len(text) < 20 or any(char.isdigit() for char in text): | |
return text | |
gen = initialize_gemini_generator() | |
if gen is None: | |
logger.warning("Gemini generator not available; skipping AI correction") | |
return text | |
try: | |
prompt = LLM_PROMPT_TEMPLATE.format(text=text) | |
logger.info("Calling Gemini for text correction...") | |
response = gen.generate_content(prompt) | |
corrected = response.text.strip() | |
# Only use the correction if it's not too different from the original | |
original_words = set(text.lower().split()) | |
corrected_words = set(corrected.lower().split()) | |
common_words = original_words.intersection(corrected_words) | |
similarity = len(common_words) / max(len(original_words), len(corrected_words)) | |
if similarity > 0.7: # Only use if at least 70% similar | |
logger.info(f"Gemini corrected text to: {corrected}") | |
return corrected | |
else: | |
logger.info("Gemini correction was too different from original, keeping original") | |
return text | |
except Exception as e: | |
logger.error(f"Gemini call failed: {e}") | |
return text # Fallback to original text | |
def is_unstructured(text: str, ocr_results: List) -> bool: | |
"""Detect if text is unstructured based on heuristics - more conservative now""" | |
if not text: | |
return False | |
# If text contains structured data patterns, it's probably structured | |
if (any(pat.search(text) for pat in DATE_PATTERNS) or | |
EMAIL_PATTERN.search(text) or | |
PHONE_PATTERN.search(text) or | |
NAME_PATTERN.search(text)): | |
return False | |
# Average OCR confidence | |
confidences = [res[2] for res in ocr_results if len(res) == 3] | |
avg_conf = np.mean(confidences) if confidences else 1.0 | |
# Misspelling ratio | |
words = text.split() | |
if len(words) < 5: # Don't check spelling for very short texts | |
misspelled = 0 | |
else: | |
misspelled = len(spell.unknown(words)) / len(words) if words else 0 | |
# Entropy (randomness) | |
text_entropy = calculate_shannon_entropy(text) | |
# More conservative thresholds | |
return ((avg_conf < OCR_CONF_THRESHOLD - 0.1) or # Lower confidence threshold | |
(misspelled > UNSTRUCTURED_MISSPELL_RATIO + 0.1) or # Higher misspelling threshold | |
(text_entropy > UNSTRUCTURED_ENTROPY_THRESHOLD + 0.5)) # Higher entropy threshold | |
def post_process_text(text: str, ocr_results: List = None) -> str: | |
"""Apply various post-processing techniques to clean and validate text - more conservative""" | |
if not text or len(text.strip()) < 3: | |
return text | |
# Store original text for comparison | |
original_text = text | |
# Basic corrections always - but more conservative now | |
corrected_text = correct_spelling(text) | |
# Entity-specific processing | |
lines = corrected_text.split('\n') | |
processed_lines = [] | |
for line in lines: | |
processed_line = line | |
# Dates - only format if they look like dates and formatting makes sense | |
for pattern in DATE_PATTERNS: | |
dates = pattern.findall(line) | |
for date in dates: | |
formatted = validate_and_format_date(date) | |
# Only replace if the formatted date is different but still recognizable | |
if formatted != date and len(formatted) >= 6: # At least YYYY-MM | |
processed_line = processed_line.replace(date, formatted) | |
# Names - only normalize if they don't already look like standard names | |
names_dict = extract_and_normalize_names(processed_line) | |
for original, normalized in names_dict.items(): | |
processed_line = processed_line.replace(original, normalized) | |
# Preserve emails/phones (already skipped in spelling) | |
emails = EMAIL_PATTERN.findall(line) | |
phones = PHONE_PATTERN.findall(line) | |
if emails: | |
logger.info(f"Found emails: {emails}") | |
if phones: | |
logger.info(f"Found phone numbers: {phones}") | |
processed_lines.append(processed_line) | |
corrected_text = '\n'.join(processed_lines) | |
# Advanced AI/external only if unstructured AND significantly different from original | |
if (ocr_results and | |
is_unstructured(corrected_text, ocr_results) and | |
calculate_shannon_entropy(corrected_text) > calculate_shannon_entropy(original_text) + 0.5): | |
# Validate places (extract potential places heuristically) | |
potential_places = re.findall(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)?\s*(?:Street|St|Ave|Road|Rd|City|Town|etc)\b', corrected_text) | |
for place in set(potential_places): | |
corrected_place, conf = validate_place_name(place) | |
if conf > MIN_CONFIDENCE_FOR_CORRECTION: | |
corrected_text = corrected_text.replace(place, corrected_place) | |
# LLM for overall verification/correction - only if really needed | |
corrected_text = call_llm_for_correction(corrected_text) | |
return corrected_text | |
def extract_text_from_image(file_bytes: bytes) -> str: | |
reader = get_ocr_reader() | |
if reader is None: | |
return "" | |
try: | |
image = Image.open(io.BytesIO(file_bytes)).convert('RGB') | |
image_np = np.array(image) | |
results = reader.readtext(image_np, paragraph=True, detail=1) | |
logger.info(f"OCR found {len(results)} text segments") | |
text_parts = [] | |
for result in results: | |
if len(result) == 3: | |
bbox, text, confidence = result | |
elif len(result) == 2: | |
bbox, text = result | |
confidence = 1.0 | |
else: | |
continue | |
if confidence > OCR_CONF_THRESHOLD: | |
text_parts.append(text) | |
raw_text = "\n".join(text_parts).strip() | |
return post_process_text(raw_text, results) | |
except Exception as e: | |
logger.error(f"Image OCR failed: {e}") | |
return "" | |
def extract_text_from_pdf(file_bytes: bytes) -> List[str]: | |
text_parts = [] | |
try: | |
with pdfplumber.open(io.BytesIO(file_bytes)) as pdf: | |
logger.info(f"Opened PDF with {len(pdf.pages)} pages") | |
for i, page in enumerate(pdf.pages): | |
try: | |
page_text = page.extract_text() or "" | |
if len(page_text.strip()) < 50: | |
reader = get_ocr_reader() | |
if reader: | |
im = page.to_image(resolution=200).original.convert("RGB") | |
image_np = np.array(im) | |
results = reader.readtext(image_np, paragraph=True, detail=1) | |
ocr_text_parts = [] | |
for result in results: | |
if len(result) == 3: | |
bbox, text, confidence = result | |
elif len(result) == 2: | |
bbox, text = result | |
confidence = 1.0 | |
else: | |
continue | |
if confidence > OCR_CONF_THRESHOLD: | |
ocr_text_parts.append(text) | |
raw_ocr_text = "\n".join(ocr_text_parts).strip() | |
processed = post_process_text(raw_ocr_text, results) | |
text_parts.append(processed if processed else page_text) | |
else: | |
text_parts.append(page_text) | |
else: | |
processed = post_process_text(page_text) | |
text_parts.append(processed) | |
except Exception as e: | |
logger.error(f"Page {i+1} extraction failed: {e}") | |
text_parts.append("") | |
return text_parts | |
except Exception as e: | |
logger.error(f"PDF extraction failed: {e}") | |
return [] | |
def guess_and_extract(filename: str, file_bytes: bytes) -> List[str]: | |
ext = ("." + filename.lower().rsplit(".", 1)[-1]) if "." in filename else "" | |
try: | |
if ext in PDF_EXTS: | |
return extract_text_from_pdf(file_bytes) | |
elif ext in IMAGE_EXTS: | |
text = extract_text_from_image(file_bytes) | |
return [text] if text else [] | |
else: | |
for encoding in ["utf-8", "latin-1", "iso-8859-1"]: | |
try: | |
raw_text = file_bytes.decode(encoding).strip() | |
processed = post_process_text(raw_text) | |
return [processed] | |
except UnicodeDecodeError: | |
continue | |
return [] | |
except Exception as e: | |
logger.error(f"Extraction failed for {filename}: {e}") | |
return [] |