rahul7star's picture
Update app_flash.py
ba4b2f5 verified
raw
history blame
4.39 kB
import os
import torch
import gradio as gr
from diffusers import DiffusionPipeline
from flashpack.integrations.diffusers import (
FlashPackDiffusersModelMixin,
FlashPackDiffusionPipeline,
)
from huggingface_hub import snapshot_download
# ============================================================
# 🧠 Device setup (CPU fallback safe)
# ============================================================
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"πŸ”§ Using device: {device}")
# ============================================================
# 🧩 Define FlashPack-integrated pipeline
# ============================================================
class FlashPackMyPipeline(DiffusionPipeline, FlashPackDiffusionPipeline):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# ============================================================
# πŸš€ Load FlashPack pipeline
# ============================================================
def load_flashpack_pipeline(repo_id: str = "rahul7star/FlashPack"):
"""
Loads a FlashPack pipeline from Hugging Face Hub.
Falls back to local snapshot if network or metadata issue occurs.
"""
print(f"πŸ” Loading FlashPack pipeline from: {repo_id}")
try:
# Try direct hub load
pipeline = FlashPackMyPipeline.from_pretrained_flashpack(repo_id)
print("βœ… Successfully loaded FlashPack pipeline from Hugging Face Hub.")
except Exception as e:
print(f"⚠️ Hub load failed: {e}")
print("⏬ Attempting to load via snapshot_download...")
try:
local_dir = snapshot_download(repo_id=repo_id)
pipeline = FlashPackMyPipeline.from_pretrained_flashpack(local_dir)
print(f"βœ… Loaded FlashPack pipeline from local snapshot: {local_dir}")
except Exception as e2:
raise RuntimeError(f"❌ Failed to load FlashPack model: {e2}")
pipeline.to(device)
return pipeline
# ============================================================
# πŸ§ͺ Inference function
# ============================================================
def generate_from_prompt(prompt: str):
if not prompt or prompt.strip() == "":
return "Please enter a valid prompt.", None
try:
output = pipeline(prompt)
if hasattr(output, "images"):
img = output.images[0]
return f"βœ… Generated successfully!", img
elif hasattr(output, "frames"):
frames = output.frames
video_path = "/tmp/generated.mp4"
from diffusers.utils import export_to_video
export_to_video(frames, video_path)
return f"βœ… Video generated successfully!", video_path
else:
return "⚠️ Unknown output format.", None
except Exception as e:
return f"❌ Inference error: {str(e)}", None
# ============================================================
# βš™οΈ Load the model
# ============================================================
try:
pipeline = load_flashpack_pipeline("rahul7star/FlashPack")
except Exception as e:
raise SystemExit(f"🚫 Failed to load model: {e}")
# ============================================================
# 🧠 Gradio UI
# ============================================================
with gr.Blocks(title="FlashPack Model – rahul7star/FlashPack", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# ⚑ FlashPack Model Inference
- Loaded from **rahul7star/FlashPack**
- Supports both image and video outputs (depending on model type)
""")
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Textbox(label="Enter your prompt", placeholder="e.g. A robot painting in the rain")
run_btn = gr.Button("πŸš€ Generate", variant="primary")
with gr.Column(scale=1):
result_msg = gr.Textbox(label="Status", interactive=False)
image_out = gr.Image(label="Generated Image")
video_out = gr.Video(label="Generated Video")
run_btn.click(
generate_from_prompt,
inputs=[prompt],
outputs=[result_msg, image_out],
)
# ============================================================
# 🏁 Launch app
# ============================================================
if __name__ == "__main__":
demo.launch(show_error=True)