KennethTM commited on
Commit
d081540
·
verified ·
1 Parent(s): 45c924a

Upload 7 files

Browse files
Files changed (7) hide show
  1. Dockerfile +13 -0
  2. id2spec.bin +3 -0
  3. image_encoder.bin +3 -0
  4. index.html +281 -0
  5. main.py +168 -0
  6. requirements.txt +9 -0
  7. species_features.bin +3 -0
Dockerfile ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10
2
+
3
+ RUN useradd -m -u 1000 user
4
+ USER user
5
+ ENV PATH="/home/user/.local/bin:$PATH"
6
+
7
+ WORKDIR /app
8
+
9
+ COPY --chown=user ./requirements.txt requirements.txt
10
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
11
+
12
+ COPY --chown=user . /app
13
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
id2spec.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:396276ababa85a106aa27022c4d311e54722bd8b0af10a9189365ee940547a74
3
+ size 98956
image_encoder.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2c4888f31c1e368352cb327442975096865a2e627f2e76393d2e386f1c850599
3
+ size 498149900
index.html ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html>
3
+ <head>
4
+ <title>Matchmaking for habitater og arter</title>
5
+ <link rel="stylesheet" href="https://unpkg.com/[email protected]/dist/leaflet.css" />
6
+ <script src="https://unpkg.com/[email protected]/dist/leaflet.js"></script>
7
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/leaflet.draw/1.0.4/leaflet.draw.js"></script>
8
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/leaflet.draw/1.0.4/leaflet.draw.css"/>
9
+ <style>
10
+ html, body, #map {
11
+ height: 100%;
12
+ width: 100%;
13
+ margin: 0;
14
+ padding: 0;
15
+ }
16
+ #downloadButton {
17
+ position: absolute;
18
+ top: 10px;
19
+ right: 10px;
20
+ z-index: 401;
21
+ padding: 10px;
22
+ background-color: white;
23
+ border: 1px solid black;
24
+ cursor: pointer;
25
+ display: none;
26
+ }
27
+
28
+ /* Loading spinner styles */
29
+ .spinner {
30
+ position: absolute;
31
+ width: 40px;
32
+ height: 40px;
33
+ margin: 0;
34
+ background-color: rgba(255, 255, 255, 0.8);
35
+ border-radius: 50%;
36
+ border: 3px solid transparent;
37
+ border-top-color: #3498db;
38
+ border-bottom-color: #3498db;
39
+ animation: spin 2s linear infinite;
40
+ z-index: 1000;
41
+ }
42
+
43
+ /* Add this if not already present */
44
+ @keyframes spin {
45
+ 0% { transform: rotate(0deg); }
46
+ 100% { transform: rotate(360deg); }
47
+ }
48
+
49
+ .spinner-container {
50
+ background: none !important;
51
+ }
52
+
53
+ /* Make sure there's no Leaflet default icon background */
54
+ .leaflet-div-icon {
55
+ background: transparent;
56
+ border: none;
57
+ }
58
+ </style>
59
+ </head>
60
+ <body>
61
+ <div id="map"></div>
62
+ <button id="downloadButton">Download GeoJSON</button>
63
+
64
+ <script>
65
+ var map = L.map('map').setView([56.2, 10.3], 7);
66
+
67
+ L.tileLayer('https://services.datafordeler.dk/GeoDanmarkOrto/orto_foraar_webm/1.0.0/WMTS/orto_foraar_webm/default/DFD_GoogleMapsCompatible/{z}/{y}/{x}.jpg?username=BJSIGPGRVW&password=Panseryrtat*56klinge', {
68
+ attribution: 'CC BY 4.0, GeoDanmark, Forårsbilleder Ortofoto, dataforsyningen.dk',
69
+ maxZoom: 19
70
+ }).addTo(map);
71
+
72
+ var drawnItems = new L.FeatureGroup();
73
+ map.addLayer(drawnItems);
74
+
75
+ var drawControl = new L.Control.Draw({
76
+ draw: {
77
+ polygon: true,
78
+ polyline: false,
79
+ circle: false,
80
+ rectangle: false,
81
+ marker: false,
82
+ circlemarker: false
83
+ },
84
+ edit: {
85
+ featureGroup: drawnItems
86
+ }
87
+ });
88
+ map.addControl(drawControl);
89
+
90
+ map.on('draw:created', function (e) {
91
+ var layer = e.layer;
92
+ drawnItems.addLayer(layer);
93
+ predictAndShow(layer);
94
+ });
95
+
96
+ map.on('draw:edited', function(e){
97
+ var layers = e.layers;
98
+ layers.eachLayer(function(layer) {
99
+ predictAndShow(layer);
100
+ });
101
+ });
102
+
103
+ map.on('draw:deleted', function(e){
104
+ updateDownloadButton();
105
+ });
106
+
107
+
108
+
109
+
110
+ function predictAndShow(layer) {
111
+ var geojson = layer.toGeoJSON();
112
+
113
+ // Get the center of the polygon for placing the spinner
114
+ var bounds = layer.getBounds();
115
+ var center = bounds.getCenter();
116
+
117
+ // Create a more visible spinner with custom HTML
118
+ var spinnerHtml = '<div class="spinner" style="width: 25px; height: 25px; ' +
119
+ 'border: 5px solid #f3f3f3; border-top: 5px solid #3498db; ' +
120
+ 'border-radius: 50%; animation: spin 2s linear infinite;"></div>';
121
+
122
+ var spinner = L.divIcon({
123
+ html: spinnerHtml,
124
+ className: 'spinner-container',
125
+ iconSize: [50, 50],
126
+ iconAnchor: [25, 25] // Center the spinner on the point
127
+ });
128
+
129
+ // Add the spinner to the map
130
+ var loadingMarker = L.marker(center, {
131
+ icon: spinner,
132
+ interactive: false,
133
+ zIndexOffset: 1000 // Ensure spinner appears above other elements
134
+ }).addTo(map);
135
+
136
+ // Change the polygon style to indicate loading
137
+ var originalStyle = {
138
+ color: layer.options.color || '#3388ff',
139
+ fillOpacity: layer.options.fillOpacity || 0.2
140
+ };
141
+
142
+ layer.setStyle({
143
+ fillOpacity: 0.1,
144
+ color: '#aaa'
145
+ });
146
+
147
+ fetch('/predict', {
148
+ method: 'POST',
149
+ headers: {
150
+ 'Content-Type': 'application/json'
151
+ },
152
+ body: JSON.stringify({ geojson: geojson })
153
+ })
154
+ .then(response => response.json())
155
+ .then(data => {
156
+ // Remove the spinner and restore original style
157
+ map.removeLayer(loadingMarker);
158
+ layer.setStyle(originalStyle);
159
+
160
+ var predictions = data.predictions;
161
+ var popupContent = "<b>Arter:</b><br>";
162
+
163
+ // Display all predictions
164
+ Object.entries(predictions).forEach(([species, score]) => {
165
+ popupContent += species + ": " + score.toFixed(2) + "<br>";
166
+ });
167
+
168
+ // Store the popup content in the layer for later use
169
+ layer.popupContent = popupContent;
170
+
171
+ // Only add click handler, no mouseover/hover effects
172
+ layer.on('click', function(e) {
173
+ if (!layer._popup || !map.hasLayer(layer._popup)) {
174
+ var popup = L.popup({
175
+ closeButton: true,
176
+ autoClose: false,
177
+ closeOnEscapeKey: false,
178
+ closeOnClick: false
179
+ })
180
+ .setLatLng(e.latlng)
181
+ .setContent(layer.popupContent);
182
+
183
+ layer.bindPopup(popup).openPopup();
184
+ } else {
185
+ layer.closePopup();
186
+ }
187
+ });
188
+
189
+ // Ensure the feature object exists before assigning to it
190
+ if (!layer.feature) {
191
+ layer.feature = {};
192
+ }
193
+ if (!layer.feature.properties) {
194
+ layer.feature.properties = {};
195
+ }
196
+
197
+ // Store both the raw predictions and formatted text
198
+ layer.feature.properties.arter = predictions;
199
+ updateDownloadButton();
200
+ })
201
+ .catch(error => {
202
+ // Remove the spinner and restore original style on error
203
+ map.removeLayer(loadingMarker);
204
+ layer.setStyle(originalStyle);
205
+
206
+ console.error('Error:', error);
207
+ alert('Prediction failed.');
208
+ });
209
+ }
210
+
211
+
212
+ document.getElementById('downloadButton').addEventListener('click', function() {
213
+ var features = [];
214
+
215
+ // Collect all drawn layers with their prediction data
216
+ drawnItems.eachLayer(function(layer){
217
+ // Get the GeoJSON representation of the layer
218
+ var featureGeoJSON = layer.toGeoJSON();
219
+
220
+ // Ensure type is explicitly set to "Feature"
221
+ featureGeoJSON.type = "Feature";
222
+
223
+ // Make sure we have properties object
224
+ if (!featureGeoJSON.properties) {
225
+ featureGeoJSON.properties = {};
226
+ }
227
+
228
+ // Ensure prediction data is included in properties
229
+ if (layer.feature && layer.feature.properties && layer.feature.properties.arter) {
230
+ featureGeoJSON.properties.arter = layer.feature.properties.arter;
231
+ }
232
+
233
+ features.push(featureGeoJSON);
234
+ });
235
+
236
+ // Create a proper GeoJSON FeatureCollection
237
+ var featureCollection = {
238
+ "type": "FeatureCollection",
239
+ "features": features
240
+ };
241
+
242
+ // Convert to a JSON string with pretty formatting
243
+ var geojsonString = JSON.stringify(featureCollection, null, 2);
244
+
245
+ // Create a Blob from the GeoJSON
246
+ var blob = new Blob([geojsonString], {type: 'application/geo+json'});
247
+
248
+ // Create download link
249
+ var url = window.URL.createObjectURL(blob);
250
+ var a = document.createElement('a');
251
+ a.style.display = 'none';
252
+ a.href = url;
253
+ a.download = 'polygoner.geojson';
254
+ document.body.appendChild(a);
255
+ a.click();
256
+
257
+ // Clean up
258
+ setTimeout(function() {
259
+ document.body.removeChild(a);
260
+ window.URL.revokeObjectURL(url);
261
+ }, 100);
262
+ });
263
+
264
+ function updateDownloadButton(){
265
+ var geojsonData = [];
266
+ drawnItems.eachLayer(function(layer){
267
+ geojsonData.push(layer.toGeoJSON());
268
+ });
269
+
270
+ if(geojsonData.length > 0){
271
+ document.getElementById('downloadButton').style.display = "block";
272
+ } else {
273
+ document.getElementById('downloadButton').style.display = "none";
274
+ }
275
+ }
276
+
277
+ updateDownloadButton();
278
+
279
+ </script>
280
+ </body>
281
+ </html>
main.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel
4
+ from typing import Dict, Any
5
+ import numpy as np
6
+ from PIL import Image, ImageDraw
7
+ import json
8
+ from dotenv import load_dotenv
9
+ import os
10
+ import requests
11
+ from io import BytesIO
12
+ from pyproj import Transformer
13
+ import onnxruntime as ort
14
+ from cryptography.fernet import Fernet
15
+ from fastapi.responses import HTMLResponse
16
+
17
+ load_dotenv()
18
+
19
+ app = FastAPI()
20
+
21
+ app.add_middleware(
22
+ CORSMiddleware,
23
+ allow_origins=["*"], # Allows all origins
24
+ allow_credentials=True,
25
+ allow_methods=["*"], # Allows all methods
26
+ allow_headers=["*"], # Allows all headers
27
+ )
28
+
29
+ # Model load
30
+ key = os.getenv("MODEL_KEY")
31
+ cipher = Fernet(key)
32
+
33
+ with open("species_features.bin", "rb") as f:
34
+ bin_data = f.read()
35
+ data = cipher.decrypt(bin_data)
36
+ species_features = np.load(BytesIO(data))
37
+
38
+ with open("id2spec.bin", "rb") as f:
39
+ bin_data = f.read()
40
+ data = cipher.decrypt(bin_data)
41
+ id2spec = json.loads(data)
42
+
43
+ with open("image_encoder.bin", "rb") as f:
44
+ bin_data = f.read()
45
+ data = cipher.decrypt(bin_data)
46
+ image_encoder = ort.InferenceSession(data)
47
+
48
+ transformer = Transformer.from_crs("EPSG:4326", "EPSG:25832", always_xy=True)
49
+
50
+ IMAGE_SIZE = 384
51
+
52
+ def normalize_image(image, mean=(0.485, 0.456, 0.406), std=(0.229, 00.224, 0.225)):
53
+ image = (image / 255.0).astype("float32")
54
+
55
+ image[:, :, 0] = (image[:, :, 0] - mean[0]) / std[0]
56
+ image[:, :, 1] = (image[:, :, 1] - mean[1]) / std[1]
57
+ image[:, :, 2] = (image[:, :, 2] - mean[2]) / std[2]
58
+
59
+ return image
60
+
61
+ def pad_if_needed(image, target_size):
62
+ height, width, _ = image.shape
63
+
64
+ y0 = abs((height - target_size) // 2)
65
+ x0 = abs((width - target_size) // 2)
66
+
67
+ background = np.zeros((target_size, target_size, 3), dtype="uint8")
68
+ background[y0:(y0 + height), x0:(x0 + width), :] = image
69
+
70
+ return background
71
+
72
+ def predict(image, image_size, top_k = 20):
73
+ image = image.convert("RGB")
74
+ image = np.array(image)
75
+ image = pad_if_needed(image, image_size)
76
+ image = normalize_image(image)
77
+ image = np.transpose(image, (2, 0, 1))
78
+ image = image[np.newaxis]
79
+ image_features = image_encoder.run(None, {"input.1": image})[0]
80
+
81
+ similarity = np.dot(image_features, species_features.T)
82
+
83
+ sorted_similarity = np.argsort(similarity[0])[::-1][:top_k]
84
+
85
+ species_scores = {id2spec[str(idx)]: similarity[0, idx] for idx in sorted_similarity}
86
+ return species_scores
87
+
88
+ def get_image(coords, max_dim):
89
+
90
+ coords_utm = [transformer.transform(lon, lat) for lon, lat in coords]
91
+
92
+ xs, ys = zip(*coords_utm)
93
+
94
+ xmin, ymin, xmax, ymax = min(xs), min(ys), max(xs), max(ys)
95
+
96
+ roi_width = xmax - xmin
97
+ roi_height = ymax - ymin
98
+ aspect_ratio = roi_width / roi_height
99
+
100
+ if aspect_ratio > 1:
101
+ width = max_dim
102
+ height = int(max_dim / aspect_ratio)
103
+ else:
104
+ width = int(max_dim * aspect_ratio)
105
+ height = max_dim
106
+
107
+ wms_params = {
108
+ 'username': os.getenv('WMSUSER'),
109
+ 'password': os.getenv('WMSPW'),
110
+ 'SERVICE': 'WMS',
111
+ 'VERSION': '1.3.0',
112
+ 'REQUEST': 'GetMap',
113
+ 'BBOX': f"{xmin},{ymin},{xmax},{ymax}",
114
+ 'CRS': 'EPSG:25832',
115
+ 'WIDTH': width,
116
+ 'HEIGHT': height,
117
+ 'LAYERS': 'orto_foraar',
118
+ 'STYLES': '',
119
+ 'FORMAT': 'image/png',
120
+ 'DPI': 96,
121
+ 'MAP_RESOLUTION': 96,
122
+ 'FORMAT_OPTIONS': 'dpi:96'
123
+ }
124
+
125
+ base_url = "https://services.datafordeler.dk/GeoDanmarkOrto/orto_foraar/1.0.0/WMS"
126
+ response = requests.get(base_url, params=wms_params)
127
+ if response.status_code != 200:
128
+ raise HTTPException(status_code=500, detail=f"Error fetching image: {response.status_code}")
129
+
130
+ img = Image.open(BytesIO(response.content))
131
+
132
+ mask = Image.new('L', (width, height), 0)
133
+
134
+ x_norm = [(x - xmin) / roi_width for x in xs]
135
+ y_norm = [(y - ymin) / roi_height for y in ys]
136
+ x_img = [int(x * width) for x in x_norm]
137
+ y_img = [int((1 - y) * height) for y in y_norm]
138
+
139
+ ImageDraw.Draw(mask).polygon(list(zip(x_img, y_img)), outline=255, fill=255)
140
+
141
+ masked_img = Image.new('RGB', img.size)
142
+ masked_img.paste(img, mask=mask)
143
+
144
+ return masked_img
145
+
146
+ class GeoJSONInput(BaseModel):
147
+ geojson: Dict[str, Any]
148
+
149
+ @app.get("/", response_class=HTMLResponse)
150
+ async def get_html():
151
+ html_file = "index.html"
152
+ with open(html_file, "r") as f:
153
+ content = f.read()
154
+ return HTMLResponse(content=content)
155
+
156
+ @app.post("/predict")
157
+ async def predict_endpoint(geojson_input: GeoJSONInput):
158
+ try:
159
+ coords = geojson_input.geojson['geometry']['coordinates'][0]
160
+ image = get_image(coords, IMAGE_SIZE)
161
+ predictions_raw = predict(image, IMAGE_SIZE)
162
+
163
+ # Convert numpy.float32 values to native Python floats
164
+ predictions = {species: float(score) for species, score in predictions_raw.items()}
165
+
166
+ return {"predictions": predictions}
167
+ except Exception as e:
168
+ raise HTTPException(status_code=500, detail=str(e))
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ numpy
4
+ onnxruntime
5
+ pydantic
6
+ gradio
7
+ cryptography
8
+ requests
9
+ pyproj
species_features.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8591833f20df53e9c0e294eda39e312ef3f7943ac3cbd0273ad5863b96843022
3
+ size 3391756