lexandstuff commited on
Commit
03be8ae
·
verified ·
1 Parent(s): c4b55b2

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. README.md +75 -3
  2. model.safetensors +3 -0
README.md CHANGED
@@ -1,3 +1,75 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ library_name: mlx-image
4
+ tags:
5
+ - mlx
6
+ - mlx-image
7
+ - vision
8
+ - image-classification
9
+ datasets:
10
+ - imagenet-1k
11
+ ---
12
+
13
+ # efficientnet_b2
14
+
15
+ An EfficientNet B2 model architecture, pretrained on ImageNet-1K.
16
+
17
+ Disclaimer: this is a port of the Torchvision model weights to Apple MLX Framework.
18
+
19
+ See [mlx-convert-scripts](https://github.com/lextoumbourou/mlx-convert-scripts) repo for the conversion script used.
20
+
21
+ ## How to use
22
+
23
+ ```bash
24
+ pip install mlx-image
25
+ ```
26
+
27
+ Here is how to use this model for image classification:
28
+
29
+ ```python
30
+ import mlx.core as mx
31
+ from mlxim.model import create_model
32
+ from mlxim.io import read_rgb
33
+ from mlxim.transform import ImageNetTransform
34
+ from mlxim.utils.imagenet import IMAGENET2012_CLASSES
35
+
36
+ transform = ImageNetTransform(train=False, img_size=288)
37
+ x = transform(read_rgb("cat.jpg"))
38
+ x = mx.array(x)
39
+ x = mx.expand_dims(x, 0)
40
+
41
+ model = create_model("efficientnet_b2")
42
+ model.eval()
43
+
44
+ logits = model(x)
45
+ predicted_idx = mx.argmax(logits, axis=-1).item()
46
+ predicted_class = list(IMAGENET2012_CLASSES.values())[predicted_idx]
47
+
48
+ print(f"Predicted class: {predicted_class}")
49
+ ```
50
+
51
+ You can also use the embeds from layer before head:
52
+
53
+ ```python
54
+ import mlx.core as mx
55
+ from mlxim.model import create_model
56
+ from mlxim.io import read_rgb
57
+ from mlxim.transform import ImageNetTransform
58
+
59
+ transform = ImageNetTransform(train=False, img_size=288)
60
+ x = transform(read_rgb("cat.jpg"))
61
+ x = mx.array(x)
62
+ x = mx.expand_dims(x, 0)
63
+
64
+ # first option
65
+ model = create_model("efficientnet_b2", num_classes=0)
66
+ model.eval()
67
+
68
+ embeds = model(x)
69
+
70
+ # second option
71
+ model = create_model("efficientnet_b2")
72
+ model.eval()
73
+
74
+ embeds = model.get_features(x)
75
+ ```
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:784fe3d24f0e33f1557ffa86907687a0258aaad96a9aaeacec678d7d036ef884
3
+ size 36756614