Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import asyncio | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as transforms | |
| from google import genai | |
| from google.genai import types | |
| from PIL import Image | |
| from io import BytesIO | |
| from huggingface_hub import hf_hub_download | |
| import time | |
| import json | |
| # Cấu hình API Key | |
| api_key = os.getenv("GEMINI_API_KEY") | |
| if not api_key: | |
| raise ValueError("⚠️ GEMINI_API_KEY is missing!") | |
| client = genai.Client(api_key=api_key) | |
| def load_image_as_bytes(image_path): | |
| """Chuyển ảnh thành dữ liệu nhị phân với kiểm tra lỗi""" | |
| try: | |
| with Image.open(image_path) as img: | |
| img = img.convert("RGB") # Đảm bảo ảnh là RGB | |
| img_bytes = BytesIO() | |
| img.save(img_bytes, format="JPEG") # Lưu ảnh vào buffer | |
| return img_bytes.getvalue() # Lấy dữ liệu nhị phân | |
| except FileNotFoundError: | |
| print(f"❌ Lỗi: Không tìm thấy file {image_path}") | |
| return None | |
| except UnidentifiedImageError: | |
| print(f"❌ Lỗi: Không thể mở file {image_path} (định dạng không hợp lệ)") | |
| return None | |
| except Exception as e: | |
| print(f"❌ Lỗi khi đọc ảnh {image_path}: {e}") | |
| return None | |
| async def generate_image(image_bytes_list, text_input, max_retries=5, retry_delay=2): | |
| """Gửi request và nhận kết quả từ Gemini API (Xử lý lỗi 429 và lỗi phản hồi không hợp lệ)""" | |
| for attempt in range(1, max_retries + 1): | |
| try: | |
| image_parts = [types.Part(inline_data=types.Blob(data=img, mime_type="image/jpeg")) for img in image_bytes_list if img] | |
| contents = [text_input] + ( | |
| [image_parts] if len(image_parts) == 1 else | |
| sum([[f"Image {idx+1}:", part] for idx, part in enumerate(image_parts)], []) if len(image_parts) == 2 else [] | |
| ) | |
| response = await asyncio.to_thread( | |
| client.models.generate_content, | |
| model="gemini-2.0-flash-exp-image-generation", | |
| contents=contents, | |
| config=types.GenerateContentConfig(response_modalities=['Text', 'Image']) | |
| ) | |
| if not response or not response.candidates or not response.candidates[0].content: | |
| print(f"⚠️ Phản hồi API không hợp lệ (thử {attempt}/{max_retries})") | |
| if attempt < max_retries: | |
| print(f"🔄 Đang thử lại sau {retry_delay} giây...") | |
| time.sleep(retry_delay) | |
| continue | |
| else: | |
| print("❌ Đã thử lại nhiều lần nhưng vẫn thất bại!") | |
| return [] | |
| images = [] | |
| for part in response.candidates[0].content.parts: | |
| if part.inline_data is not None: | |
| try: | |
| img = Image.open(BytesIO(part.inline_data.data)) | |
| images.append(img) | |
| except Exception as e: | |
| print(f"❌ Lỗi khi hiển thị ảnh: {e}") | |
| return images | |
| except Exception as e: | |
| error_message = str(e) | |
| if "429" in error_message and "RESOURCE_EXHAUSTED" in error_message: | |
| try: | |
| error_json = json.loads(error_message.split("RESOURCE_EXHAUSTED. ")[1]) | |
| retry_seconds = int(error_json["error"]["details"][-1]["retryDelay"][:-1]) # Lấy số giây từ '2s' | |
| print(f"⚠️ Quá tải API! Chờ {retry_seconds} giây trước khi thử lại...") | |
| for i in range(retry_seconds, 0, -1): | |
| print(f"⏳ Thử lại sau {i} giây...", end="\r") | |
| time.sleep(1) | |
| print("\n🔄 Đang thử lại...") | |
| continue | |
| except Exception as parse_error: | |
| print(f"❌ Lỗi khi phân tích retryDelay: {parse_error}") | |
| print(f"❌ Lỗi khi gọi API Gemini: {e}") | |
| return [] | |
| async def process_request(images, text, num_requests): | |
| """Chạy nhiều request song song""" | |
| image_bytes_list = [load_image_as_bytes(image) if image else None for image in images] | |
| # Tạo nhiều request song song theo số lượng yêu cầu | |
| tasks = [generate_image(image_bytes_list, text) for _ in range(num_requests)] | |
| results = await asyncio.gather(*tasks) | |
| # Gộp tất cả ảnh từ các request | |
| generated_images = [img for result in results for img in result] | |
| # Resize ảnh giữ nguyên tỷ lệ chiều cao | |
| resized_images = [img.resize((3840, int(img.height * (3840 / img.width))), Image.LANCZOS) for img in generated_images] | |
| print("num_requests", num_requests) | |
| print("tasks", len(tasks)) | |
| print("generated_images", len(generated_images)) | |
| print("resized_images", len(resized_images)) | |
| # return generated_images + resized_images | |
| return resized_images | |
| def gradio_interface(image1, image2, text, num_requests): | |
| """Hàm Gradio xử lý yêu cầu và trả về ảnh""" | |
| images = [img for img in [image1, image2] if img] | |
| return asyncio.run(process_request(images, text, num_requests)) | |
| # Tạo giao diện Gradio với slider từ 1 đến 8 | |
| demo = gr.Interface( | |
| fn=gradio_interface, | |
| inputs=[ | |
| gr.Image(type='filepath', label="Upload hình ảnh 1"), | |
| gr.Image(type='filepath', label="Upload hình ảnh 2"), | |
| gr.Textbox(label="Nhập yêu cầu chỉnh sửa hình ảnh"), | |
| gr.Slider(minimum=1, maximum=8, step=1, value=4, label="Số lượng ảnh cần tạo") # Tăng lên 8 | |
| ], | |
| outputs=gr.Gallery(label="Kết quả chỉnh sửa", columns=4), | |
| title="Chỉnh sửa ảnh bằng Gemini AI + Upscale", | |
| description="Upload tối đa 2 ảnh và nhập yêu cầu chỉnh sửa. Hiển thị ảnh gốc từ API và ảnh đã qua Upscale.", | |
| ) | |
| demo.launch() |