Receipt Processing Agent
Ensemble classification, OCR, field extraction, and anomaly detection
Built by Emily, John, Luke, Michael and Raghu
""" Receipt Processing Pipeline - Hugging Face Spaces App Ensemble classification, OCR, field extraction, anomaly detection, and agentic routing. """ import os import torch import torch.nn as nn import numpy as np import gradio as gr import gradio.routes as gr_routes import easyocr import json import re from PIL import Image, ImageDraw from datetime import datetime from torchvision import transforms, models from transformers import ( ViTForImageClassification, ViTImageProcessor, LayoutLMv3ForTokenClassification, LayoutLMv3Processor, ) from sklearn.ensemble import IsolationForest import warnings warnings.filterwarnings('ignore') # --------------------------------------------------------------------------- # Work around Gradio json_schema traversal crash: # - guard bool schema entries # --------------------------------------------------------------------------- import gradio_client.utils as grc_utils _orig_get_type = grc_utils.get_type _orig_json_schema_to_python_type = grc_utils.json_schema_to_python_type def _safe_get_type(schema): if isinstance(schema, bool): return "any" return _orig_get_type(schema) def _safe_json_schema_to_python_type(schema, defs=None): if isinstance(schema, bool): return "any" try: return _orig_json_schema_to_python_type(schema, defs) except Exception: return "any" grc_utils.get_type = _safe_get_type grc_utils.json_schema_to_python_type = _safe_json_schema_to_python_type # --------------------------------------------------------------------------- # JSON sanitation helper (convert numpy types & PIL-friendly outputs) # --------------------------------------------------------------------------- def to_jsonable(obj): if isinstance(obj, dict): return {k: to_jsonable(v) for k, v in obj.items()} if isinstance(obj, (list, tuple)): return [to_jsonable(v) for v in obj] if isinstance(obj, (np.bool_, bool)): return bool(obj) if isinstance(obj, (np.integer,)): return int(obj) if isinstance(obj, (np.floating,)): return float(obj) if isinstance(obj, np.ndarray): return obj.tolist() if isinstance(obj, Image.Image): return None # avoid serializing images; skip in JSON return obj # --------------------------------------------------------------------------- # Feedback persistence helper (CSV; optionally include section label) # --------------------------------------------------------------------------- def save_feedback(assessment, notes, results_json_str, section="overall"): try: parsed = json.loads(results_json_str) if results_json_str else {} except Exception: parsed = {"raw": results_json_str} entry = { "timestamp": datetime.utcnow().isoformat(), "section": section or "", "assessment": assessment or "", "notes": notes or "", "results": parsed, } import csv fieldnames = ["timestamp", "section", "assessment", "notes", "results"] file_exists = os.path.exists("feedback_logs.csv") with open("feedback_logs.csv", "a", newline="", encoding="utf-8") as f: writer = csv.DictWriter(f, fieldnames=fieldnames) if not file_exists: writer.writeheader() writer.writerow({ "timestamp": entry["timestamp"], "section": entry.get("section", ""), "assessment": entry["assessment"], "notes": entry["notes"], "results": json.dumps(entry["results"]), }) return "✅ Feedback saved. (Stored in feedback_logs.csv)" # ============================================================================ # Configuration # ============================================================================ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') MODELS_DIR = 'models' print(f"Device: {DEVICE}") print(f"Models directory: {MODELS_DIR}") # ============================================================================ # Model Classes # ============================================================================ class DocumentClassifier: """ViT-based document classifier (receipt vs other).""" def __init__(self, num_labels=2, model_path=None): self.num_labels = num_labels self.model = None self.processor = None self.model_path = model_path or os.path.join(MODELS_DIR, 'rvl_classifier.pt') self.pretrained = 'WinKawaks/vit-tiny-patch16-224' def load_model(self): try: self.processor = ViTImageProcessor.from_pretrained(self.pretrained) except: self.processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224') self.model = ViTForImageClassification.from_pretrained( self.pretrained, num_labels=self.num_labels, ignore_mismatched_sizes=True ) self.model = self.model.to(DEVICE) self.model.eval() return self.model def load_weights(self, path): if os.path.exists(path): checkpoint = torch.load(path, map_location=DEVICE) if isinstance(checkpoint, dict): if 'model_state_dict' in checkpoint: self.model.load_state_dict(checkpoint['model_state_dict'], strict=False) elif 'state_dict' in checkpoint: self.model.load_state_dict(checkpoint['state_dict'], strict=False) else: self.model.load_state_dict(checkpoint, strict=False) else: self.model.load_state_dict(checkpoint, strict=False) print(f" Loaded ViT weights from {path}") def predict(self, image): if self.model is None: self.load_model() self.model.eval() if not isinstance(image, Image.Image): image = Image.fromarray(image) image = image.convert('RGB') inputs = self.processor(images=image, return_tensors="pt") inputs = {k: v.to(DEVICE) for k, v in inputs.items()} with torch.no_grad(): outputs = self.model(**inputs) probs = torch.softmax(outputs.logits, dim=-1) pred = torch.argmax(probs, dim=-1).item() conf = probs[0, pred].item() is_receipt = pred == 1 label = "receipt" if is_receipt else "other" return { 'is_receipt': is_receipt, 'confidence': conf, 'label': label, 'probabilities': probs[0].cpu().numpy().tolist() } class ResNetDocumentClassifier: """ResNet18-based document classifier.""" def __init__(self, num_labels=2, model_path=None): self.num_labels = num_labels self.model = None self.model_path = model_path or os.path.join(MODELS_DIR, 'resnet18_rvlcdip.pt') self.use_class_mapping = False self.transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def load_model(self): self.model = models.resnet18(weights=None) self.model = self.model.to(DEVICE) self.model.eval() return self.model def load_weights(self, path): if not os.path.exists(path): return checkpoint = torch.load(path, map_location=DEVICE) if isinstance(checkpoint, dict): state_dict = checkpoint.get('model_state_dict', checkpoint.get('state_dict', checkpoint)) id2label = checkpoint.get('id2label', None) else: state_dict = checkpoint id2label = None # Determine number of classes from checkpoint fc_weight_key = 'fc.weight' if fc_weight_key in state_dict: num_classes = state_dict[fc_weight_key].shape[0] else: num_classes = self.num_labels # Rebuild final layer if needed if num_classes != self.model.fc.out_features: self.model.fc = nn.Linear(self.model.fc.in_features, num_classes) self.model = self.model.to(DEVICE) self.model.load_state_dict(state_dict, strict=False) # Handle 16-class RVL-CDIP models if num_classes == 16: self.use_class_mapping = True self.receipt_class_idx = 11 # Receipt class in RVL-CDIP print(f" Loaded ResNet weights from {path} ({num_classes} classes)") def predict(self, image): if self.model is None: self.load_model() self.model.eval() if not isinstance(image, Image.Image): image = Image.fromarray(image) image = image.convert('RGB') input_tensor = self.transform(image).unsqueeze(0).to(DEVICE) with torch.no_grad(): outputs = self.model(input_tensor) probs = torch.softmax(outputs, dim=-1) if self.use_class_mapping: receipt_prob = probs[0, self.receipt_class_idx].item() other_prob = 1.0 - receipt_prob is_receipt = receipt_prob > 0.5 conf = receipt_prob if is_receipt else other_prob final_probs = [other_prob, receipt_prob] else: pred = torch.argmax(probs, dim=-1).item() conf = probs[0, pred].item() is_receipt = pred == 1 final_probs = probs[0].cpu().numpy().tolist() return { 'is_receipt': is_receipt, 'confidence': conf, 'label': "receipt" if is_receipt else "other", 'probabilities': final_probs } class EnsembleDocumentClassifier: """Ensemble of ViT and ResNet classifiers.""" def __init__(self, model_configs=None, weights=None): self.model_configs = model_configs or [ {'name': 'vit_base', 'path': os.path.join(MODELS_DIR, 'rvl_classifier.pt')}, {'name': 'resnet18', 'path': os.path.join(MODELS_DIR, 'resnet18_rvlcdip.pt')}, ] # Filter to existing models self.model_configs = [cfg for cfg in self.model_configs if os.path.exists(cfg['path'])] if not self.model_configs: print("Warning: No model files found, will use default ViT") self.model_configs = [{'name': 'vit_default', 'path': None}] self.weights = weights or [1.0 / len(self.model_configs)] * len(self.model_configs) self.classifiers = [] self.processor = None def load_models(self): print(f"Loading ensemble with {len(self.model_configs)} models...") for cfg in self.model_configs: is_resnet = 'resnet' in cfg['name'].lower() or 'resnet' in cfg.get('path', '').lower() if is_resnet: classifier = ResNetDocumentClassifier(num_labels=2, model_path=cfg['path']) else: classifier = DocumentClassifier(num_labels=2, model_path=cfg['path']) classifier.load_model() if cfg['path'] and os.path.exists(cfg['path']): try: classifier.load_weights(cfg['path']) except Exception as e: print(f" Warning: Could not load {cfg['name']}: {e}") self.classifiers.append(classifier) if self.processor is None: if hasattr(classifier, 'processor'): self.processor = classifier.processor elif hasattr(classifier, 'transform'): self.processor = classifier.transform print(f"Ensemble ready with {len(self.classifiers)} models") return self def predict(self, image, return_individual=False): if not self.classifiers: self.load_models() all_probs = [] individual_results = [] for i, classifier in enumerate(self.classifiers): result = classifier.predict(image) probs = result.get('probabilities', [0.5, 0.5]) if len(probs) < 2: probs = [1 - result['confidence'], result['confidence']] all_probs.append(probs) individual_results.append({ 'name': self.model_configs[i]['name'], 'prediction': result['label'], 'confidence': result['confidence'], 'probabilities': probs }) # Weighted average ensemble_probs = np.zeros(2) for i, probs in enumerate(all_probs): ensemble_probs += np.array(probs[:2]) * self.weights[i] pred = np.argmax(ensemble_probs) is_receipt = pred == 1 conf = ensemble_probs[pred] result = { 'is_receipt': is_receipt, 'confidence': float(conf), 'label': "receipt" if is_receipt else "other", 'probabilities': ensemble_probs.tolist() } if return_individual: result['individual_results'] = individual_results return result # ============================================================================ # OCR # ============================================================================ class ReceiptOCR: """Enhanced OCR with EasyOCR + TrOCR + PaddleOCR + Tesseract ensemble.""" def __init__(self): self.reader = None self.trocr_engine = None self.paddleocr_engine = None self.use_tesseract = False # Engine weights for ensemble self.engine_weights = { 'trocr': 0.40, # Highest weight - best quality 'easyocr': 0.35, 'paddleocr': 0.30, 'tesseract': 0.20 } # Try to initialize TrOCR try: from transformers import TrOCRProcessor, VisionEncoderDecoderModel self.trocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed") self.trocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed") self.trocr_model = self.trocr_model.to(DEVICE) self.trocr_model.eval() self.trocr_available = True print("TrOCR initialized") except Exception as e: self.trocr_available = False print(f"TrOCR not available: {e}") # Try to initialize PaddleOCR try: from paddleocr import PaddleOCR self.paddleocr_engine = PaddleOCR(use_angle_cls=True, lang='en', show_log=False) self.paddleocr_available = True print("PaddleOCR initialized") except Exception as e: self.paddleocr_available = False print(f"PaddleOCR not available: {e}") # Try to initialize Tesseract try: import pytesseract self.use_tesseract = True except ImportError: self.use_tesseract = False def load(self): if self.reader is None: print("Loading EasyOCR...") self.reader = easyocr.Reader(['en'], gpu=torch.cuda.is_available()) print("EasyOCR ready") return self def _preprocess_image(self, image, method='enhance'): """Apply image preprocessing to improve OCR accuracy.""" import cv2 if isinstance(image, Image.Image): img_array = np.array(image) else: img_array = image.copy() if method == 'enhance': # Convert to grayscale if needed if len(img_array.shape) == 3: gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) else: gray = img_array # Apply CLAHE (Contrast Limited Adaptive Histogram Equalization) clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8)) enhanced = clahe.apply(gray) # Denoise denoised = cv2.fastNlMeansDenoising(enhanced, h=10) # Convert back to RGB for OCR engines return cv2.cvtColor(denoised, cv2.COLOR_GRAY2RGB) elif method == 'sharpen': # Sharpen the image kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]]) if len(img_array.shape) == 3: sharpened = cv2.filter2D(img_array, -1, kernel) else: gray = img_array sharpened = cv2.filter2D(gray, -1, kernel) sharpened = cv2.cvtColor(sharpened, cv2.COLOR_GRAY2RGB) return sharpened return img_array def _run_easyocr(self, image): """Run EasyOCR.""" if self.reader is None: self.load() results = self.reader.readtext(image) extracted = [] for bbox, text, conf in results: x_coords = [p[0] for p in bbox] y_coords = [p[1] for p in bbox] extracted.append({ 'text': text.strip(), 'confidence': conf, 'bbox': [min(x_coords), min(y_coords), max(x_coords), max(y_coords)], 'engine': 'easyocr' }) return extracted def _run_trocr(self, image, boxes): """Run TrOCR on detected text regions.""" if not self.trocr_available: return [] if isinstance(image, np.ndarray): pil_image = Image.fromarray(image).convert('RGB') else: pil_image = image.convert('RGB') results = [] for box in boxes: try: if isinstance(box, list) and len(box) >= 4: # Convert to [x1, y1, x2, y2] if isinstance(box[0], list): x1 = int(min(p[0] for p in box)) y1 = int(min(p[1] for p in box)) x2 = int(max(p[0] for p in box)) y2 = int(max(p[1] for p in box)) else: x1, y1, x2, y2 = [int(b) for b in box[:4]] # Crop and recognize cropped = pil_image.crop((x1, y1, x2, y2)) # TrOCR recognition pixel_values = self.trocr_processor(images=cropped, return_tensors="pt").pixel_values.to(DEVICE) with torch.no_grad(): generated_ids = self.trocr_model.generate( pixel_values, max_length=128, num_beams=4, early_stopping=True ) text = self.trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] if text.strip(): results.append({ 'text': text.strip(), 'confidence': 0.9, # TrOCR doesn't provide confidence, use high default 'bbox': [x1, y1, x2, y2], 'engine': 'trocr' }) except Exception as e: continue return results def _run_paddleocr(self, image): """Run PaddleOCR.""" if not self.paddleocr_available: return [] try: result = self.paddleocr_engine.ocr(image, cls=True) if result is None or len(result) == 0 or result[0] is None: return [] extracted = [] for line in result[0]: if line is None: continue bbox, (text, conf) = line x_coords = [p[0] for p in bbox] y_coords = [p[1] for p in bbox] extracted.append({ 'text': text.strip(), 'confidence': conf, 'bbox': [min(x_coords), min(y_coords), max(x_coords), max(y_coords)], 'engine': 'paddleocr' }) return extracted except Exception as e: print(f"PaddleOCR error: {e}") return [] def _run_tesseract(self, image): """Run Tesseract OCR.""" if not self.use_tesseract: return [] try: import pytesseract if isinstance(image, Image.Image): pil_image = image.convert('RGB') else: pil_image = Image.fromarray(image).convert('RGB') data = pytesseract.image_to_data(pil_image, output_type=pytesseract.Output.DICT) results = [] n_boxes = len(data['text']) for i in range(n_boxes): text = data['text'][i].strip() conf = int(data['conf'][i]) if text and conf > 0: x, y, w, h = data['left'][i], data['top'][i], data['width'][i], data['height'][i] results.append({ 'text': text, 'confidence': conf / 100.0, 'bbox': [x, y, x+w, y+h], 'engine': 'tesseract' }) return results except Exception as e: print(f"Tesseract OCR error: {e}") return [] def _compute_iou(self, box1, box2): """Compute Intersection over Union for bounding boxes.""" x1_1, y1_1, x2_1, y2_1 = box1 x1_2, y1_2, x2_2, y2_2 = box2 xi1 = max(x1_1, x1_2) yi1 = max(y1_1, y1_2) xi2 = min(x2_1, x2_2) yi2 = min(y2_1, y2_2) inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1) box1_area = (x2_1 - x1_1) * (y2_1 - y1_1) box2_area = (x2_2 - x1_2) * (y2_2 - y1_2) union_area = box1_area + box2_area - inter_area return inter_area / union_area if union_area > 0 else 0 def _merge_results(self, all_results): """Merge results from multiple OCR engines using weighted voting.""" if not all_results: return [] # Use the engine with most detections as base base_engine = max(all_results.keys(), key=lambda k: len(all_results[k])) base_results = all_results[base_engine] merged = [] for base_result in base_results: base_box = base_result['bbox'] base_text = base_result['text'] base_conf = base_result['confidence'] # Find matching results from other engines matches = [(base_text, base_conf, self.engine_weights.get(base_engine, 0.3))] for engine_name, results in all_results.items(): if engine_name == base_engine: continue for result in results: iou = self._compute_iou(base_box, result['bbox']) if iou > 0.3: # Same text region weight = self.engine_weights.get(engine_name, 0.2) matches.append((result['text'], result['confidence'], weight)) # Vote on the best text if len(matches) == 1: final_text = base_text final_conf = base_conf else: # Weighted voting text_scores = {} for text, conf, weight in matches: if text not in text_scores: text_scores[text] = 0 text_scores[text] += conf * weight final_text = max(text_scores.keys(), key=lambda t: text_scores[t]) total_weight = sum(w for _, _, w in matches) final_conf = min(0.99, text_scores[final_text] / total_weight if total_weight > 0 else 0.5) merged.append({ 'text': final_text, 'confidence': final_conf, 'bbox': base_box, 'engines_used': len(matches) }) return merged def extract_with_positions(self, image, min_confidence=0.3, use_ensemble=False): """Extract text with positions using ensemble of OCR engines.""" if isinstance(image, Image.Image): img_array = np.array(image) else: img_array = image.copy() all_results = {} # Run EasyOCR (always available) try: easyocr_results = self._run_easyocr(img_array) if easyocr_results: all_results['easyocr'] = easyocr_results except Exception as e: print(f"EasyOCR error: {e}") # Run PaddleOCR if available if self.paddleocr_available and use_ensemble: try: paddleocr_results = self._run_paddleocr(img_array) if paddleocr_results: all_results['paddleocr'] = paddleocr_results except Exception as e: print(f"PaddleOCR error: {e}") # Run Tesseract if available if self.use_tesseract and use_ensemble: try: tesseract_results = self._run_tesseract(img_array) if tesseract_results: all_results['tesseract'] = tesseract_results except Exception as e: print(f"Tesseract error: {e}") # Run TrOCR on detected boxes (needs boxes from other engines) if self.trocr_available and use_ensemble and all_results: try: # Get boxes from best available engine source_engine = max(all_results.keys(), key=lambda k: len(all_results[k])) boxes = [r['bbox'] for r in all_results[source_engine]] trocr_results = self._run_trocr(img_array, boxes) if trocr_results: all_results['trocr'] = trocr_results except Exception as e: print(f"TrOCR error: {e}") # Merge results if ensemble, otherwise use EasyOCR only if use_ensemble and len(all_results) > 1: merged = self._merge_results(all_results) elif 'easyocr' in all_results: merged = all_results['easyocr'] else: merged = [] # Filter by confidence filtered = [r for r in merged if r['confidence'] >= min_confidence] # If results are poor, try with preprocessing avg_confidence = np.mean([r['confidence'] for r in filtered]) if filtered else 0 if len(filtered) < 3 or avg_confidence < 0.4: try: preprocessed = self._preprocess_image(image, method='enhance') retry_results = self._run_easyocr(preprocessed) retry_filtered = [r for r in retry_results if r['confidence'] >= min_confidence] retry_avg = np.mean([r['confidence'] for r in retry_filtered]) if retry_filtered else 0 if retry_avg > avg_confidence: filtered = retry_filtered except Exception: pass # Sort by confidence (highest first) filtered.sort(key=lambda x: x['confidence'], reverse=True) return filtered def postprocess_receipt(self, ocr_results): """Extract structured fields from OCR results with improved patterns.""" # Fix common OCR errors (S->$ in amounts) fixed_results = [] for r in ocr_results: fixed_r = r.copy() fixed_r['text'] = self._fix_ocr_text(r['text']) fixed_results.append(fixed_r) full_text = ' '.join([r['text'] for r in fixed_results]) fields = { 'vendor': self._extract_vendor(ocr_results), 'date': self._extract_date(full_text), 'total': self._extract_total(full_text), 'time': self._extract_time(full_text) } return fields def _extract_vendor(self, ocr_results): """Extract vendor name - look for business name in top portion of receipt.""" if not ocr_results: return None # Sort by vertical position (top of receipt first) sorted_results = sorted(ocr_results, key=lambda x: x['bbox'][1] if isinstance(x['bbox'], list) and len(x['bbox']) > 1 else 0) # Look in top 10 results for vendor name top_results = sorted_results[:10] # Skip words that are clearly not vendor names skip_words = {'TOTAL', 'DATE', 'TIME', 'RECEIPT', 'THANK', 'YOU', 'STORE', 'HOST', 'ORDER', 'TYPE', 'TOGO', 'DINE', 'IN', 'CHECK', 'CLOSED', 'AMEX', 'VISA', 'MASTERCARD', 'CASH', 'CHANGE', 'SUBTOTAL', 'TAX'} # Known vendor patterns (common stores) known_vendors = ['EINSTEIN', 'STARBUCKS', 'MCDONALDS', 'WALMART', 'TARGET', 'CHIPOTLE', 'PANERA', 'DUNKIN', 'SUBWAY', 'CHICK-FIL-A'] # First, check if any known vendor is in the OCR results for result in top_results: text = result['text'].strip().upper() for vendor in known_vendors: if vendor in text: return result['text'].strip() # Look for longest meaningful text (likely the business name) candidates = [] for result in top_results: text = result['text'].strip() text_upper = text.upper() # Skip short texts, numbers, and common skip words if len(text) < 3: continue if text_upper in skip_words: continue if re.match(r'^[\d\s\-\/\.\$\,]+$', text): # Skip pure numbers/symbols continue if re.match(r'^#?\d+$', text): # Skip store numbers like #2846 continue # Prefer texts with letters and reasonable length if len(text) >= 4 and any(c.isalpha() for c in text): candidates.append((text, len(text), result['confidence'])) # Return the longest candidate with good confidence if candidates: # Sort by length (longer = more likely to be full vendor name) candidates.sort(key=lambda x: (x[1], x[2]), reverse=True) return candidates[0][0] return None def _extract_date(self, text): """Extract date with improved patterns.""" patterns = [ r'\b\d{1,2}[/-]\d{1,2}[/-]\d{2,4}\b', # MM/DD/YYYY or MM-DD-YYYY r'\b\d{4}[/-]\d{2}[/-]\d{2}\b', # YYYY-MM-DD r'\b(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)[a-z]*\s+\d{1,2},?\s+\d{4}\b', # Month DD, YYYY ] for pattern in patterns: matches = re.findall(pattern, text, re.IGNORECASE) if matches: return matches[0] return None def _extract_total(self, text): """Extract total amount - handles S/$ OCR confusion.""" # Fix S -> $ in amounts (common OCR error) fixed_text = re.sub(r'\bS(\d{1,3}(?:,\d{3})*(?:\.\d{2})?)\b', r'$\1', text) # Find all dollar amounts (now with fixed $ symbols) all_amounts = re.findall(r'[\$S](\d{1,3}(?:,\d{3})*(?:\.\d{2})?)', fixed_text) all_amounts = [float(a.replace(',', '')) for a in all_amounts if a] if not all_amounts: # Try finding any decimal amounts all_amounts = re.findall(r'(\d{1,3}(?:,\d{3})*\.\d{2})', fixed_text) all_amounts = [float(a.replace(',', '')) for a in all_amounts if a] if not all_amounts: return None # Look for "TOTAL", "AMOUNT DUE", "BALANCE" keywords and find amount near them lines = fixed_text.split('\n') for i, line in enumerate(lines): line_upper = line.upper() if any(keyword in line_upper for keyword in ['TOTAL', 'AMOUNT DUE', 'BALANCE DUE', 'DUE']): # Check this line and next 2 lines for amount search_text = ' '.join(lines[i:min(i+3, len(lines))]) # Match both $ and S followed by amounts matches = re.findall(r'[\$S](\d{1,3}(?:,\d{3})*(?:\.\d{2})?)', search_text) if matches: amounts_near_total = [float(m.replace(',', '')) for m in matches] return f"{max(amounts_near_total):.2f}" # Fallback: return largest amount overall return f"{max(all_amounts):.2f}" def _extract_time(self, text): """Extract time.""" patterns = [ r'\b(\d{1,2}):(\d{2})\s*(?:AM|PM)\b', r'\b(\d{1,2}):(\d{2})\b', ] for pattern in patterns: match = re.search(pattern, text, re.IGNORECASE) if match: return match.group(0) return None def _fix_ocr_text(self, text): """Fix common OCR errors like S->$ in amounts.""" # Fix S followed by digits -> $ (e.g., S154.06 -> $154.06) text = re.sub(r'\bS(\d{1,3}(?:,\d{3})*(?:\.\d{2})?)\b', r'$\1', text) # Fix Subtolal -> Subtotal (common OCR error) text = re.sub(r'\bSubtolal\b', 'Subtotal', text, flags=re.IGNORECASE) return text class LayoutLMFieldExtractor: """LayoutLMv3-based field extractor using fine-tuned weights if available.""" def __init__(self, model_path=None): self.model_path = model_path or os.path.join(MODELS_DIR, 'layoutlm_extractor.pt') self.id2label = { 0: 'O', 1: 'B-VENDOR', 2: 'I-VENDOR', 3: 'B-DATE', 4: 'I-DATE', 5: 'B-TOTAL', 6: 'I-TOTAL', 7: 'B-TIME', 8: 'I-TIME' } self.label2id = {v: k for k, v in self.id2label.items()} self.processor = None self.model = None def load(self): print("Loading LayoutLMv3 extractor...") self.processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base") self.model = LayoutLMv3ForTokenClassification.from_pretrained( "microsoft/layoutlmv3-base", num_labels=len(self.id2label), id2label=self.id2label, label2id=self.label2id, ) if os.path.exists(self.model_path): checkpoint = torch.load(self.model_path, map_location=DEVICE) if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: checkpoint = checkpoint['model_state_dict'] if isinstance(checkpoint, dict): missing, unexpected = self.model.load_state_dict(checkpoint, strict=False) print(f"Loaded LayoutLM weights; missing={len(missing)}, unexpected={len(unexpected)}") self.model = self.model.to(DEVICE) self.model.eval() print("LayoutLMv3 ready") return self def _prepare_boxes(self, ocr_results, image_size): """Convert absolute pixel boxes to LayoutLM 0-1000 format.""" width, height = image_size boxes = [] words = [] for r in ocr_results: bbox = r.get("bbox", [0, 0, width, height]) x0, y0, x1, y1 = bbox boxes.append([ int(1000 * x0 / width), int(1000 * y0 / height), int(1000 * x1 / width), int(1000 * y1 / height), ]) words.append(r.get("text", "")) return words, boxes def predict_fields(self, image, ocr_results=None): """Predict fields with confidence scores and improved total extraction.""" if self.model is None: self.load() if not isinstance(image, Image.Image): image = Image.fromarray(image) image = image.convert("RGB") if ocr_results: words, boxes = self._prepare_boxes(ocr_results, image.size) encoding = self.processor( image, words=words, boxes=boxes, return_tensors="pt", truncation=True, padding="max_length", max_length=512, ) else: encoding = self.processor(image, return_tensors="pt") encoding = {k: v.to(DEVICE) for k, v in encoding.items()} with torch.no_grad(): outputs = self.model(**encoding) logits = outputs.logits[0] # Get softmax probabilities for confidence probs = torch.softmax(logits, dim=-1) preds = logits.argmax(-1).cpu().tolist() probs_np = probs.cpu().numpy() tokens = self.processor.tokenizer.convert_ids_to_tokens(encoding["input_ids"][0].cpu()) # Extract entities with confidence scores entities = {"VENDOR": [], "DATE": [], "TOTAL": [], "TIME": []} entity_confidences = {"VENDOR": [], "DATE": [], "TOTAL": [], "TIME": []} entity_positions = {"VENDOR": [], "DATE": [], "TOTAL": [], "TIME": []} current = {"label": None, "tokens": [], "start_idx": None} for idx, (token, pred) in enumerate(zip(tokens, preds)): label = self.id2label.get(pred, "O") conf = float(probs_np[idx, pred]) if token in ["[PAD]", "[CLS]", "[SEP]"]: continue if label.startswith("B-"): # Flush previous if current["label"] and current["tokens"]: entity_text = " ".join(current["tokens"]).replace("▁", " ").strip() entities[current["label"]].append(entity_text) entity_confidences[current["label"]].append(conf) entity_positions[current["label"]].append(current["start_idx"]) current = {"label": label[2:], "tokens": [token], "start_idx": idx} elif label.startswith("I-") and current["label"] == label[2:]: current["tokens"].append(token) else: if current["label"] and current["tokens"]: entity_text = " ".join(current["tokens"]).replace("▁", " ").strip() entities[current["label"]].append(entity_text) entity_confidences[current["label"]].append(conf) entity_positions[current["label"]].append(current["start_idx"]) current = {"label": None, "tokens": [], "start_idx": None} if current["label"] and current["tokens"]: entity_text = " ".join(current["tokens"]).replace("▁", " ").strip() entities[current["label"]].append(entity_text) entity_confidences[current["label"]].append(conf) entity_positions[current["label"]].append(current["start_idx"]) # Smart field selection with confidence and position awareness result = {} # Vendor: prefer first high-confidence result if entities["VENDOR"]: best_vendor_idx = max(range(len(entities["VENDOR"])), key=lambda i: entity_confidences["VENDOR"][i]) if entity_confidences["VENDOR"][best_vendor_idx] > 0.3: result["vendor"] = entities["VENDOR"][best_vendor_idx] # Date: prefer first high-confidence result if entities["DATE"]: best_date_idx = max(range(len(entities["DATE"])), key=lambda i: entity_confidences["DATE"][i]) if entity_confidences["DATE"][best_date_idx] > 0.3: result["date"] = entities["DATE"][best_date_idx] # Time: prefer first high-confidence result if entities["TIME"]: best_time_idx = max(range(len(entities["TIME"])), key=lambda i: entity_confidences["TIME"][i]) if entity_confidences["TIME"][best_time_idx] > 0.3: result["time"] = entities["TIME"][best_time_idx] # Total: improved extraction - look for amounts near "TOTAL" keyword in OCR if entities["TOTAL"]: # Get all total candidates with confidence total_candidates = [(entities["TOTAL"][i], entity_confidences["TOTAL"][i], entity_positions["TOTAL"][i]) for i in range(len(entities["TOTAL"]))] # If OCR results available, validate against OCR text if ocr_results: ocr_text = ' '.join([r['text'] for r in ocr_results]).upper() ocr_lines = [r['text'] for r in ocr_results] # Find amounts near "TOTAL" keyword best_total = None best_conf = 0 for total_val, conf, pos in total_candidates: # Clean the total value total_clean = str(total_val).replace('$', '').replace(',', '').replace('.', '').strip() # Check if this total appears near "TOTAL" keyword in OCR for i, line in enumerate(ocr_lines): line_upper = line.upper() if 'TOTAL' in line_upper or 'AMOUNT DUE' in line_upper: # Check this line and next 2 lines for the amount search_text = ' '.join(ocr_lines[i:min(i+3, len(ocr_lines))]) search_clean = search_text.replace('$', '').replace(',', '').replace('.', '') if total_clean in search_clean: # Found near TOTAL keyword - high confidence if conf > best_conf: best_total = total_val best_conf = conf break if best_total: result["total"] = best_total else: # Fallback: use highest confidence total best_idx = max(range(len(total_candidates)), key=lambda i: total_candidates[i][1]) if total_candidates[best_idx][1] > 0.3: result["total"] = total_candidates[best_idx][0] else: # No OCR, use highest confidence best_idx = max(range(len(total_candidates)), key=lambda i: total_candidates[i][1]) if total_candidates[best_idx][1] > 0.3: result["total"] = total_candidates[best_idx][0] return result # ============================================================================ # Anomaly Detection # ============================================================================ class AnomalyDetector: """Isolation Forest-based anomaly detection.""" def __init__(self): self.model = IsolationForest(contamination=0.1, random_state=42) self.is_fitted = False def extract_features(self, fields): """Extract features from receipt fields.""" total = 0 try: total = float(fields.get('total', 0) or 0) except: pass vendor = fields.get('vendor', '') or '' date = fields.get('date', '') or '' features = [ total, np.log1p(total), len(vendor), 1 if date else 0, 1, # num_items placeholder 12, # hour placeholder total, # amount_per_item placeholder 0 # is_weekend placeholder ] return np.array(features).reshape(1, -1) def predict(self, fields): features = self.extract_features(fields) # Simple rule-based detection if model not fitted reasons = [] total = float(fields.get('total', 0) or 0) if total > 1000: reasons.append(f"High amount: ${total:.2f}") if not fields.get('vendor'): reasons.append("Missing vendor") if not fields.get('date'): reasons.append("Missing date") is_anomaly = len(reasons) > 0 return { 'is_anomaly': is_anomaly, 'score': -0.5 if is_anomaly else 0.5, 'prediction': 'ANOMALY' if is_anomaly else 'NORMAL', 'reasons': reasons } # ============================================================================ # Initialize Models # ============================================================================ print("\n" + "="*50) print("Initializing models...") print("="*50) # Check for model files model_files = [] if os.path.exists(MODELS_DIR): model_files = [f for f in os.listdir(MODELS_DIR) if f.endswith('.pt')] print(f"Found model files: {model_files}") else: print(f"Models directory not found: {MODELS_DIR}") os.makedirs(MODELS_DIR, exist_ok=True) # Initialize components try: ensemble_classifier = EnsembleDocumentClassifier() ensemble_classifier.load_models() except Exception as e: print(f"Warning: Could not load ensemble classifier: {e}") ensemble_classifier = None try: receipt_ocr = ReceiptOCR() receipt_ocr.load() except Exception as e: print(f"Warning: Could not load OCR: {e}") receipt_ocr = None try: layoutlm_extractor = LayoutLMFieldExtractor() layoutlm_extractor.load() except Exception as e: print(f"Warning: Could not load LayoutLMv3 extractor: {e}") layoutlm_extractor = None anomaly_detector = AnomalyDetector() print("\n" + "="*50) print("Initialization complete!") print("="*50 + "\n") # ============================================================================ # Helper Functions # ============================================================================ def draw_ocr_boxes(image, ocr_results): """Draw bounding boxes on image.""" img_copy = image.copy() draw = ImageDraw.Draw(img_copy) for r in ocr_results: conf = r.get('confidence', 0.5) bbox = r.get('bbox', []) if conf > 0.8: color = '#28a745' # Green elif conf > 0.5: color = '#ffc107' # Yellow else: color = '#dc3545' # Red if len(bbox) >= 4: draw.rectangle([bbox[0], bbox[1], bbox[2], bbox[3]], outline=color, width=2) return img_copy def process_receipt(image): """Main processing function for Gradio.""" if image is None: return ( "
Ensemble classification, OCR, field extraction, and anomaly detection
Built by Emily, John, Luke, Michael and Raghu