Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -3,12 +3,12 @@ import asyncio
|
|
| 3 |
import os
|
| 4 |
import torch
|
| 5 |
import torch.nn as nn
|
| 6 |
-
import torch.optim as optim
|
| 7 |
import torchvision.transforms as transforms
|
| 8 |
from google import genai
|
| 9 |
from google.genai import types
|
| 10 |
from PIL import Image
|
| 11 |
from io import BytesIO
|
|
|
|
| 12 |
|
| 13 |
# Cấu hình API Key
|
| 14 |
api_key = os.getenv("GEMINI_API_KEY")
|
|
@@ -18,23 +18,19 @@ if not api_key:
|
|
| 18 |
client = genai.Client(api_key=api_key)
|
| 19 |
|
| 20 |
# Định nghĩa mô hình SRCNN
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def forward(self, x):
|
| 30 |
-
x = self.relu(self.conv1(x))
|
| 31 |
-
x = self.relu(self.conv2(x))
|
| 32 |
-
x = self.conv3(x)
|
| 33 |
-
return x
|
| 34 |
|
| 35 |
-
# Khởi tạo mô hình SRCNN
|
| 36 |
model = SRCNN()
|
| 37 |
-
|
|
|
|
| 38 |
model.eval()
|
| 39 |
|
| 40 |
def upscale_image(image):
|
|
|
|
| 3 |
import os
|
| 4 |
import torch
|
| 5 |
import torch.nn as nn
|
|
|
|
| 6 |
import torchvision.transforms as transforms
|
| 7 |
from google import genai
|
| 8 |
from google.genai import types
|
| 9 |
from PIL import Image
|
| 10 |
from io import BytesIO
|
| 11 |
+
from huggingface_hub import hf_hub_download
|
| 12 |
|
| 13 |
# Cấu hình API Key
|
| 14 |
api_key = os.getenv("GEMINI_API_KEY")
|
|
|
|
| 18 |
client = genai.Client(api_key=api_key)
|
| 19 |
|
| 20 |
# Định nghĩa mô hình SRCNN
|
| 21 |
+
def SRCNN():
|
| 22 |
+
return nn.Sequential(
|
| 23 |
+
nn.Conv2d(3, 64, kernel_size=9, padding=4),
|
| 24 |
+
nn.ReLU(inplace=True),
|
| 25 |
+
nn.Conv2d(64, 32, kernel_size=5, padding=2),
|
| 26 |
+
nn.ReLU(inplace=True),
|
| 27 |
+
nn.Conv2d(32, 3, kernel_size=5, padding=2)
|
| 28 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
+
# Khởi tạo mô hình SRCNN và tải trọng số từ Hugging Face
|
| 31 |
model = SRCNN()
|
| 32 |
+
pth_path = hf_hub_download(repo_id="sunbv56/srcnn", filename="srcnn.pth")
|
| 33 |
+
model.load_state_dict(torch.load(pth_path, map_location=torch.device('cpu')))
|
| 34 |
model.eval()
|
| 35 |
|
| 36 |
def upscale_image(image):
|