File size: 5,546 Bytes
575479b d9f0003 e79e752 d9f0003 31f781b 575479b d9f0003 575479b d9f0003 575479b d9f0003 5603837 8e1fa2b 5603837 575479b b8f4312 5603837 b8f4312 5603837 b8f4312 5603837 b8f4312 5603837 b8f4312 5603837 a18913a 5603837 a18913a 5603837 b8f4312 a18913a b8f4312 5603837 b8f4312 5603837 b8f4312 5603837 b8f4312 5603837 575479b d818a86 d9f0003 d818a86 d9f0003 9114f57 797ac3c 9114f57 0c06b9e 5603837 0c06b9e 575479b 9114f57 d9f0003 9114f57 d818a86 575479b 5603837 9b57f9a d9f0003 747613a d9f0003 5603837 575479b d9f0003 b8f4312 575479b d9f0003 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import gradio as gr
import asyncio
import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from google import genai
from google.genai import types
from PIL import Image
from io import BytesIO
from huggingface_hub import hf_hub_download
# Cấu hình API Key
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
raise ValueError("⚠️ GEMINI_API_KEY is missing!")
client = genai.Client(api_key=api_key)
def load_image_as_bytes(image_path):
"""Chuyển ảnh thành dữ liệu nhị phân với kiểm tra lỗi"""
try:
with Image.open(image_path) as img:
img = img.convert("RGB") # Đảm bảo ảnh là RGB
img_bytes = BytesIO()
img.save(img_bytes, format="JPEG") # Lưu ảnh vào buffer
return img_bytes.getvalue() # Lấy dữ liệu nhị phân
except FileNotFoundError:
print(f"❌ Lỗi: Không tìm thấy file {image_path}")
return None
except UnidentifiedImageError:
print(f"❌ Lỗi: Không thể mở file {image_path} (định dạng không hợp lệ)")
return None
except Exception as e:
print(f"❌ Lỗi khi đọc ảnh {image_path}: {e}")
return None
async def generate_image(image_bytes_list, text_input):
"""Gửi request và nhận kết quả từ Gemini API (Xử lý lỗi 429)"""
while True:
try:
image_parts = [types.Part(inline_data=types.Blob(data=img, mime_type="image/jpeg")) for img in image_bytes_list if img]
contents = [text_input, image_parts] if image_parts else [text_input]
response = await asyncio.to_thread(
client.models.generate_content,
model="gemini-2.0-flash-exp-image-generation",
contents=contents,
config=types.GenerateContentConfig(response_modalities=['Text', 'Image'])
)
if not response or not response.candidates or not response.candidates[0].content:
print("❌ Lỗi: Phản hồi API không hợp lệ")
return []
images = []
for part in response.candidates[0].content.parts:
if part.inline_data is not None:
try:
img = Image.open(BytesIO(part.inline_data.data))
images.append(img)
except Exception as e:
print(f"❌ Lỗi khi hiển thị ảnh: {e}")
return images
except Exception as e:
error_message = str(e)
if "429" in error_message and "RESOURCE_EXHAUSTED" in error_message:
try:
# Trích xuất retryDelay từ JSON lỗi
error_json = json.loads(error_message.split("RESOURCE_EXHAUSTED. ")[1])
retry_delay = int(error_json["error"]["details"][-1]["retryDelay"][:-1]) # Lấy số giây từ '2s'
print(f"⚠️ Đã vượt quá hạn mức! Chờ {retry_delay} giây trước khi thử lại...")
# Đếm ngược
for i in range(retry_delay, 0, -1):
print(f"⏳ Thử lại sau {i} giây...", end="\r")
time.sleep(1)
print("\n🔄 Đang thử lại...")
continue # Thử lại request
except Exception as parse_error:
print(f"❌ Lỗi khi phân tích retryDelay: {parse_error}")
print(f"❌ Lỗi khi gọi API Gemini: {e}")
return []
async def process_request(images, text, num_requests):
"""Chạy nhiều request song song"""
image_bytes_list = [load_image_as_bytes(image) if image else None for image in images]
# Tạo nhiều request song song theo số lượng yêu cầu
tasks = [generate_image(image_bytes_list, text) for _ in range(num_requests)]
results = await asyncio.gather(*tasks)
# Gộp tất cả ảnh từ các request
generated_images = [img for result in results for img in result]
# Resize ảnh giữ nguyên tỷ lệ chiều cao
resized_images = [img.resize((3840, int(img.height * (3840 / img.width))), Image.LANCZOS) for img in generated_images]
print("num_requests", num_requests)
print("tasks", len(tasks))
print("generated_images", len(generated_images))
print("resized_images", len(resized_images))
# return generated_images + resized_images
return resized_images
def gradio_interface(image1, image2, text, num_requests):
"""Hàm Gradio xử lý yêu cầu và trả về ảnh"""
images = [img for img in [image1, image2] if img]
return asyncio.run(process_request(images, text, num_requests))
# Tạo giao diện Gradio với slider từ 1 đến 8
demo = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Image(type='filepath', label="Upload hình ảnh 1"),
gr.Image(type='filepath', label="Upload hình ảnh 2"),
gr.Textbox(label="Nhập yêu cầu chỉnh sửa hình ảnh"),
gr.Slider(minimum=1, maximum=8, step=1, value=4, label="Số lượng ảnh cần tạo") # Tăng lên 8
],
outputs=gr.Gallery(label="Kết quả chỉnh sửa", columns=4),
title="Chỉnh sửa ảnh bằng Gemini AI + SRCNN",
description="Upload tối đa 2 ảnh và nhập yêu cầu chỉnh sửa. Hiển thị ảnh gốc từ API và ảnh đã qua SRCNN.",
)
demo.launch() |