File size: 6,056 Bytes
575479b
d9f0003
 
e79e752
 
 
d9f0003
 
 
 
31f781b
d28e1dd
 
575479b
d9f0003
 
 
 
575479b
d9f0003
575479b
d9f0003
5603837
 
 
36ece93
5603837
 
 
 
8e1fa2b
5603837
 
 
 
 
 
 
575479b
cae7746
 
 
9bf5f5a
 
 
 
 
 
 
5603837
9bf5f5a
 
 
 
 
 
5603837
9bf5f5a
cae7746
 
 
 
 
 
 
 
5603837
9bf5f5a
 
f35554d
9bf5f5a
 
 
 
cae7746
 
5603837
9bf5f5a
 
cae7746
9bf5f5a
 
cae7746
9bf5f5a
cae7746
 
 
 
9bf5f5a
 
cae7746
 
9bf5f5a
 
cae7746
9bf5f5a
cae7746
9bf5f5a
 
d818a86
d9f0003
d818a86
d9f0003
9114f57
 
 
 
 
 
797ac3c
9114f57
0c06b9e
5603837
 
 
 
 
0c06b9e
 
575479b
9114f57
d9f0003
9114f57
d818a86
575479b
5603837
9b57f9a
d9f0003
 
747613a
 
d9f0003
5603837
575479b
d9f0003
e7fca3b
 
575479b
 
d9f0003
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import gradio as gr
import asyncio
import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from google import genai
from google.genai import types
from PIL import Image
from io import BytesIO
from huggingface_hub import hf_hub_download
import time
import json

# Cấu hình API Key
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
    raise ValueError("⚠️ GEMINI_API_KEY is missing!")

client = genai.Client(api_key=api_key)

def load_image_as_bytes(image_path):
    """Chuyển ảnh thành dữ liệu nhị phân với kiểm tra lỗi"""
    try:
        with Image.open(image_path) as img:
            img = img.convert("RGB")  # Đảm bảo ảnh là RGB
            img_bytes = BytesIO()
            img.save(img_bytes, format="JPEG")  # Lưu ảnh vào buffer
            return img_bytes.getvalue()  # Lấy dữ liệu nhị phân
    except FileNotFoundError:
        print(f"❌ Lỗi: Không tìm thấy file {image_path}")
        return None
    except UnidentifiedImageError:
        print(f"❌ Lỗi: Không thể mở file {image_path} (định dạng không hợp lệ)")
        return None
    except Exception as e:
        print(f"❌ Lỗi khi đọc ảnh {image_path}: {e}")
        return None

async def generate_image(image_bytes_list, text_input, max_retries=5, retry_delay=2):
    """Gửi request và nhận kết quả từ Gemini API (Xử lý lỗi 429 và lỗi phản hồi không hợp lệ)"""
    for attempt in range(1, max_retries + 1):
        try:
            image_parts = [types.Part(inline_data=types.Blob(data=img, mime_type="image/jpeg")) for img in image_bytes_list if img]
            
            contents = [text_input] + (
                [image_parts] if len(image_parts) == 1 else
                sum([[f"Image {idx+1}:", part] for idx, part in enumerate(image_parts)], []) if len(image_parts) == 2 else []
            )

            response = await asyncio.to_thread(
                client.models.generate_content,
                model="gemini-2.0-flash-exp-image-generation",
                contents=contents,
                config=types.GenerateContentConfig(response_modalities=['Text', 'Image'])
            )

            if not response or not response.candidates or not response.candidates[0].content:
                print(f"⚠️ Phản hồi API không hợp lệ (thử {attempt}/{max_retries})")
                if attempt < max_retries:
                    print(f"🔄 Đang thử lại sau {retry_delay} giây...")
                    time.sleep(retry_delay)
                    continue
                else:
                    print("❌ Đã thử lại nhiều lần nhưng vẫn thất bại!")
                    return []

            images = []
            for part in response.candidates[0].content.parts:
                if part.inline_data is not None:
                    try:
                        img = Image.open(BytesIO(part.inline_data.data))
                        images.append(img)
                    except Exception as e:
                        print(f"❌ Lỗi khi hiển thị ảnh: {e}")
            return images

        except Exception as e:
            error_message = str(e)
            if "429" in error_message and "RESOURCE_EXHAUSTED" in error_message:
                try:
                    error_json = json.loads(error_message.split("RESOURCE_EXHAUSTED. ")[1])
                    retry_seconds = int(error_json["error"]["details"][-1]["retryDelay"][:-1])  # Lấy số giây từ '2s'

                    print(f"⚠️ Quá tải API! Chờ {retry_seconds} giây trước khi thử lại...")

                    for i in range(retry_seconds, 0, -1):
                        print(f"⏳ Thử lại sau {i} giây...", end="\r")
                        time.sleep(1)

                    print("\n🔄 Đang thử lại...")
                    continue  

                except Exception as parse_error:
                    print(f"❌ Lỗi khi phân tích retryDelay: {parse_error}")

            print(f"❌ Lỗi khi gọi API Gemini: {e}")
            return []

async def process_request(images, text, num_requests):
    """Chạy nhiều request song song"""
    image_bytes_list = [load_image_as_bytes(image) if image else None for image in images]
    
    # Tạo nhiều request song song theo số lượng yêu cầu
    tasks = [generate_image(image_bytes_list, text) for _ in range(num_requests)]
    results = await asyncio.gather(*tasks)

    # Gộp tất cả ảnh từ các request
    generated_images = [img for result in results for img in result]
    
    # Resize ảnh giữ nguyên tỷ lệ chiều cao
    resized_images = [img.resize((3840, int(img.height * (3840 / img.width))), Image.LANCZOS) for img in generated_images]
    print("num_requests", num_requests)
    print("tasks", len(tasks))
    print("generated_images", len(generated_images))
    print("resized_images", len(resized_images))
    
    # return generated_images + resized_images
    return resized_images

def gradio_interface(image1, image2, text, num_requests):
    """Hàm Gradio xử lý yêu cầu và trả về ảnh"""
    images = [img for img in [image1, image2] if img]
    return asyncio.run(process_request(images, text, num_requests))

# Tạo giao diện Gradio với slider từ 1 đến 8
demo = gr.Interface(
    fn=gradio_interface,
    inputs=[
        gr.Image(type='filepath', label="Upload hình ảnh 1"),
        gr.Image(type='filepath', label="Upload hình ảnh 2"),
        gr.Textbox(label="Nhập yêu cầu chỉnh sửa hình ảnh"),
        gr.Slider(minimum=1, maximum=8, step=1, value=4, label="Số lượng ảnh cần tạo")  # Tăng lên 8
    ],
    outputs=gr.Gallery(label="Kết quả chỉnh sửa", columns=4),
    title="Chỉnh sửa ảnh bằng Gemini AI + Upscale",
    description="Upload tối đa 2 ảnh và nhập yêu cầu chỉnh sửa. Hiển thị ảnh gốc từ API và ảnh đã qua Upscale.",
)

demo.launch()