|
|
import os |
|
|
import torch |
|
|
import gradio as gr |
|
|
from diffusers import DiffusionPipeline |
|
|
from flashpack.integrations.diffusers import ( |
|
|
FlashPackDiffusersModelMixin, |
|
|
FlashPackDiffusionPipeline, |
|
|
) |
|
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"π§ Using device: {device}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FlashPackMyPipeline(DiffusionPipeline, FlashPackDiffusionPipeline): |
|
|
def __init__(self, *args, **kwargs): |
|
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_flashpack_pipeline(repo_id: str = "rahul7star/FlashPack"): |
|
|
""" |
|
|
Loads a FlashPack pipeline from Hugging Face Hub. |
|
|
Falls back to local snapshot if network or metadata issue occurs. |
|
|
""" |
|
|
print(f"π Loading FlashPack pipeline from: {repo_id}") |
|
|
|
|
|
try: |
|
|
|
|
|
pipeline = FlashPackMyPipeline.from_pretrained_flashpack(repo_id) |
|
|
print("β
Successfully loaded FlashPack pipeline from Hugging Face Hub.") |
|
|
except Exception as e: |
|
|
print(f"β οΈ Hub load failed: {e}") |
|
|
print("β¬ Attempting to load via snapshot_download...") |
|
|
try: |
|
|
local_dir = snapshot_download(repo_id=repo_id) |
|
|
pipeline = FlashPackMyPipeline.from_pretrained_flashpack(local_dir) |
|
|
print(f"β
Loaded FlashPack pipeline from local snapshot: {local_dir}") |
|
|
except Exception as e2: |
|
|
raise RuntimeError(f"β Failed to load FlashPack model: {e2}") |
|
|
|
|
|
pipeline.to(device) |
|
|
return pipeline |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_from_prompt(prompt: str): |
|
|
if not prompt or prompt.strip() == "": |
|
|
return "Please enter a valid prompt.", None |
|
|
|
|
|
try: |
|
|
output = pipeline(prompt) |
|
|
if hasattr(output, "images"): |
|
|
img = output.images[0] |
|
|
return f"β
Generated successfully!", img |
|
|
elif hasattr(output, "frames"): |
|
|
frames = output.frames |
|
|
video_path = "/tmp/generated.mp4" |
|
|
from diffusers.utils import export_to_video |
|
|
export_to_video(frames, video_path) |
|
|
return f"β
Video generated successfully!", video_path |
|
|
else: |
|
|
return "β οΈ Unknown output format.", None |
|
|
except Exception as e: |
|
|
return f"β Inference error: {str(e)}", None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
pipeline = load_flashpack_pipeline("rahul7star/FlashPack") |
|
|
except Exception as e: |
|
|
raise SystemExit(f"π« Failed to load model: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="FlashPack Model β rahul7star/FlashPack", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown(""" |
|
|
# β‘ FlashPack Model Inference |
|
|
- Loaded from **rahul7star/FlashPack** |
|
|
- Supports both image and video outputs (depending on model type) |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
prompt = gr.Textbox(label="Enter your prompt", placeholder="e.g. A robot painting in the rain") |
|
|
run_btn = gr.Button("π Generate", variant="primary") |
|
|
with gr.Column(scale=1): |
|
|
result_msg = gr.Textbox(label="Status", interactive=False) |
|
|
image_out = gr.Image(label="Generated Image") |
|
|
video_out = gr.Video(label="Generated Video") |
|
|
|
|
|
run_btn.click( |
|
|
generate_from_prompt, |
|
|
inputs=[prompt], |
|
|
outputs=[result_msg, image_out], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(show_error=True) |
|
|
|