from typing import Dict, Any import requests import io import base64 from transformers import CLIPProcessor, CLIPModel from PIL import Image from sklearn.metrics.pairwise import cosine_similarity class EndpointHandler: def __init__(self, path=""): self.processor = CLIPProcessor.from_pretrained(path) self.model = CLIPModel.from_pretrained(path) def __call__(self, data: Dict[str, Any]) -> Dict: print("this shows the custom endpoint handler is being called") inputs = data.pop("inputs", data) text = inputs.pop("text") if "image_url" in inputs: image_url = inputs.pop("image_url") image = Image.open(requests.get(image_url, stream=True).raw) else: image = inputs.pop("image") image = Image.open(io.BytesIO(base64.b64decode(image))) processed_inputs = self.processor(text=text, images=image, return_tensors="pt", padding=True, truncation=True) outputs = self.model(**processed_inputs) embedding_similarity = cosine_similarity(outputs.text_embeds.detach().numpy(), outputs.image_embeds.detach().numpy())[0][0].item() return {"text_embedding": outputs.text_embeds[0].tolist(), "image_embedding": outputs.image_embeds[0].tolist(), "embedding_similarity": embedding_similarity}