Update README.md
Browse files
README.md
CHANGED
|
@@ -108,19 +108,47 @@ This repository hosts a fine-tuned `vit_large_patch16_384` classifier
|
|
| 108 |
## Usage
|
| 109 |
|
| 110 |
```python
|
| 111 |
-
import timm
|
| 112 |
-
import
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
)
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
)
|
| 124 |
-
|
| 125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
```
|
|
|
|
| 108 |
## Usage
|
| 109 |
|
| 110 |
```python
|
| 111 |
+
import json, torch, timm
|
| 112 |
+
from huggingface_hub import hf_hub_download
|
| 113 |
+
from timm.data.transforms_factory import create_transform
|
| 114 |
+
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
| 115 |
+
from PIL import Image
|
| 116 |
|
| 117 |
+
REPO = "rexologue/vit_large_384_for_trees"
|
| 118 |
+
MODEL_NAME = "vit_large_patch16_384"
|
| 119 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 120 |
+
|
| 121 |
+
# 1) labels
|
| 122 |
+
labels_path = hf_hub_download(REPO, filename="labels.json")
|
| 123 |
+
with open(labels_path, "r", encoding="utf-8") as f:
|
| 124 |
+
raw = json.load(f)
|
| 125 |
+
labels = [raw[str(i)] for i in range(len(raw))] if isinstance(raw, dict) else list(raw)
|
| 126 |
+
|
| 127 |
+
# 2) weights
|
| 128 |
+
ckpt_path = hf_hub_download(REPO, filename="pytorch_model.bin")
|
| 129 |
+
state = torch.load(ckpt_path, map_location="cpu")
|
| 130 |
+
if any(k.startswith("module.") for k in state): # DDP fix
|
| 131 |
+
state = {k.replace("module.", "", 1): v for k, v in state.items()}
|
| 132 |
+
|
| 133 |
+
# 3) model
|
| 134 |
+
model = timm.create_model(MODEL_NAME, num_classes=len(labels), pretrained=False)
|
| 135 |
+
model.load_state_dict(state, strict=True)
|
| 136 |
+
model.to(DEVICE).eval()
|
| 137 |
+
|
| 138 |
+
# 4) preprocessing (ViT-L/16 @ 384 w/ ImageNet mean/std + bicubic)
|
| 139 |
+
transform = create_transform(
|
| 140 |
+
input_size=(3, 384, 384),
|
| 141 |
+
interpolation="bicubic",
|
| 142 |
+
mean=IMAGENET_DEFAULT_MEAN,
|
| 143 |
+
std=IMAGENET_DEFAULT_STD,
|
| 144 |
)
|
| 145 |
+
|
| 146 |
+
# 5) run
|
| 147 |
+
img = Image.open("your_image.jpg").convert("RGB")
|
| 148 |
+
x = transform(img).unsqueeze(0).to(DEVICE)
|
| 149 |
+
with torch.no_grad():
|
| 150 |
+
logits = model(x)
|
| 151 |
+
probs = torch.softmax(logits, dim=1)[0].cpu()
|
| 152 |
+
topk = probs.topk(k=min(5, len(labels)))
|
| 153 |
+
print([(labels[i], float(probs[i])) for i in topk.indices])
|
| 154 |
```
|