aj718890's picture
Update app.py
14ee8fa verified
# unified_movie_embeddings_demo_final.py
# ---------------------------------------------------------
# Unified Movie Embeddings Demo (Offline + Local Posters)
# ---------------------------------------------------------
# unified_movie_embeddings_demo_final.py
# ---------------------------------------------------------
# Unified Movie Embeddings Demo (HF-Safe Version)
# ---------------------------------------------------------
import numpy as np
import pandas as pd
from pathlib import Path
from PIL import Image
from sentence_transformers import SentenceTransformer
import faiss
import gradio as gr
# ---------------------------------------------------------
# 1. Dataset
# ---------------------------------------------------------
data = [
["Interstellar", "A team of explorers travel through a wormhole in space."],
["Inception", "A thief who steals secrets through dream-sharing technology."],
["The Martian", "An astronaut stranded on Mars must survive alone."],
["Gravity", "Two astronauts stranded in space after an accident."],
["Avatar", "A marine on an alien planet torn between duty and nature."]
]
df = pd.DataFrame(data, columns=["title", "synopsis"])
# ---------------------------------------------------------
# 2. Local poster paths (use HF root directory)
# ---------------------------------------------------------
POSTER_DIR = Path(".")
POSTER_MAP = {
"Interstellar": POSTER_DIR / "interstellar.jpg",
"Inception": POSTER_DIR / "inception.jpg",
"The Martian": POSTER_DIR / "martian.jpg",
"Gravity": POSTER_DIR / "gravity.jpg",
"Avatar": POSTER_DIR / "avatar.jpg"
}
def load_local_image(title: str) -> Image.Image:
path = POSTER_MAP.get(title)
if path and path.exists():
return Image.open(path).convert("RGB")
# fallback placeholder
img = Image.new("RGB", (400, 600), (100, 100, 180))
return img
# Load posters
posters = [load_local_image(row.title) for row in df.itertuples()]
# ---------------------------------------------------------
# 3. Models and embeddings
# ---------------------------------------------------------
text_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
img_model = SentenceTransformer("sentence-transformers/clip-ViT-B-32")
texts = (df["title"] + ". " + df["synopsis"]).tolist()
text_emb = text_model.encode(texts, normalize_embeddings=True)
img_emb = img_model.encode(posters, normalize_embeddings=True)
# Fused embeddings
fused = np.hstack([text_emb, img_emb]).astype("float32")
fused /= (np.linalg.norm(fused, axis=1, keepdims=True) + 1e-12)
# ---------------------------------------------------------
# 4. FAISS index
# ---------------------------------------------------------
index = faiss.IndexFlatIP(fused.shape[1])
index.add(fused)
def search_movies(query, k=3):
q_text = text_model.encode([query], normalize_embeddings=True)
q_img = np.zeros((1, img_emb.shape[1]), dtype="float32")
q_fused = np.hstack([q_text, q_img]).astype("float32")
q_fused /= (np.linalg.norm(q_fused, axis=1, keepdims=True) + 1e-12)
_, indices = index.search(q_fused, k)
results = []
for idx in indices[0]:
caption = f"{df.iloc[idx]['title']}{df.iloc[idx]['synopsis']}"
results.append((posters[idx], caption))
return results
# ---------------------------------------------------------
# 5. Gradio UI (HF safe)
# ---------------------------------------------------------
demo = gr.Interface(
fn=search_movies,
inputs=gr.Textbox(label="Describe the kind of movie you want:", value="sci-fi thriller"),
outputs=gr.Gallery(label="Recommended Titles"),
title="Unified Movie Embeddings Demo"
)
if __name__ == "__main__":
demo.launch()