import gradio as gr import asyncio import os import torch import torchvision.transforms as transforms from torchvision.utils import save_image from google import genai from google.genai import types from PIL import Image from io import BytesIO # Cấu hình API Key api_key = os.getenv("GEMINI_API_KEY") if not api_key: raise ValueError("⚠️ GEMINI_API_KEY is missing!") client = genai.Client(api_key=api_key) # Load SRCNN từ Torch Hub model = torch.hub.load('pytorch/vision:v0.10.0', 'srcnn', pretrained=True) model.eval() def upscale_image(image): """Nâng cấp độ phân giải ảnh bằng SRCNN""" transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x.unsqueeze(0)) # Thêm batch dimension ]) img_tensor = transform(image) with torch.no_grad(): upscaled_tensor = model(img_tensor) upscaled_image = transforms.ToPILImage()(upscaled_tensor.squeeze(0)) return upscaled_image def load_image_as_bytes(image_path): """Chuyển ảnh thành dữ liệu nhị phân""" with Image.open(image_path) as img: img = img.convert("RGB") # Đảm bảo ảnh là RGB img = upscale_image(img) # SRCNN trước khi gửi đi img_bytes = BytesIO() img.save(img_bytes, format="JPEG") # Lưu ảnh vào buffer return img_bytes.getvalue() # Lấy dữ liệu nhị phân async def generate_image(image_bytes, text_input): """Gửi request và nhận kết quả từ Gemini API""" contents = [text_input] if image_bytes: contents.append(types.Part(inline_data=types.Blob(data=image_bytes, mime_type="image/jpeg"))) response = await asyncio.to_thread( client.models.generate_content, model="gemini-2.0-flash-exp-image-generation", contents=contents, config=types.GenerateContentConfig(response_modalities=['Text', 'Image']) ) images = [] for part in response.candidates[0].content.parts: if part.inline_data is not None: img = Image.open(BytesIO(part.inline_data.data)) img = upscale_image(img) # SRCNN sau khi nhận ảnh từ Gemini images.append(img) return images async def process_request(image, text, num_requests): """Chạy nhiều request song song""" image_bytes = load_image_as_bytes(image) if image else None tasks = [generate_image(image_bytes, text) for _ in range(num_requests)] results = await asyncio.gather(*tasks) # Hợp nhất danh sách ảnh từ các request all_images = [img for result in results for img in result] return all_images def gradio_interface(image, text, num_requests): """Hàm Gradio xử lý yêu cầu và trả về ảnh""" return asyncio.run(process_request(image, text, num_requests)) # Tạo giao diện Gradio demo = gr.Interface( fn=gradio_interface, inputs=[ gr.Image(type='filepath', label="Upload hình ảnh"), gr.Textbox(label="Nhập yêu cầu chỉnh sửa hình ảnh"), gr.Slider(minimum=1, maximum=4, step=1, value=4, label="Số lượng ảnh cần tạo") ], outputs=gr.Gallery(label="Kết quả chỉnh sửa", columns=4), title="Chỉnh sửa ảnh bằng Gemini AI + SRCNN", description="Upload ảnh và nhập yêu cầu chỉnh sửa. Ảnh được nâng cấp độ phân giải trước và sau khi xử lý.", ) demo.launch()