import gradio as gr import asyncio import os import torch import torch.nn as nn import torchvision.transforms as transforms from google import genai from google.genai import types from PIL import Image from io import BytesIO from huggingface_hub import hf_hub_download import time import json # Cấu hình API Key api_key = os.getenv("GEMINI_API_KEY") if not api_key: raise ValueError("⚠️ GEMINI_API_KEY is missing!") client = genai.Client(api_key=api_key) def load_image_as_bytes(image_path): """Chuyển ảnh thành dữ liệu nhị phân với kiểm tra lỗi""" try: with Image.open(image_path) as img: img = img.convert("RGB") # Đảm bảo ảnh là RGB img_bytes = BytesIO() img.save(img_bytes, format="JPEG") # Lưu ảnh vào buffer return img_bytes.getvalue() # Lấy dữ liệu nhị phân except FileNotFoundError: print(f"❌ Lỗi: Không tìm thấy file {image_path}") return None except UnidentifiedImageError: print(f"❌ Lỗi: Không thể mở file {image_path} (định dạng không hợp lệ)") return None except Exception as e: print(f"❌ Lỗi khi đọc ảnh {image_path}: {e}") return None async def generate_image(image_bytes_list, text_input, max_retries=5, retry_delay=2): """Gửi request và nhận kết quả từ Gemini API (Xử lý lỗi 429 và lỗi phản hồi không hợp lệ)""" for attempt in range(1, max_retries + 1): try: image_parts = [types.Part(inline_data=types.Blob(data=img, mime_type="image/jpeg")) for img in image_bytes_list if img] contents = [text_input] + ( [image_parts] if len(image_parts) == 1 else sum([[f"Image {idx+1}:", part] for idx, part in enumerate(image_parts)], []) if len(image_parts) == 2 else [] ) response = await asyncio.to_thread( client.models.generate_content, model="gemini-2.0-flash-exp-image-generation", contents=contents, config=types.GenerateContentConfig(response_modalities=['Text', 'Image']) ) if not response or not response.candidates or not response.candidates[0].content: print(f"⚠️ Phản hồi API không hợp lệ (thử {attempt}/{max_retries})") if attempt < max_retries: print(f"🔄 Đang thử lại sau {retry_delay} giây...") time.sleep(retry_delay) continue else: print("❌ Đã thử lại nhiều lần nhưng vẫn thất bại!") return [] images = [] for part in response.candidates[0].content.parts: if part.inline_data is not None: try: img = Image.open(BytesIO(part.inline_data.data)) images.append(img) except Exception as e: print(f"❌ Lỗi khi hiển thị ảnh: {e}") return images except Exception as e: error_message = str(e) if "429" in error_message and "RESOURCE_EXHAUSTED" in error_message: try: error_json = json.loads(error_message.split("RESOURCE_EXHAUSTED. ")[1]) retry_seconds = int(error_json["error"]["details"][-1]["retryDelay"][:-1]) # Lấy số giây từ '2s' print(f"⚠️ Quá tải API! Chờ {retry_seconds} giây trước khi thử lại...") for i in range(retry_seconds, 0, -1): print(f"⏳ Thử lại sau {i} giây...", end="\r") time.sleep(1) print("\n🔄 Đang thử lại...") continue except Exception as parse_error: print(f"❌ Lỗi khi phân tích retryDelay: {parse_error}") print(f"❌ Lỗi khi gọi API Gemini: {e}") return [] async def process_request(images, text, num_requests): """Chạy nhiều request song song""" image_bytes_list = [load_image_as_bytes(image) if image else None for image in images] # Tạo nhiều request song song theo số lượng yêu cầu tasks = [generate_image(image_bytes_list, text) for _ in range(num_requests)] results = await asyncio.gather(*tasks) # Gộp tất cả ảnh từ các request generated_images = [img for result in results for img in result] # Resize ảnh giữ nguyên tỷ lệ chiều cao resized_images = [img.resize((3840, int(img.height * (3840 / img.width))), Image.LANCZOS) for img in generated_images] print("num_requests", num_requests) print("tasks", len(tasks)) print("generated_images", len(generated_images)) print("resized_images", len(resized_images)) # return generated_images + resized_images return resized_images def gradio_interface(image1, image2, text, num_requests): """Hàm Gradio xử lý yêu cầu và trả về ảnh""" images = [img for img in [image1, image2] if img] return asyncio.run(process_request(images, text, num_requests)) # Tạo giao diện Gradio với slider từ 1 đến 8 demo = gr.Interface( fn=gradio_interface, inputs=[ gr.Image(type='filepath', label="Upload hình ảnh 1"), gr.Image(type='filepath', label="Upload hình ảnh 2"), gr.Textbox(label="Nhập yêu cầu chỉnh sửa hình ảnh"), gr.Slider(minimum=1, maximum=8, step=1, value=4, label="Số lượng ảnh cần tạo") # Tăng lên 8 ], outputs=gr.Gallery(label="Kết quả chỉnh sửa", columns=4), title="Chỉnh sửa ảnh bằng Gemini AI + Upscale", description="Upload tối đa 2 ảnh và nhập yêu cầu chỉnh sửa. Hiển thị ảnh gốc từ API và ảnh đã qua Upscale.", ) demo.launch()