from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Dict, Any
import numpy as np
from PIL import Image, ImageDraw
import json
import os
import requests
from io import BytesIO
from pyproj import Transformer
import onnxruntime as ort
from cryptography.fernet import Fernet
from fastapi.responses import HTMLResponse
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
# Model load
key = os.getenv("MODEL_KEY")
cipher = Fernet(key)
with open("species_features.bin", "rb") as f:
bin_data = f.read()
data = cipher.decrypt(bin_data)
species_features = np.load(BytesIO(data))
with open("id2spec.bin", "rb") as f:
bin_data = f.read()
data = cipher.decrypt(bin_data)
id2spec = json.loads(data)
with open("image_encoder.bin", "rb") as f:
bin_data = f.read()
data = cipher.decrypt(bin_data)
image_encoder = ort.InferenceSession(data)
with open("spec2key.json", "r") as f:
spec2key = json.load(f)
transformer = Transformer.from_crs("EPSG:4326", "EPSG:25832", always_xy=True)
IMAGE_SIZE = 384
def normalize_image(image, mean=(0.485, 0.456, 0.406), std=(0.229, 00.224, 0.225)):
image = (image / 255.0).astype("float32")
image[:, :, 0] = (image[:, :, 0] - mean[0]) / std[0]
image[:, :, 1] = (image[:, :, 1] - mean[1]) / std[1]
image[:, :, 2] = (image[:, :, 2] - mean[2]) / std[2]
return image
def pad_if_needed(image, target_size):
height, width, _ = image.shape
y0 = abs((height - target_size) // 2)
x0 = abs((width - target_size) // 2)
background = np.zeros((target_size, target_size, 3), dtype="uint8")
background[y0:(y0 + height), x0:(x0 + width), :] = image
return background
def predict(image, image_size, top_k = 20):
image = image.convert("RGB")
image = np.array(image)
image = pad_if_needed(image, image_size)
image = normalize_image(image)
image = np.transpose(image, (2, 0, 1))
image = image[np.newaxis]
image_features = image_encoder.run(None, {"input.1": image})[0]
similarity = np.dot(image_features, species_features.T)
sorted_similarity = np.argsort(similarity[0])[::-1][:top_k]
species_scores = {id2spec[str(idx)]: similarity[0, idx] for idx in sorted_similarity}
species_scores = {species: (float(score)+1)/2*100 for species, score in species_scores.items()}
return species_scores
def format_predictions(predictions):
baseurl = "https://www.gbif.org/species/"
formatted_strings = []
for species, value in predictions.items():
gbif_key = spec2key.get(species)
value = round(value, 1)
if gbif_key is None:
formatted_strings.append(f"{species}: {value}%")
else:
formatted_strings.append(f'{species}: {value}%')
format_predictions = "
".join(formatted_strings)
return format_predictions
def get_image(coords, max_dim):
coords_utm = [transformer.transform(lon, lat) for lon, lat in coords]
xs, ys = zip(*coords_utm)
xmin, ymin, xmax, ymax = min(xs), min(ys), max(xs), max(ys)
roi_width = xmax - xmin
roi_height = ymax - ymin
aspect_ratio = roi_width / roi_height
if aspect_ratio > 1:
width = max_dim
height = int(max_dim / aspect_ratio)
else:
width = int(max_dim * aspect_ratio)
height = max_dim
wms_params = {
'username': os.getenv('WMSUSER'),
'password': os.getenv('WMSPW'),
'SERVICE': 'WMS',
'VERSION': '1.3.0',
'REQUEST': 'GetMap',
'BBOX': f"{xmin},{ymin},{xmax},{ymax}",
'CRS': 'EPSG:25832',
'WIDTH': width,
'HEIGHT': height,
'LAYERS': 'orto_foraar',
'STYLES': '',
'FORMAT': 'image/png',
'DPI': 96,
'MAP_RESOLUTION': 96,
'FORMAT_OPTIONS': 'dpi:96'
}
base_url = "https://services.datafordeler.dk/GeoDanmarkOrto/orto_foraar/1.0.0/WMS"
response = requests.get(base_url, params=wms_params)
if response.status_code != 200:
raise HTTPException(status_code=500, detail=f"Error fetching image: {response.status_code}")
img = Image.open(BytesIO(response.content))
mask = Image.new('L', (width, height), 0)
x_norm = [(x - xmin) / roi_width for x in xs]
y_norm = [(y - ymin) / roi_height for y in ys]
x_img = [int(x * width) for x in x_norm]
y_img = [int((1 - y) * height) for y in y_norm]
ImageDraw.Draw(mask).polygon(list(zip(x_img, y_img)), outline=255, fill=255)
masked_img = Image.new('RGB', img.size)
masked_img.paste(img, mask=mask)
return masked_img
class GeoJSONInput(BaseModel):
geojson: Dict[str, Any]
@app.get("/", response_class=HTMLResponse)
async def get_html():
html_file = "index.html"
with open(html_file, "r") as f:
content = f.read()
return HTMLResponse(content=content)
@app.post("/predict")
async def predict_endpoint(geojson_input: GeoJSONInput):
try:
coords = geojson_input.geojson['geometry']['coordinates'][0]
image = get_image(coords, IMAGE_SIZE)
predictions = predict(image, IMAGE_SIZE)
predictions_formatted = format_predictions(predictions)
return {"predictions": predictions, "predictions_formatted": predictions_formatted}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))