Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -17,33 +17,34 @@ if not api_key:
|
|
| 17 |
|
| 18 |
client = genai.Client(api_key=api_key)
|
| 19 |
|
| 20 |
-
# Định nghĩa mô hình SRCNN
|
| 21 |
-
def SRCNN():
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
|
|
|
| 29 |
|
| 30 |
-
# Khởi tạo mô hình SRCNN và tải trọng số từ Hugging Face
|
| 31 |
-
model = SRCNN()
|
| 32 |
-
pth_path = hf_hub_download(repo_id="sunbv56/srcnn", filename="srcnn_model.pth")
|
| 33 |
-
model.load_state_dict(torch.load(pth_path, map_location=torch.device('cpu')))
|
| 34 |
-
model.eval()
|
| 35 |
|
| 36 |
-
def upscale_image(image):
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
|
| 48 |
def load_image_as_bytes(image_path):
|
| 49 |
"""Chuyển ảnh thành dữ liệu nhị phân"""
|
|
@@ -53,11 +54,10 @@ def load_image_as_bytes(image_path):
|
|
| 53 |
img.save(img_bytes, format="JPEG") # Lưu ảnh vào buffer
|
| 54 |
return img_bytes.getvalue() # Lấy dữ liệu nhị phân
|
| 55 |
|
| 56 |
-
async def generate_image(
|
| 57 |
"""Gửi request và nhận kết quả từ Gemini API"""
|
| 58 |
-
|
| 59 |
-
if
|
| 60 |
-
contents.append(types.Part(inline_data=types.Blob(data=image_bytes, mime_type="image/jpeg")))
|
| 61 |
|
| 62 |
response = await asyncio.to_thread(
|
| 63 |
client.models.generate_content,
|
|
@@ -72,9 +72,6 @@ async def generate_image(image_bytes, text_input):
|
|
| 72 |
if part.inline_data is not None:
|
| 73 |
img = Image.open(BytesIO(part.inline_data.data))
|
| 74 |
images.append(img)
|
| 75 |
-
else:
|
| 76 |
-
print("⚠️ Gemini API không trả về kết quả hợp lệ!")
|
| 77 |
-
|
| 78 |
return images
|
| 79 |
|
| 80 |
async def process_request(image, text, num_requests):
|
|
@@ -85,10 +82,11 @@ async def process_request(image, text, num_requests):
|
|
| 85 |
|
| 86 |
# Hợp nhất danh sách ảnh từ các request
|
| 87 |
original_images = [img for result in results for img in result]
|
| 88 |
-
resized_images = [img.resize((2560,
|
| 89 |
-
srcnn_images = [upscale_image(img) for img in resized_images]
|
| 90 |
|
| 91 |
-
return
|
|
|
|
| 92 |
|
| 93 |
def gradio_interface(image, text, num_requests):
|
| 94 |
"""Hàm Gradio xử lý yêu cầu và trả về ảnh"""
|
|
|
|
| 17 |
|
| 18 |
client = genai.Client(api_key=api_key)
|
| 19 |
|
| 20 |
+
# # Định nghĩa mô hình SRCNN
|
| 21 |
+
# def SRCNN():
|
| 22 |
+
# return nn.Sequential(
|
| 23 |
+
# nn.Conv2d(3
|
| 24 |
+
# , 64, kernel_size=9, padding=4),
|
| 25 |
+
# nn.ReLU(inplace=True),
|
| 26 |
+
# nn.Conv2d(64, 32, kernel_size=5, padding=2),
|
| 27 |
+
# nn.ReLU(inplace=True),
|
| 28 |
+
# nn.Conv2d(32, 3, kernel_size=5, padding=2)
|
| 29 |
+
# )
|
| 30 |
|
| 31 |
+
# # Khởi tạo mô hình SRCNN và tải trọng số từ Hugging Face
|
| 32 |
+
# model = SRCNN()
|
| 33 |
+
# pth_path = hf_hub_download(repo_id="sunbv56/srcnn", filename="srcnn_model.pth")
|
| 34 |
+
# model.load_state_dict(torch.load(pth_path, map_location=torch.device('cpu')))
|
| 35 |
+
# model.eval()
|
| 36 |
|
| 37 |
+
# def upscale_image(image):
|
| 38 |
+
# """Nâng cấp độ phân giải ảnh bằng SRCNN"""
|
| 39 |
+
# transform = transforms.Compose([
|
| 40 |
+
# transforms.ToTensor(),
|
| 41 |
+
# transforms.Lambda(lambda x: x.unsqueeze(0)) # Thêm batch dimension
|
| 42 |
+
# ])
|
| 43 |
+
# input_tensor = transform(image)
|
| 44 |
+
# with torch.no_grad():
|
| 45 |
+
# output_tensor = model(input_tensor)
|
| 46 |
+
# output_image = transforms.ToPILImage()(output_tensor.squeeze(0))
|
| 47 |
+
# return output_image
|
| 48 |
|
| 49 |
def load_image_as_bytes(image_path):
|
| 50 |
"""Chuyển ảnh thành dữ liệu nhị phân"""
|
|
|
|
| 54 |
img.save(img_bytes, format="JPEG") # Lưu ảnh vào buffer
|
| 55 |
return img_bytes.getvalue() # Lấy dữ liệu nhị phân
|
| 56 |
|
| 57 |
+
async def generate_image(image_bytes_list, text_input):
|
| 58 |
"""Gửi request và nhận kết quả từ Gemini API"""
|
| 59 |
+
image_parts = [types.Part(inline_data=types.Blob(data=img, mime_type="image/jpeg")) for img in image_bytes_list if img]
|
| 60 |
+
contents = [text_input, image_parts] if image_parts else [text_input]
|
|
|
|
| 61 |
|
| 62 |
response = await asyncio.to_thread(
|
| 63 |
client.models.generate_content,
|
|
|
|
| 72 |
if part.inline_data is not None:
|
| 73 |
img = Image.open(BytesIO(part.inline_data.data))
|
| 74 |
images.append(img)
|
|
|
|
|
|
|
|
|
|
| 75 |
return images
|
| 76 |
|
| 77 |
async def process_request(image, text, num_requests):
|
|
|
|
| 82 |
|
| 83 |
# Hợp nhất danh sách ảnh từ các request
|
| 84 |
original_images = [img for result in results for img in result]
|
| 85 |
+
resized_images = [img.resize((2560, int(img.height * (2560 / img.width)))), Image.LANCZOS) for img in original_images] # Resize trước khi upscale
|
| 86 |
+
# srcnn_images = [upscale_image(img) for img in resized_images]
|
| 87 |
|
| 88 |
+
return resized_images # 4 ảnh gốc
|
| 89 |
+
# return resized_images + srcnn_images # 4 ảnh gốc + 4 ảnh đã qua SRCNN
|
| 90 |
|
| 91 |
def gradio_interface(image, text, num_requests):
|
| 92 |
"""Hàm Gradio xử lý yêu cầu và trả về ảnh"""
|