pixeltext-ai / modeling_paligemma_ocr.py
BabaK07's picture
Upload custom PaliGemma OCR model
b8a8a54 verified
#!/usr/bin/env python3
"""
Fixed Custom OCR Model based on PaliGemma-3B
Handles device placement issues and provides better OCR performance
"""
import torch
import torch.nn as nn
from transformers import (
PaliGemmaForConditionalGeneration,
PaliGemmaProcessor,
AutoTokenizer
)
from PIL import Image
import warnings
warnings.filterwarnings("ignore")
class FixedPaliGemmaOCR(nn.Module):
"""
Fixed Custom OCR model based on PaliGemma-3B with proper device handling.
"""
def __init__(self, model_name="google/paligemma-3b-pt-224"):
super().__init__()
print(f"πŸš€ Initializing Fixed PaliGemma OCR Model...")
print(f"πŸ“¦ Base model: {model_name}")
# Determine best device and dtype
if torch.cuda.is_available():
self.device = "cuda"
self.torch_dtype = torch.float16
print("πŸ”§ Using CUDA with float16")
else:
self.device = "cpu"
self.torch_dtype = torch.float32
print("πŸ”§ Using CPU with float32")
# Load model components
try:
print("πŸ“₯ Loading PaliGemma model...")
self.base_model = PaliGemmaForConditionalGeneration.from_pretrained(
model_name,
torch_dtype=self.torch_dtype,
trust_remote_code=True
)
print("πŸ“₯ Loading processor...")
self.processor = PaliGemmaProcessor.from_pretrained(model_name)
print("πŸ“₯ Loading tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
# Move model to device
self.base_model = self.base_model.to(self.device)
print("βœ… All components loaded successfully")
except Exception as e:
print(f"❌ Failed to load PaliGemma model: {e}")
raise
# Get model dimensions
self.hidden_size = self.base_model.config.text_config.hidden_size
self.vocab_size = self.base_model.config.text_config.vocab_size
# Simple confidence estimation (no custom heads to avoid device issues)
print(f"πŸ”§ Model ready:")
print(f" - Device: {self.device}")
print(f" - Hidden size: {self.hidden_size}")
print(f" - Vocab size: {self.vocab_size}")
print(f" - Parameters: ~3B")
def generate_ocr_text(self, image, prompt="<image>Extract all text from this image:", max_length=512):
"""
Generate OCR text from image with proper device handling.
Args:
image: PIL Image or path to image
prompt: Text prompt for OCR task (must include <image> token)
max_length: Maximum length of generated text
Returns:
dict: Contains extracted text, confidence, and metadata
"""
if isinstance(image, str):
image = Image.open(image).convert('RGB')
elif not isinstance(image, Image.Image):
raise ValueError("Image must be PIL Image or path string")
try:
# Method 1: Standard PaliGemma OCR
result = self._extract_with_paligemma(image, prompt, max_length)
result['method'] = 'paligemma_standard'
return result
except Exception as e:
print(f"⚠️ Standard method failed: {e}")
try:
# Method 2: Fallback with different prompts
result = self._extract_with_fallback(image, max_length)
result['method'] = 'paligemma_fallback'
return result
except Exception as e2:
print(f"⚠️ Fallback method failed: {e2}")
# Method 3: Error handling
return {
'text': "Error: Could not extract text from image",
'confidence': 0.0,
'quality': 'error',
'method': 'error',
'error': str(e2)
}
def _extract_with_paligemma(self, image, prompt, max_length):
"""Extract text using PaliGemma's standard approach."""
try:
# Prepare inputs with proper prompt format
if "<image>" not in prompt:
prompt = f"<image>{prompt}"
inputs = self.processor(
text=prompt,
images=image,
return_tensors="pt"
)
# Move all tensor inputs to device
for key in inputs:
if isinstance(inputs[key], torch.Tensor):
inputs[key] = inputs[key].to(self.device)
# Generate with proper settings
with torch.no_grad():
generated_ids = self.base_model.generate(
**inputs,
max_length=max_length,
do_sample=False,
num_beams=1,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
# Decode generated text
generated_text = self.processor.batch_decode(
generated_ids,
skip_special_tokens=True
)[0]
# Clean up the text
extracted_text = self._clean_generated_text(generated_text, prompt)
# Estimate confidence based on output quality
confidence = self._estimate_confidence(extracted_text)
return {
'text': extracted_text,
'confidence': confidence,
'quality': self._assess_quality(extracted_text),
'raw_output': generated_text
}
except Exception as e:
print(f"❌ PaliGemma extraction failed: {e}")
raise
def _extract_with_fallback(self, image, max_length):
"""Fallback extraction with different prompts."""
fallback_prompts = [
"<image>What text is visible in this image?",
"<image>Read all the text in this image.",
"<image>OCR this image.",
"<image>Transcribe the text.",
"<image>"
]
for prompt in fallback_prompts:
try:
inputs = self.processor(
text=prompt,
images=image,
return_tensors="pt"
)
# Move inputs to device
for key in inputs:
if isinstance(inputs[key], torch.Tensor):
inputs[key] = inputs[key].to(self.device)
with torch.no_grad():
generated_ids = self.base_model.generate(
**inputs,
max_length=max_length,
do_sample=True,
temperature=0.1,
top_p=0.9,
num_beams=1,
pad_token_id=self.tokenizer.eos_token_id
)
generated_text = self.processor.batch_decode(
generated_ids,
skip_special_tokens=True
)[0]
extracted_text = self._clean_generated_text(generated_text, prompt)
if len(extracted_text.strip()) > 0:
return {
'text': extracted_text,
'confidence': 0.7,
'quality': 'good',
'raw_output': generated_text
}
except Exception as e:
print(f"⚠️ Fallback prompt '{prompt}' failed: {e}")
continue
# All fallbacks failed
return {
'text': "",
'confidence': 0.0,
'quality': 'poor',
'raw_output': ""
}
def _clean_generated_text(self, generated_text, prompt):
"""Clean up generated text by removing prompt and artifacts."""
# Remove the prompt from generated text
clean_prompt = prompt.replace("<image>", "").strip()
if clean_prompt and clean_prompt in generated_text:
extracted_text = generated_text.replace(clean_prompt, "").strip()
else:
extracted_text = generated_text.strip()
# Remove common artifacts
artifacts = [
"The image shows",
"The text in the image says",
"The image contains the text",
"I can see the text",
"The text reads"
]
for artifact in artifacts:
if extracted_text.lower().startswith(artifact.lower()):
extracted_text = extracted_text[len(artifact):].strip()
if extracted_text.startswith(":"):
extracted_text = extracted_text[1:].strip()
if extracted_text.startswith('"') and extracted_text.endswith('"'):
extracted_text = extracted_text[1:-1].strip()
return extracted_text
def _estimate_confidence(self, text):
"""Estimate confidence based on text characteristics."""
if not text or len(text.strip()) == 0:
return 0.0
# Base confidence
confidence = 0.5
# Length bonus
if len(text) > 10:
confidence += 0.2
if len(text) > 50:
confidence += 0.1
# Character variety bonus
if any(c.isalpha() for c in text):
confidence += 0.1
if any(c.isdigit() for c in text):
confidence += 0.05
# Penalty for very short or suspicious text
if len(text.strip()) < 3:
confidence *= 0.5
return min(0.95, confidence)
def _assess_quality(self, text):
"""Assess text quality."""
if not text or len(text.strip()) == 0:
return 'poor'
if len(text.strip()) < 5:
return 'poor'
elif len(text.strip()) < 20:
return 'fair'
elif len(text.strip()) < 100:
return 'good'
else:
return 'excellent'
def batch_ocr(self, images, prompt="<image>Extract all text from this image:", max_length=512):
"""Process multiple images efficiently."""
results = []
for i, image in enumerate(images):
print(f"πŸ“„ Processing image {i+1}/{len(images)}...")
try:
result = self.generate_ocr_text(image, prompt, max_length)
results.append(result)
print(f" βœ… Success: {len(result['text'])} characters extracted")
except Exception as e:
print(f" ❌ Error: {e}")
results.append({
'text': f"Error processing image {i+1}",
'confidence': 0.0,
'quality': 'error',
'method': 'error',
'error': str(e)
})
return results
def get_model_info(self):
"""Get comprehensive model information."""
return {
'base_model': 'PaliGemma-3B',
'device': self.device,
'dtype': str(self.torch_dtype),
'hidden_size': self.hidden_size,
'vocab_size': self.vocab_size,
'parameters': '~3B',
'optimized_for': 'OCR and Document Understanding',
'supported_languages': '100+',
'features': [
'Multi-language OCR',
'Document understanding',
'Robust error handling',
'Batch processing',
'Confidence estimation'
]
}
def main():
"""Test the Fixed PaliGemma OCR Model."""
print("πŸš€ Testing Fixed PaliGemma OCR Model")
print("=" * 50)
try:
# Initialize model
model = FixedPaliGemmaOCR()
# Print model info
info = model.get_model_info()
print(f"\nπŸ“Š Model Information:")
for key, value in info.items():
if isinstance(value, list):
print(f" {key}:")
for item in value:
print(f" - {item}")
else:
print(f" {key}: {value}")
# Create test image
print(f"\nπŸ§ͺ Creating test image...")
from PIL import Image, ImageDraw, ImageFont
img = Image.new('RGB', (500, 300), color='white')
draw = ImageDraw.Draw(img)
try:
font = ImageFont.truetype("/System/Library/Fonts/Arial.ttf", 20)
title_font = ImageFont.truetype("/System/Library/Fonts/Arial.ttf", 28)
except:
font = ImageFont.load_default()
title_font = font
# Add various text elements
draw.text((20, 30), "INVOICE #12345", fill='black', font=title_font)
draw.text((20, 80), "Date: January 15, 2024", fill='black', font=font)
draw.text((20, 110), "Customer: John Smith", fill='blue', font=font)
draw.text((20, 140), "Amount: $1,234.56", fill='red', font=font)
draw.text((20, 170), "Description: Professional Services", fill='black', font=font)
draw.text((20, 200), "Tax (10%): $123.46", fill='black', font=font)
draw.text((20, 230), "Total: $1,358.02", fill='black', font=title_font)
img.save("test_paligemma_ocr.png")
print("βœ… Test image created: test_paligemma_ocr.png")
# Test OCR
print(f"\nπŸ” Testing OCR extraction...")
result = model.generate_ocr_text(img)
print(f"\nπŸ“ OCR Results:")
print(f" Text: {result['text']}")
print(f" Confidence: {result['confidence']:.3f}")
print(f" Quality: {result['quality']}")
print(f" Method: {result['method']}")
if len(result['text']) > 0:
print(f"\nβœ… PaliGemma OCR Model is working perfectly!")
else:
print(f"\n⚠️ OCR extracted no text - may need adjustment")
return model
except Exception as e:
print(f"❌ Error testing model: {e}")
import traceback
traceback.print_exc()
return None
if __name__ == "__main__":
model = main()