sunbv56 commited on
Commit
31f781b
·
verified ·
1 Parent(s): d0ca6df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -16
app.py CHANGED
@@ -3,12 +3,12 @@ import asyncio
3
  import os
4
  import torch
5
  import torch.nn as nn
6
- import torch.optim as optim
7
  import torchvision.transforms as transforms
8
  from google import genai
9
  from google.genai import types
10
  from PIL import Image
11
  from io import BytesIO
 
12
 
13
  # Cấu hình API Key
14
  api_key = os.getenv("GEMINI_API_KEY")
@@ -18,23 +18,19 @@ if not api_key:
18
  client = genai.Client(api_key=api_key)
19
 
20
  # Định nghĩa mô hình SRCNN
21
- class SRCNN(nn.Module):
22
- def __init__(self):
23
- super(SRCNN, self).__init__()
24
- self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=4)
25
- self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=2)
26
- self.conv3 = nn.Conv2d(32, 3, kernel_size=5, padding=2)
27
- self.relu = nn.ReLU()
28
-
29
- def forward(self, x):
30
- x = self.relu(self.conv1(x))
31
- x = self.relu(self.conv2(x))
32
- x = self.conv3(x)
33
- return x
34
 
35
- # Khởi tạo mô hình SRCNN
36
  model = SRCNN()
37
- model.load_state_dict(torch.load("srcnn.pth", map_location=torch.device('cpu')))
 
38
  model.eval()
39
 
40
  def upscale_image(image):
 
3
  import os
4
  import torch
5
  import torch.nn as nn
 
6
  import torchvision.transforms as transforms
7
  from google import genai
8
  from google.genai import types
9
  from PIL import Image
10
  from io import BytesIO
11
+ from huggingface_hub import hf_hub_download
12
 
13
  # Cấu hình API Key
14
  api_key = os.getenv("GEMINI_API_KEY")
 
18
  client = genai.Client(api_key=api_key)
19
 
20
  # Định nghĩa mô hình SRCNN
21
+ def SRCNN():
22
+ return nn.Sequential(
23
+ nn.Conv2d(3, 64, kernel_size=9, padding=4),
24
+ nn.ReLU(inplace=True),
25
+ nn.Conv2d(64, 32, kernel_size=5, padding=2),
26
+ nn.ReLU(inplace=True),
27
+ nn.Conv2d(32, 3, kernel_size=5, padding=2)
28
+ )
 
 
 
 
 
29
 
30
+ # Khởi tạo mô hình SRCNN và tải trọng số từ Hugging Face
31
  model = SRCNN()
32
+ pth_path = hf_hub_download(repo_id="sunbv56/srcnn", filename="srcnn.pth")
33
+ model.load_state_dict(torch.load(pth_path, map_location=torch.device('cpu')))
34
  model.eval()
35
 
36
  def upscale_image(image):