|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
from pathlib import Path |
|
|
from PIL import Image |
|
|
from sentence_transformers import SentenceTransformer |
|
|
import faiss |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data = [ |
|
|
["Interstellar", "A team of explorers travel through a wormhole in space."], |
|
|
["Inception", "A thief who steals secrets through dream-sharing technology."], |
|
|
["The Martian", "An astronaut stranded on Mars must survive alone."], |
|
|
["Gravity", "Two astronauts stranded in space after an accident."], |
|
|
["Avatar", "A marine on an alien planet torn between duty and nature."] |
|
|
] |
|
|
df = pd.DataFrame(data, columns=["title", "synopsis"]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
POSTER_DIR = Path(".") |
|
|
|
|
|
POSTER_MAP = { |
|
|
"Interstellar": POSTER_DIR / "interstellar.jpg", |
|
|
"Inception": POSTER_DIR / "inception.jpg", |
|
|
"The Martian": POSTER_DIR / "martian.jpg", |
|
|
"Gravity": POSTER_DIR / "gravity.jpg", |
|
|
"Avatar": POSTER_DIR / "avatar.jpg" |
|
|
} |
|
|
|
|
|
def load_local_image(title: str) -> Image.Image: |
|
|
path = POSTER_MAP.get(title) |
|
|
if path and path.exists(): |
|
|
return Image.open(path).convert("RGB") |
|
|
|
|
|
|
|
|
img = Image.new("RGB", (400, 600), (100, 100, 180)) |
|
|
return img |
|
|
|
|
|
|
|
|
posters = [load_local_image(row.title) for row in df.itertuples()] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
text_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") |
|
|
img_model = SentenceTransformer("sentence-transformers/clip-ViT-B-32") |
|
|
|
|
|
texts = (df["title"] + ". " + df["synopsis"]).tolist() |
|
|
text_emb = text_model.encode(texts, normalize_embeddings=True) |
|
|
img_emb = img_model.encode(posters, normalize_embeddings=True) |
|
|
|
|
|
|
|
|
fused = np.hstack([text_emb, img_emb]).astype("float32") |
|
|
fused /= (np.linalg.norm(fused, axis=1, keepdims=True) + 1e-12) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
index = faiss.IndexFlatIP(fused.shape[1]) |
|
|
index.add(fused) |
|
|
|
|
|
def search_movies(query, k=3): |
|
|
q_text = text_model.encode([query], normalize_embeddings=True) |
|
|
q_img = np.zeros((1, img_emb.shape[1]), dtype="float32") |
|
|
q_fused = np.hstack([q_text, q_img]).astype("float32") |
|
|
q_fused /= (np.linalg.norm(q_fused, axis=1, keepdims=True) + 1e-12) |
|
|
|
|
|
_, indices = index.search(q_fused, k) |
|
|
|
|
|
results = [] |
|
|
for idx in indices[0]: |
|
|
caption = f"{df.iloc[idx]['title']} — {df.iloc[idx]['synopsis']}" |
|
|
results.append((posters[idx], caption)) |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=search_movies, |
|
|
inputs=gr.Textbox(label="Describe the kind of movie you want:", value="sci-fi thriller"), |
|
|
outputs=gr.Gallery(label="Recommended Titles"), |
|
|
title="Unified Movie Embeddings Demo" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|
|
|
|
|
|
|
|
|
|