rexologue commited on
Commit
0228949
·
verified ·
1 Parent(s): dd484fb

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +41 -13
README.md CHANGED
@@ -108,19 +108,47 @@ This repository hosts a fine-tuned `vit_large_patch16_384` classifier
108
  ## Usage
109
 
110
  ```python
111
- import timm
112
- import torch
 
 
 
113
 
114
- model = timm.create_model(
115
- "vit_large_patch16_384",
116
- num_classes=92,
117
- pretrained=False,
118
- )
119
- state_dict = torch.hub.load_state_dict_from_url(
120
- "https://huggingface.co/rexologue/vit_large_384_for_trees/resolve/main/pytorch_model.bin",
121
- map_location="cpu",
122
- file_name="rexologue--vit_large_384_for_trees.bin",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  )
124
- model.load_state_dict(state_dict)
125
- model.eval()
 
 
 
 
 
 
 
126
  ```
 
108
  ## Usage
109
 
110
  ```python
111
+ import json, torch, timm
112
+ from huggingface_hub import hf_hub_download
113
+ from timm.data.transforms_factory import create_transform
114
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
115
+ from PIL import Image
116
 
117
+ REPO = "rexologue/vit_large_384_for_trees"
118
+ MODEL_NAME = "vit_large_patch16_384"
119
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
120
+
121
+ # 1) labels
122
+ labels_path = hf_hub_download(REPO, filename="labels.json")
123
+ with open(labels_path, "r", encoding="utf-8") as f:
124
+ raw = json.load(f)
125
+ labels = [raw[str(i)] for i in range(len(raw))] if isinstance(raw, dict) else list(raw)
126
+
127
+ # 2) weights
128
+ ckpt_path = hf_hub_download(REPO, filename="pytorch_model.bin")
129
+ state = torch.load(ckpt_path, map_location="cpu")
130
+ if any(k.startswith("module.") for k in state): # DDP fix
131
+ state = {k.replace("module.", "", 1): v for k, v in state.items()}
132
+
133
+ # 3) model
134
+ model = timm.create_model(MODEL_NAME, num_classes=len(labels), pretrained=False)
135
+ model.load_state_dict(state, strict=True)
136
+ model.to(DEVICE).eval()
137
+
138
+ # 4) preprocessing (ViT-L/16 @ 384 w/ ImageNet mean/std + bicubic)
139
+ transform = create_transform(
140
+ input_size=(3, 384, 384),
141
+ interpolation="bicubic",
142
+ mean=IMAGENET_DEFAULT_MEAN,
143
+ std=IMAGENET_DEFAULT_STD,
144
  )
145
+
146
+ # 5) run
147
+ img = Image.open("your_image.jpg").convert("RGB")
148
+ x = transform(img).unsqueeze(0).to(DEVICE)
149
+ with torch.no_grad():
150
+ logits = model(x)
151
+ probs = torch.softmax(logits, dim=1)[0].cpu()
152
+ topk = probs.topk(k=min(5, len(labels)))
153
+ print([(labels[i], float(probs[i])) for i in topk.indices])
154
  ```