|
|
import gradio as gr |
|
|
import asyncio |
|
|
import os |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torchvision.transforms as transforms |
|
|
from google import genai |
|
|
from google.genai import types |
|
|
from PIL import Image |
|
|
from io import BytesIO |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
api_key = os.getenv("GEMINI_API_KEY") |
|
|
if not api_key: |
|
|
raise ValueError("⚠️ GEMINI_API_KEY is missing!") |
|
|
|
|
|
client = genai.Client(api_key=api_key) |
|
|
|
|
|
def load_image_as_bytes(image_path): |
|
|
"""Chuyển ảnh thành dữ liệu nhị phân với kiểm tra lỗi""" |
|
|
try: |
|
|
with Image.open(image_path) as img: |
|
|
img = img.convert("RGB") |
|
|
img_bytes = BytesIO() |
|
|
img.save(img_bytes, format="JPEG") |
|
|
return img_bytes.getvalue() |
|
|
except FileNotFoundError: |
|
|
print(f"❌ Lỗi: Không tìm thấy file {image_path}") |
|
|
return None |
|
|
except UnidentifiedImageError: |
|
|
print(f"❌ Lỗi: Không thể mở file {image_path} (định dạng không hợp lệ)") |
|
|
return None |
|
|
except Exception as e: |
|
|
print(f"❌ Lỗi khi đọc ảnh {image_path}: {e}") |
|
|
return None |
|
|
|
|
|
async def generate_image(image_bytes_list, text_input): |
|
|
"""Gửi request và nhận kết quả từ Gemini API (Xử lý lỗi 429)""" |
|
|
while True: |
|
|
try: |
|
|
image_parts = [types.Part(inline_data=types.Blob(data=img, mime_type="image/jpeg")) for img in image_bytes_list if img] |
|
|
contents = [text_input, image_parts] if image_parts else [text_input] |
|
|
|
|
|
response = await asyncio.to_thread( |
|
|
client.models.generate_content, |
|
|
model="gemini-2.0-flash-exp-image-generation", |
|
|
contents=contents, |
|
|
config=types.GenerateContentConfig(response_modalities=['Text', 'Image']) |
|
|
) |
|
|
|
|
|
if not response or not response.candidates or not response.candidates[0].content: |
|
|
print("❌ Lỗi: Phản hồi API không hợp lệ") |
|
|
return [] |
|
|
|
|
|
images = [] |
|
|
for part in response.candidates[0].content.parts: |
|
|
if part.inline_data is not None: |
|
|
try: |
|
|
img = Image.open(BytesIO(part.inline_data.data)) |
|
|
images.append(img) |
|
|
except Exception as e: |
|
|
print(f"❌ Lỗi khi hiển thị ảnh: {e}") |
|
|
return images |
|
|
|
|
|
except Exception as e: |
|
|
error_message = str(e) |
|
|
if "429" in error_message and "RESOURCE_EXHAUSTED" in error_message: |
|
|
try: |
|
|
|
|
|
error_json = json.loads(error_message.split("RESOURCE_EXHAUSTED. ")[1]) |
|
|
retry_delay = int(error_json["error"]["details"][-1]["retryDelay"][:-1]) |
|
|
|
|
|
print(f"⚠️ Đã vượt quá hạn mức! Chờ {retry_delay} giây trước khi thử lại...") |
|
|
|
|
|
|
|
|
for i in range(retry_delay, 0, -1): |
|
|
print(f"⏳ Thử lại sau {i} giây...", end="\r") |
|
|
time.sleep(1) |
|
|
|
|
|
print("\n🔄 Đang thử lại...") |
|
|
continue |
|
|
|
|
|
except Exception as parse_error: |
|
|
print(f"❌ Lỗi khi phân tích retryDelay: {parse_error}") |
|
|
|
|
|
print(f"❌ Lỗi khi gọi API Gemini: {e}") |
|
|
return [] |
|
|
|
|
|
async def process_request(images, text, num_requests): |
|
|
"""Chạy nhiều request song song""" |
|
|
image_bytes_list = [load_image_as_bytes(image) if image else None for image in images] |
|
|
|
|
|
|
|
|
tasks = [generate_image(image_bytes_list, text) for _ in range(num_requests)] |
|
|
results = await asyncio.gather(*tasks) |
|
|
|
|
|
|
|
|
generated_images = [img for result in results for img in result] |
|
|
|
|
|
|
|
|
resized_images = [img.resize((3840, int(img.height * (3840 / img.width))), Image.LANCZOS) for img in generated_images] |
|
|
print("num_requests", num_requests) |
|
|
print("tasks", len(tasks)) |
|
|
print("generated_images", len(generated_images)) |
|
|
print("resized_images", len(resized_images)) |
|
|
|
|
|
|
|
|
return resized_images |
|
|
|
|
|
def gradio_interface(image1, image2, text, num_requests): |
|
|
"""Hàm Gradio xử lý yêu cầu và trả về ảnh""" |
|
|
images = [img for img in [image1, image2] if img] |
|
|
return asyncio.run(process_request(images, text, num_requests)) |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=gradio_interface, |
|
|
inputs=[ |
|
|
gr.Image(type='filepath', label="Upload hình ảnh 1"), |
|
|
gr.Image(type='filepath', label="Upload hình ảnh 2"), |
|
|
gr.Textbox(label="Nhập yêu cầu chỉnh sửa hình ảnh"), |
|
|
gr.Slider(minimum=1, maximum=8, step=1, value=4, label="Số lượng ảnh cần tạo") |
|
|
], |
|
|
outputs=gr.Gallery(label="Kết quả chỉnh sửa", columns=4), |
|
|
title="Chỉnh sửa ảnh bằng Gemini AI + SRCNN", |
|
|
description="Upload tối đa 2 ảnh và nhập yêu cầu chỉnh sửa. Hiển thị ảnh gốc từ API và ảnh đã qua SRCNN.", |
|
|
) |
|
|
|
|
|
demo.launch() |