sunbv56 commited on
Commit
c426cb4
·
verified ·
1 Parent(s): 96a99a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -35
app.py CHANGED
@@ -17,33 +17,34 @@ if not api_key:
17
 
18
  client = genai.Client(api_key=api_key)
19
 
20
- # Định nghĩa mô hình SRCNN
21
- def SRCNN():
22
- return nn.Sequential(
23
- nn.Conv2d(3, 64, kernel_size=9, padding=4),
24
- nn.ReLU(inplace=True),
25
- nn.Conv2d(64, 32, kernel_size=5, padding=2),
26
- nn.ReLU(inplace=True),
27
- nn.Conv2d(32, 3, kernel_size=5, padding=2)
28
- )
 
29
 
30
- # Khởi tạo mô hình SRCNN và tải trọng số từ Hugging Face
31
- model = SRCNN()
32
- pth_path = hf_hub_download(repo_id="sunbv56/srcnn", filename="srcnn_model.pth")
33
- model.load_state_dict(torch.load(pth_path, map_location=torch.device('cpu')))
34
- model.eval()
35
 
36
- def upscale_image(image):
37
- """Nâng cấp độ phân giải ảnh bằng SRCNN"""
38
- transform = transforms.Compose([
39
- transforms.ToTensor(),
40
- transforms.Lambda(lambda x: x.unsqueeze(0)) # Thêm batch dimension
41
- ])
42
- input_tensor = transform(image)
43
- with torch.no_grad():
44
- output_tensor = model(input_tensor)
45
- output_image = transforms.ToPILImage()(output_tensor.squeeze(0))
46
- return output_image
47
 
48
  def load_image_as_bytes(image_path):
49
  """Chuyển ảnh thành dữ liệu nhị phân"""
@@ -53,11 +54,10 @@ def load_image_as_bytes(image_path):
53
  img.save(img_bytes, format="JPEG") # Lưu ảnh vào buffer
54
  return img_bytes.getvalue() # Lấy dữ liệu nhị phân
55
 
56
- async def generate_image(image_bytes, text_input):
57
  """Gửi request và nhận kết quả từ Gemini API"""
58
- contents = [text_input]
59
- if image_bytes:
60
- contents.append(types.Part(inline_data=types.Blob(data=image_bytes, mime_type="image/jpeg")))
61
 
62
  response = await asyncio.to_thread(
63
  client.models.generate_content,
@@ -72,9 +72,6 @@ async def generate_image(image_bytes, text_input):
72
  if part.inline_data is not None:
73
  img = Image.open(BytesIO(part.inline_data.data))
74
  images.append(img)
75
- else:
76
- print("⚠️ Gemini API không trả về kết quả hợp lệ!")
77
-
78
  return images
79
 
80
  async def process_request(image, text, num_requests):
@@ -85,10 +82,11 @@ async def process_request(image, text, num_requests):
85
 
86
  # Hợp nhất danh sách ảnh từ các request
87
  original_images = [img for result in results for img in result]
88
- resized_images = [img.resize((2560, 1440), Image.LANCZOS) for img in original_images] # Resize trước khi upscale
89
- srcnn_images = [upscale_image(img) for img in resized_images]
90
 
91
- return original_images + srcnn_images # 4 ảnh gốc + 4 ảnh đã qua SRCNN
 
92
 
93
  def gradio_interface(image, text, num_requests):
94
  """Hàm Gradio xử lý yêu cầu và trả về ảnh"""
 
17
 
18
  client = genai.Client(api_key=api_key)
19
 
20
+ # # Định nghĩa mô hình SRCNN
21
+ # def SRCNN():
22
+ # return nn.Sequential(
23
+ # nn.Conv2d(3
24
+ # , 64, kernel_size=9, padding=4),
25
+ # nn.ReLU(inplace=True),
26
+ # nn.Conv2d(64, 32, kernel_size=5, padding=2),
27
+ # nn.ReLU(inplace=True),
28
+ # nn.Conv2d(32, 3, kernel_size=5, padding=2)
29
+ # )
30
 
31
+ # # Khởi tạo mô hình SRCNN và tải trọng số từ Hugging Face
32
+ # model = SRCNN()
33
+ # pth_path = hf_hub_download(repo_id="sunbv56/srcnn", filename="srcnn_model.pth")
34
+ # model.load_state_dict(torch.load(pth_path, map_location=torch.device('cpu')))
35
+ # model.eval()
36
 
37
+ # def upscale_image(image):
38
+ # """Nâng cấp độ phân giải ảnh bằng SRCNN"""
39
+ # transform = transforms.Compose([
40
+ # transforms.ToTensor(),
41
+ # transforms.Lambda(lambda x: x.unsqueeze(0)) # Thêm batch dimension
42
+ # ])
43
+ # input_tensor = transform(image)
44
+ # with torch.no_grad():
45
+ # output_tensor = model(input_tensor)
46
+ # output_image = transforms.ToPILImage()(output_tensor.squeeze(0))
47
+ # return output_image
48
 
49
  def load_image_as_bytes(image_path):
50
  """Chuyển ảnh thành dữ liệu nhị phân"""
 
54
  img.save(img_bytes, format="JPEG") # Lưu ảnh vào buffer
55
  return img_bytes.getvalue() # Lấy dữ liệu nhị phân
56
 
57
+ async def generate_image(image_bytes_list, text_input):
58
  """Gửi request và nhận kết quả từ Gemini API"""
59
+ image_parts = [types.Part(inline_data=types.Blob(data=img, mime_type="image/jpeg")) for img in image_bytes_list if img]
60
+ contents = [text_input, image_parts] if image_parts else [text_input]
 
61
 
62
  response = await asyncio.to_thread(
63
  client.models.generate_content,
 
72
  if part.inline_data is not None:
73
  img = Image.open(BytesIO(part.inline_data.data))
74
  images.append(img)
 
 
 
75
  return images
76
 
77
  async def process_request(image, text, num_requests):
 
82
 
83
  # Hợp nhất danh sách ảnh từ các request
84
  original_images = [img for result in results for img in result]
85
+ resized_images = [img.resize((2560, int(img.height * (2560 / img.width)))), Image.LANCZOS) for img in original_images] # Resize trước khi upscale
86
+ # srcnn_images = [upscale_image(img) for img in resized_images]
87
 
88
+ return resized_images # 4 ảnh gốc
89
+ # return resized_images + srcnn_images # 4 ảnh gốc + 4 ảnh đã qua SRCNN
90
 
91
  def gradio_interface(image, text, num_requests):
92
  """Hàm Gradio xử lý yêu cầu và trả về ảnh"""