eksemyashkina's picture
Added files
f096e52
raw
history blame
5.28 kB
from typing import Union, List
from pathlib import Path
import PIL.Image
import torch
from torch import nn
import torchvision.transforms.functional as F
class Bottleneck(nn.Module):
expansion: int = 4
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: nn.Module | None = None,
groups: int = 1,
dilation: int = 1,
) -> None:
super().__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=groups, dilation=dilation, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, stride=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
def forward(self, x: torch.Tensor) -> torch.Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
@property
def expansion(self):
return 4
def __init__(
self,
num_classes: int = 1000,
weights_path: str | None = None,
) -> None:
super().__init__()
if weights_path is not None and not Path(weights_path).exists():
raise FileNotFoundError(weights_path)
self.inplanes = 64
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(Bottleneck, 64, 3)
self.layer2 = self._make_layer(Bottleneck, 128, 4, stride=2)
self.layer3 = self._make_layer(Bottleneck, 256, 6, stride=2)
self.layer4 = self._make_layer(Bottleneck, 512, 3, stride=2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * Bottleneck.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
if weights_path:
self.load_pretrained_weights(weights_path)
def _make_layer(
self,
block: Bottleneck,
planes: int,
blocks: int,
stride: int = 1,
) -> nn.Sequential:
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def load_pretrained_weights(self, weights_path: str) -> None:
state_dict = torch.load(weights_path, map_location="cpu")
self.load_state_dict(state_dict)
@torch.inference_mode()
def predict(self, x: torch.Tensor, top_k: int | None) -> Union[List[int], List[List[int]]]:
output = self.forward(x)
probs = torch.nn.functional.softmax(output, dim=1)
if top_k is not None:
preds = torch.topk(probs, dim=1, k=top_k).indices
return preds.tolist()
else:
pred = torch.argmax(probs, dim=1)
return pred.tolist()
if __name__ == "__main__":
model = ResNet(weights_path="weights/resnet50-0676ba61.pth")
num_params = sum([p.numel() for p in model.parameters()])
print(f"params: {num_params/1e6:.2f} M")
model.eval()
image = PIL.Image.open("assets\cat.jpg").convert("RGB")
image = F.resize(image, (224, 224))
image = F.to_tensor(image)
image = F.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
image = image.unsqueeze(0)
predicted_class = model.predict(image, top_k=10)
print(f"predicted class: {predicted_class}")
# https://deeplearning.cms.waikato.ac.nz/user-guide/class-maps/IMAGENET/