|
|
|
|
|
""" |
|
|
Fixed Custom OCR Model based on PaliGemma-3B |
|
|
Handles device placement issues and provides better OCR performance |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import ( |
|
|
PaliGemmaForConditionalGeneration, |
|
|
PaliGemmaProcessor, |
|
|
AutoTokenizer |
|
|
) |
|
|
from PIL import Image |
|
|
import warnings |
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
class FixedPaliGemmaOCR(nn.Module): |
|
|
""" |
|
|
Fixed Custom OCR model based on PaliGemma-3B with proper device handling. |
|
|
""" |
|
|
|
|
|
def __init__(self, model_name="google/paligemma-3b-pt-224"): |
|
|
super().__init__() |
|
|
|
|
|
print(f"π Initializing Fixed PaliGemma OCR Model...") |
|
|
print(f"π¦ Base model: {model_name}") |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
self.device = "cuda" |
|
|
self.torch_dtype = torch.float16 |
|
|
print("π§ Using CUDA with float16") |
|
|
else: |
|
|
self.device = "cpu" |
|
|
self.torch_dtype = torch.float32 |
|
|
print("π§ Using CPU with float32") |
|
|
|
|
|
|
|
|
try: |
|
|
print("π₯ Loading PaliGemma model...") |
|
|
self.base_model = PaliGemmaForConditionalGeneration.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype=self.torch_dtype, |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
print("π₯ Loading processor...") |
|
|
self.processor = PaliGemmaProcessor.from_pretrained(model_name) |
|
|
|
|
|
print("π₯ Loading tokenizer...") |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
|
|
|
self.base_model = self.base_model.to(self.device) |
|
|
|
|
|
print("β
All components loaded successfully") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Failed to load PaliGemma model: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
self.hidden_size = self.base_model.config.text_config.hidden_size |
|
|
self.vocab_size = self.base_model.config.text_config.vocab_size |
|
|
|
|
|
|
|
|
print(f"π§ Model ready:") |
|
|
print(f" - Device: {self.device}") |
|
|
print(f" - Hidden size: {self.hidden_size}") |
|
|
print(f" - Vocab size: {self.vocab_size}") |
|
|
print(f" - Parameters: ~3B") |
|
|
|
|
|
def generate_ocr_text(self, image, prompt="<image>Extract all text from this image:", max_length=512): |
|
|
""" |
|
|
Generate OCR text from image with proper device handling. |
|
|
|
|
|
Args: |
|
|
image: PIL Image or path to image |
|
|
prompt: Text prompt for OCR task (must include <image> token) |
|
|
max_length: Maximum length of generated text |
|
|
|
|
|
Returns: |
|
|
dict: Contains extracted text, confidence, and metadata |
|
|
""" |
|
|
|
|
|
if isinstance(image, str): |
|
|
image = Image.open(image).convert('RGB') |
|
|
elif not isinstance(image, Image.Image): |
|
|
raise ValueError("Image must be PIL Image or path string") |
|
|
|
|
|
try: |
|
|
|
|
|
result = self._extract_with_paligemma(image, prompt, max_length) |
|
|
result['method'] = 'paligemma_standard' |
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β οΈ Standard method failed: {e}") |
|
|
|
|
|
try: |
|
|
|
|
|
result = self._extract_with_fallback(image, max_length) |
|
|
result['method'] = 'paligemma_fallback' |
|
|
return result |
|
|
|
|
|
except Exception as e2: |
|
|
print(f"β οΈ Fallback method failed: {e2}") |
|
|
|
|
|
|
|
|
return { |
|
|
'text': "Error: Could not extract text from image", |
|
|
'confidence': 0.0, |
|
|
'quality': 'error', |
|
|
'method': 'error', |
|
|
'error': str(e2) |
|
|
} |
|
|
|
|
|
def _extract_with_paligemma(self, image, prompt, max_length): |
|
|
"""Extract text using PaliGemma's standard approach.""" |
|
|
|
|
|
try: |
|
|
|
|
|
if "<image>" not in prompt: |
|
|
prompt = f"<image>{prompt}" |
|
|
|
|
|
inputs = self.processor( |
|
|
text=prompt, |
|
|
images=image, |
|
|
return_tensors="pt" |
|
|
) |
|
|
|
|
|
|
|
|
for key in inputs: |
|
|
if isinstance(inputs[key], torch.Tensor): |
|
|
inputs[key] = inputs[key].to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
generated_ids = self.base_model.generate( |
|
|
**inputs, |
|
|
max_length=max_length, |
|
|
do_sample=False, |
|
|
num_beams=1, |
|
|
pad_token_id=self.tokenizer.eos_token_id, |
|
|
eos_token_id=self.tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
|
|
|
generated_text = self.processor.batch_decode( |
|
|
generated_ids, |
|
|
skip_special_tokens=True |
|
|
)[0] |
|
|
|
|
|
|
|
|
extracted_text = self._clean_generated_text(generated_text, prompt) |
|
|
|
|
|
|
|
|
confidence = self._estimate_confidence(extracted_text) |
|
|
|
|
|
return { |
|
|
'text': extracted_text, |
|
|
'confidence': confidence, |
|
|
'quality': self._assess_quality(extracted_text), |
|
|
'raw_output': generated_text |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β PaliGemma extraction failed: {e}") |
|
|
raise |
|
|
|
|
|
def _extract_with_fallback(self, image, max_length): |
|
|
"""Fallback extraction with different prompts.""" |
|
|
|
|
|
fallback_prompts = [ |
|
|
"<image>What text is visible in this image?", |
|
|
"<image>Read all the text in this image.", |
|
|
"<image>OCR this image.", |
|
|
"<image>Transcribe the text.", |
|
|
"<image>" |
|
|
] |
|
|
|
|
|
for prompt in fallback_prompts: |
|
|
try: |
|
|
inputs = self.processor( |
|
|
text=prompt, |
|
|
images=image, |
|
|
return_tensors="pt" |
|
|
) |
|
|
|
|
|
|
|
|
for key in inputs: |
|
|
if isinstance(inputs[key], torch.Tensor): |
|
|
inputs[key] = inputs[key].to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
generated_ids = self.base_model.generate( |
|
|
**inputs, |
|
|
max_length=max_length, |
|
|
do_sample=True, |
|
|
temperature=0.1, |
|
|
top_p=0.9, |
|
|
num_beams=1, |
|
|
pad_token_id=self.tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
generated_text = self.processor.batch_decode( |
|
|
generated_ids, |
|
|
skip_special_tokens=True |
|
|
)[0] |
|
|
|
|
|
extracted_text = self._clean_generated_text(generated_text, prompt) |
|
|
|
|
|
if len(extracted_text.strip()) > 0: |
|
|
return { |
|
|
'text': extracted_text, |
|
|
'confidence': 0.7, |
|
|
'quality': 'good', |
|
|
'raw_output': generated_text |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β οΈ Fallback prompt '{prompt}' failed: {e}") |
|
|
continue |
|
|
|
|
|
|
|
|
return { |
|
|
'text': "", |
|
|
'confidence': 0.0, |
|
|
'quality': 'poor', |
|
|
'raw_output': "" |
|
|
} |
|
|
|
|
|
def _clean_generated_text(self, generated_text, prompt): |
|
|
"""Clean up generated text by removing prompt and artifacts.""" |
|
|
|
|
|
|
|
|
clean_prompt = prompt.replace("<image>", "").strip() |
|
|
if clean_prompt and clean_prompt in generated_text: |
|
|
extracted_text = generated_text.replace(clean_prompt, "").strip() |
|
|
else: |
|
|
extracted_text = generated_text.strip() |
|
|
|
|
|
|
|
|
artifacts = [ |
|
|
"The image shows", |
|
|
"The text in the image says", |
|
|
"The image contains the text", |
|
|
"I can see the text", |
|
|
"The text reads" |
|
|
] |
|
|
|
|
|
for artifact in artifacts: |
|
|
if extracted_text.lower().startswith(artifact.lower()): |
|
|
extracted_text = extracted_text[len(artifact):].strip() |
|
|
if extracted_text.startswith(":"): |
|
|
extracted_text = extracted_text[1:].strip() |
|
|
if extracted_text.startswith('"') and extracted_text.endswith('"'): |
|
|
extracted_text = extracted_text[1:-1].strip() |
|
|
|
|
|
return extracted_text |
|
|
|
|
|
def _estimate_confidence(self, text): |
|
|
"""Estimate confidence based on text characteristics.""" |
|
|
|
|
|
if not text or len(text.strip()) == 0: |
|
|
return 0.0 |
|
|
|
|
|
|
|
|
confidence = 0.5 |
|
|
|
|
|
|
|
|
if len(text) > 10: |
|
|
confidence += 0.2 |
|
|
if len(text) > 50: |
|
|
confidence += 0.1 |
|
|
|
|
|
|
|
|
if any(c.isalpha() for c in text): |
|
|
confidence += 0.1 |
|
|
if any(c.isdigit() for c in text): |
|
|
confidence += 0.05 |
|
|
|
|
|
|
|
|
if len(text.strip()) < 3: |
|
|
confidence *= 0.5 |
|
|
|
|
|
return min(0.95, confidence) |
|
|
|
|
|
def _assess_quality(self, text): |
|
|
"""Assess text quality.""" |
|
|
|
|
|
if not text or len(text.strip()) == 0: |
|
|
return 'poor' |
|
|
|
|
|
if len(text.strip()) < 5: |
|
|
return 'poor' |
|
|
elif len(text.strip()) < 20: |
|
|
return 'fair' |
|
|
elif len(text.strip()) < 100: |
|
|
return 'good' |
|
|
else: |
|
|
return 'excellent' |
|
|
|
|
|
def batch_ocr(self, images, prompt="<image>Extract all text from this image:", max_length=512): |
|
|
"""Process multiple images efficiently.""" |
|
|
|
|
|
results = [] |
|
|
|
|
|
for i, image in enumerate(images): |
|
|
print(f"π Processing image {i+1}/{len(images)}...") |
|
|
|
|
|
try: |
|
|
result = self.generate_ocr_text(image, prompt, max_length) |
|
|
results.append(result) |
|
|
|
|
|
print(f" β
Success: {len(result['text'])} characters extracted") |
|
|
|
|
|
except Exception as e: |
|
|
print(f" β Error: {e}") |
|
|
results.append({ |
|
|
'text': f"Error processing image {i+1}", |
|
|
'confidence': 0.0, |
|
|
'quality': 'error', |
|
|
'method': 'error', |
|
|
'error': str(e) |
|
|
}) |
|
|
|
|
|
return results |
|
|
|
|
|
def get_model_info(self): |
|
|
"""Get comprehensive model information.""" |
|
|
|
|
|
return { |
|
|
'base_model': 'PaliGemma-3B', |
|
|
'device': self.device, |
|
|
'dtype': str(self.torch_dtype), |
|
|
'hidden_size': self.hidden_size, |
|
|
'vocab_size': self.vocab_size, |
|
|
'parameters': '~3B', |
|
|
'optimized_for': 'OCR and Document Understanding', |
|
|
'supported_languages': '100+', |
|
|
'features': [ |
|
|
'Multi-language OCR', |
|
|
'Document understanding', |
|
|
'Robust error handling', |
|
|
'Batch processing', |
|
|
'Confidence estimation' |
|
|
] |
|
|
} |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Test the Fixed PaliGemma OCR Model.""" |
|
|
|
|
|
print("π Testing Fixed PaliGemma OCR Model") |
|
|
print("=" * 50) |
|
|
|
|
|
try: |
|
|
|
|
|
model = FixedPaliGemmaOCR() |
|
|
|
|
|
|
|
|
info = model.get_model_info() |
|
|
print(f"\nπ Model Information:") |
|
|
for key, value in info.items(): |
|
|
if isinstance(value, list): |
|
|
print(f" {key}:") |
|
|
for item in value: |
|
|
print(f" - {item}") |
|
|
else: |
|
|
print(f" {key}: {value}") |
|
|
|
|
|
|
|
|
print(f"\nπ§ͺ Creating test image...") |
|
|
from PIL import Image, ImageDraw, ImageFont |
|
|
|
|
|
img = Image.new('RGB', (500, 300), color='white') |
|
|
draw = ImageDraw.Draw(img) |
|
|
|
|
|
try: |
|
|
font = ImageFont.truetype("/System/Library/Fonts/Arial.ttf", 20) |
|
|
title_font = ImageFont.truetype("/System/Library/Fonts/Arial.ttf", 28) |
|
|
except: |
|
|
font = ImageFont.load_default() |
|
|
title_font = font |
|
|
|
|
|
|
|
|
draw.text((20, 30), "INVOICE #12345", fill='black', font=title_font) |
|
|
draw.text((20, 80), "Date: January 15, 2024", fill='black', font=font) |
|
|
draw.text((20, 110), "Customer: John Smith", fill='blue', font=font) |
|
|
draw.text((20, 140), "Amount: $1,234.56", fill='red', font=font) |
|
|
draw.text((20, 170), "Description: Professional Services", fill='black', font=font) |
|
|
draw.text((20, 200), "Tax (10%): $123.46", fill='black', font=font) |
|
|
draw.text((20, 230), "Total: $1,358.02", fill='black', font=title_font) |
|
|
|
|
|
img.save("test_paligemma_ocr.png") |
|
|
print("β
Test image created: test_paligemma_ocr.png") |
|
|
|
|
|
|
|
|
print(f"\nπ Testing OCR extraction...") |
|
|
result = model.generate_ocr_text(img) |
|
|
|
|
|
print(f"\nπ OCR Results:") |
|
|
print(f" Text: {result['text']}") |
|
|
print(f" Confidence: {result['confidence']:.3f}") |
|
|
print(f" Quality: {result['quality']}") |
|
|
print(f" Method: {result['method']}") |
|
|
|
|
|
if len(result['text']) > 0: |
|
|
print(f"\nβ
PaliGemma OCR Model is working perfectly!") |
|
|
else: |
|
|
print(f"\nβ οΈ OCR extracted no text - may need adjustment") |
|
|
|
|
|
return model |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Error testing model: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return None |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
model = main() |