Raghu commited on
Commit
23980e2
·
1 Parent(s): 53ff1f6

Enhance OCR: add Tesseract fallback, better preprocessing, improved retry logic

Browse files
Files changed (2) hide show
  1. app.py +247 -34
  2. requirements.txt +1 -0
app.py CHANGED
@@ -369,10 +369,16 @@ class EnsembleDocumentClassifier:
369
  # ============================================================================
370
 
371
  class ReceiptOCR:
372
- """EasyOCR wrapper with retry logic."""
373
 
374
  def __init__(self):
375
  self.reader = None
 
 
 
 
 
 
376
 
377
  def load(self):
378
  if self.reader is None:
@@ -381,14 +387,162 @@ class ReceiptOCR:
381
  print("EasyOCR ready")
382
  return self
383
 
384
- def extract_with_positions(self, image, min_confidence=0.3):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
  if self.reader is None:
386
  self.load()
387
 
 
388
  if isinstance(image, Image.Image):
389
  image = np.array(image)
390
 
391
- results = self.reader.readtext(image)
 
 
 
 
 
392
 
393
  extracted = []
394
  for bbox, text, conf in results:
@@ -396,15 +550,59 @@ class ReceiptOCR:
396
  x_coords = [p[0] for p in bbox]
397
  y_coords = [p[1] for p in bbox]
398
  extracted.append({
399
- 'text': text,
400
  'confidence': conf,
401
- 'bbox': [min(x_coords), min(y_coords), max(x_coords), max(y_coords)]
 
402
  })
403
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
  return extracted
405
 
406
  def postprocess_receipt(self, ocr_results):
407
- """Extract structured fields from OCR results."""
408
  full_text = ' '.join([r['text'] for r in ocr_results])
409
 
410
  fields = {
@@ -417,49 +615,64 @@ class ReceiptOCR:
417
  return fields
418
 
419
  def _extract_vendor(self, ocr_results):
420
- if ocr_results:
421
- # Usually first line is vendor
422
- return ocr_results[0]['text']
423
- return None
 
 
 
 
 
 
 
 
 
 
 
 
424
 
425
  def _extract_date(self, text):
 
426
  patterns = [
427
- r'\d{1,2}/\d{1,2}/\d{2,4}',
428
- r'\d{1,2}-\d{1,2}-\d{2,4}',
429
- r'\d{4}-\d{2}-\d{2}',
430
  ]
431
  for pattern in patterns:
432
- match = re.search(pattern, text)
433
- if match:
434
- return match.group()
435
  return None
436
 
437
  def _extract_total(self, text):
 
 
438
  patterns = [
439
- r'TOTAL[:\s]*\$?(\d+\.?\d*)',
440
- r'AMOUNT[:\s]*\$?(\d+\.?\d*)',
441
- r'DUE[:\s]*\$?(\d+\.?\d*)',
442
  ]
 
443
  for pattern in patterns:
444
- match = re.search(pattern, text, re.IGNORECASE)
445
- if match:
446
- return match.group(1)
 
 
447
 
448
- # Find largest dollar amount
449
- amounts = re.findall(r'\$(\d+\.\d{2})', text)
450
- if amounts:
451
- return max(amounts, key=float)
452
  return None
453
 
454
  def _extract_time(self, text):
455
- pattern = r'\d{1,2}:\d{2}(?::\d{2})?(?:\s*[AP]M)?'
456
- match = re.search(pattern, text, re.IGNORECASE)
457
- return match.group() if match else None
458
-
459
-
460
- # ============================================================================
461
- # LayoutLMv3 Field Extractor
462
- # ============================================================================
 
 
463
 
464
  class LayoutLMFieldExtractor:
465
  """LayoutLMv3-based field extractor using fine-tuned weights if available."""
 
369
  # ============================================================================
370
 
371
  class ReceiptOCR:
372
+ """Enhanced OCR with EasyOCR + Tesseract fallback, better preprocessing, and retry logic."""
373
 
374
  def __init__(self):
375
  self.reader = None
376
+ self.use_tesseract = False
377
+ try:
378
+ import pytesseract
379
+ self.use_tesseract = True
380
+ except ImportError:
381
+ pass
382
 
383
  def load(self):
384
  if self.reader is None:
 
387
  print("EasyOCR ready")
388
  return self
389
 
390
+ def _preprocess_image(self, image, method='enhance'):
391
+ """Apply image preprocessing to improve OCR accuracy."""
392
+ import cv2
393
+
394
+ if isinstance(image, Image.Image):
395
+ img_array = np.array(image)
396
+ else:
397
+ img_array = image.copy()
398
+
399
+ if method == 'enhance':
400
+ # Convert to grayscale if needed
401
+ if len(img_array.shape) == 3:
402
+ gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
403
+ else:
404
+ gray = img_array
405
+
406
+ # Apply CLAHE (Contrast Limited Adaptive Histogram Equalization)
407
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
408
+ enhanced = clahe.apply(gray)
409
+
410
+ # Denoise
411
+ denoised = cv2.fastNlMeansDenoising(enhanced, h=10)
412
+
413
+ # Convert back to RGB for EasyOCR
414
+ return cv2.cvtColor(denoised, cv2.COLOR_GRAY2RGB)
415
+
416
+ elif method == 'sharpen':
417
+ # Sharpen the image
418
+ kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]])
419
+ if len(img_array.shape) == 3:
420
+ sharpened = cv2.filter2D(img_array, -1, kernel)
421
+ else:
422
+ gray = img_array
423
+ sharpened = cv2.filter2D(gray, -1, kernel)
424
+ sharpened = cv2.cvtColor(sharpened, cv2.COLOR_GRAY2RGB)
425
+ return sharpened
426
+
427
+ elif method == 'binarize':
428
+ # Adaptive thresholding
429
+ if len(img_array.shape) == 3:
430
+ gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
431
+ else:
432
+ gray = img_array
433
+ binary = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
434
+ cv2.THRESH_BINARY, 11, 2)
435
+ return cv2.cvtColor(binary, cv2.COLOR_GRAY2RGB)
436
+
437
+ return img_array
438
+
439
+ def _extract_with_tesseract(self, image):
440
+ """Fallback OCR using Tesseract."""
441
+ if not self.use_tesseract:
442
+ return []
443
+
444
+ try:
445
+ import pytesseract
446
+
447
+ if isinstance(image, Image.Image):
448
+ pil_image = image.convert('RGB')
449
+ else:
450
+ pil_image = Image.fromarray(image).convert('RGB')
451
+
452
+ # Get detailed output with bounding boxes
453
+ data = pytesseract.image_to_data(pil_image, output_type=pytesseract.Output.DICT)
454
+
455
+ results = []
456
+ n_boxes = len(data['text'])
457
+
458
+ for i in range(n_boxes):
459
+ text = data['text'][i].strip()
460
+ conf = int(data['conf'][i])
461
+
462
+ if text and conf > 0:
463
+ x, y, w, h = data['left'][i], data['top'][i], data['width'][i], data['height'][i]
464
+ results.append({
465
+ 'text': text,
466
+ 'confidence': conf / 100.0,
467
+ 'bbox': [x, y, x+w, y+h],
468
+ 'engine': 'tesseract'
469
+ })
470
+
471
+ return results
472
+ except Exception as e:
473
+ print(f"Tesseract OCR error: {e}")
474
+ return []
475
+
476
+ def _merge_ocr_results(self, easyocr_results, tesseract_results):
477
+ """Merge results from multiple OCR engines, preferring higher confidence."""
478
+ if not tesseract_results:
479
+ return easyocr_results
480
+
481
+ # Create a map of EasyOCR results by approximate position
482
+ merged = []
483
+ used_tesseract = set()
484
+
485
+ for easy_result in easyocr_results:
486
+ best_match = None
487
+ best_iou = 0
488
+
489
+ # Find best matching Tesseract result
490
+ for i, tess_result in enumerate(tesseract_results):
491
+ if i in used_tesseract:
492
+ continue
493
+
494
+ # Simple IoU calculation
495
+ iou = self._compute_iou(easy_result['bbox'], tess_result['bbox'])
496
+ if iou > best_iou and iou > 0.3: # 30% overlap threshold
497
+ best_iou = iou
498
+ best_match = (i, tess_result)
499
+
500
+ if best_match and best_match[1]['confidence'] > easy_result['confidence']:
501
+ # Use Tesseract result if it's more confident
502
+ merged.append(best_match[1])
503
+ used_tesseract.add(best_match[0])
504
+ else:
505
+ merged.append(easy_result)
506
+
507
+ # Add unused Tesseract results
508
+ for i, tess_result in enumerate(tesseract_results):
509
+ if i not in used_tesseract:
510
+ merged.append(tess_result)
511
+
512
+ return merged
513
+
514
+ def _compute_iou(self, box1, box2):
515
+ """Compute Intersection over Union for bounding boxes."""
516
+ x1_1, y1_1, x2_1, y2_1 = box1
517
+ x1_2, y1_2, x2_2, y2_2 = box2
518
+
519
+ xi1 = max(x1_1, x1_2)
520
+ yi1 = max(y1_1, y1_2)
521
+ xi2 = min(x2_1, x2_2)
522
+ yi2 = min(y2_1, y2_2)
523
+
524
+ inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1)
525
+ box1_area = (x2_1 - x1_1) * (y2_1 - y1_1)
526
+ box2_area = (x2_2 - x1_2) * (y2_2 - y1_2)
527
+ union_area = box1_area + box2_area - inter_area
528
+
529
+ return inter_area / union_area if union_area > 0 else 0
530
+
531
+ def extract_with_positions(self, image, min_confidence=0.3, use_fallback=True):
532
+ """Extract text with positions using EasyOCR + optional Tesseract fallback."""
533
  if self.reader is None:
534
  self.load()
535
 
536
+ original_image = image
537
  if isinstance(image, Image.Image):
538
  image = np.array(image)
539
 
540
+ # Try EasyOCR first
541
+ try:
542
+ results = self.reader.readtext(image)
543
+ except Exception as e:
544
+ print(f"EasyOCR error: {e}")
545
+ results = []
546
 
547
  extracted = []
548
  for bbox, text, conf in results:
 
550
  x_coords = [p[0] for p in bbox]
551
  y_coords = [p[1] for p in bbox]
552
  extracted.append({
553
+ 'text': text.strip(),
554
  'confidence': conf,
555
+ 'bbox': [min(x_coords), min(y_coords), max(x_coords), max(y_coords)],
556
+ 'engine': 'easyocr'
557
  })
558
 
559
+ # Check if we need fallback (low confidence or few results)
560
+ avg_confidence = np.mean([r['confidence'] for r in extracted]) if extracted else 0
561
+ needs_fallback = use_fallback and (len(extracted) < 3 or avg_confidence < 0.5)
562
+
563
+ if needs_fallback and self.use_tesseract:
564
+ # Try preprocessing + Tesseract
565
+ preprocessed = self._preprocess_image(original_image, method='enhance')
566
+ tesseract_results = self._extract_with_tesseract(preprocessed)
567
+
568
+ if tesseract_results:
569
+ # Merge results
570
+ extracted = self._merge_ocr_results(extracted, tesseract_results)
571
+
572
+ # If still poor results, try with preprocessing
573
+ if len(extracted) < 3 or avg_confidence < 0.4:
574
+ for method in ['enhance', 'sharpen']:
575
+ try:
576
+ preprocessed = self._preprocess_image(original_image, method=method)
577
+ retry_results = self.reader.readtext(preprocessed)
578
+
579
+ retry_extracted = []
580
+ for bbox, text, conf in retry_results:
581
+ if conf >= min_confidence:
582
+ x_coords = [p[0] for p in bbox]
583
+ y_coords = [p[1] for p in bbox]
584
+ retry_extracted.append({
585
+ 'text': text.strip(),
586
+ 'confidence': conf,
587
+ 'bbox': [min(x_coords), min(y_coords), max(x_coords), max(y_coords)],
588
+ 'engine': 'easyocr'
589
+ })
590
+
591
+ # Use retry if it's better
592
+ retry_avg = np.mean([r['confidence'] for r in retry_extracted]) if retry_extracted else 0
593
+ if retry_avg > avg_confidence:
594
+ extracted = retry_extracted
595
+ break
596
+ except Exception as e:
597
+ continue
598
+
599
+ # Sort by confidence (highest first)
600
+ extracted.sort(key=lambda x: x['confidence'], reverse=True)
601
+
602
  return extracted
603
 
604
  def postprocess_receipt(self, ocr_results):
605
+ """Extract structured fields from OCR results with improved patterns."""
606
  full_text = ' '.join([r['text'] for r in ocr_results])
607
 
608
  fields = {
 
615
  return fields
616
 
617
  def _extract_vendor(self, ocr_results):
618
+ """Extract vendor name, usually in first few lines."""
619
+ if not ocr_results:
620
+ return None
621
+
622
+ # Look for vendor in top 3 results (usually at top of receipt)
623
+ top_results = sorted(ocr_results, key=lambda x: x['bbox'][1])[:3]
624
+
625
+ for result in top_results:
626
+ text = result['text'].strip()
627
+ # Skip common non-vendor words
628
+ if text and len(text) > 2 and text.upper() not in ['TOTAL', 'DATE', 'TIME', 'RECEIPT', 'THANK', 'YOU']:
629
+ # Take longest text as vendor (usually company name)
630
+ if len(text) > 5:
631
+ return text
632
+
633
+ return top_results[0]['text'] if top_results else None
634
 
635
  def _extract_date(self, text):
636
+ """Extract date with improved patterns."""
637
  patterns = [
638
+ r'\b\d{1,2}[/-]\d{1,2}[/-]\d{2,4}\b', # MM/DD/YYYY or MM-DD-YYYY
639
+ r'\b\d{4}[/-]\d{2}[/-]\d{2}\b', # YYYY-MM-DD
640
+ r'\b(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)[a-z]*\s+\d{1,2},?\s+\d{4}\b', # Month DD, YYYY
641
  ]
642
  for pattern in patterns:
643
+ matches = re.findall(pattern, text, re.IGNORECASE)
644
+ if matches:
645
+ return matches[0]
646
  return None
647
 
648
  def _extract_total(self, text):
649
+ """Extract total amount with improved patterns."""
650
+ # Look for TOTAL, AMOUNT, DUE keywords
651
  patterns = [
652
+ r'(?:TOTAL|AMOUNT|DUE|BALANCE)[:\s]*\$?\s*(\d{1,3}(?:,\d{3})*(?:\.\d{2})?)',
653
+ r'\$\s*(\d{1,3}(?:,\d{3})*(?:\.\d{2})?)', # Any dollar amount
 
654
  ]
655
+
656
  for pattern in patterns:
657
+ matches = re.findall(pattern, text, re.IGNORECASE)
658
+ if matches:
659
+ # Return largest amount (usually the total)
660
+ amounts = [float(m.replace(',', '')) for m in matches]
661
+ return f"{max(amounts):.2f}"
662
 
 
 
 
 
663
  return None
664
 
665
  def _extract_time(self, text):
666
+ """Extract time."""
667
+ patterns = [
668
+ r'\b(\d{1,2}):(\d{2})\s*(?:AM|PM)\b',
669
+ r'\b(\d{1,2}):(\d{2})\b',
670
+ ]
671
+ for pattern in patterns:
672
+ match = re.search(pattern, text, re.IGNORECASE)
673
+ if match:
674
+ return match.group(0)
675
+ return None
676
 
677
  class LayoutLMFieldExtractor:
678
  """LayoutLMv3-based field extractor using fine-tuned weights if available."""
requirements.txt CHANGED
@@ -10,3 +10,4 @@ numpy>=1.21.0
10
  scikit-learn>=1.0.0
11
  opencv-python-headless>=4.5.0
12
 
 
 
10
  scikit-learn>=1.0.0
11
  opencv-python-headless>=4.5.0
12
 
13
+ pytesseract>=0.3.10