from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing import Dict, Any import numpy as np from PIL import Image, ImageDraw import json import os import requests from io import BytesIO from pyproj import Transformer import onnxruntime as ort from cryptography.fernet import Fernet from fastapi.responses import HTMLResponse app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], # Allows all origins allow_credentials=True, allow_methods=["*"], # Allows all methods allow_headers=["*"], # Allows all headers ) # Model load key = os.getenv("MODEL_KEY") cipher = Fernet(key) with open("species_features.bin", "rb") as f: bin_data = f.read() data = cipher.decrypt(bin_data) species_features = np.load(BytesIO(data)) with open("id2spec.bin", "rb") as f: bin_data = f.read() data = cipher.decrypt(bin_data) id2spec = json.loads(data) with open("image_encoder.bin", "rb") as f: bin_data = f.read() data = cipher.decrypt(bin_data) image_encoder = ort.InferenceSession(data) with open("spec2key.json", "r") as f: spec2key = json.load(f) transformer = Transformer.from_crs("EPSG:4326", "EPSG:25832", always_xy=True) IMAGE_SIZE = 384 def normalize_image(image, mean=(0.485, 0.456, 0.406), std=(0.229, 00.224, 0.225)): image = (image / 255.0).astype("float32") image[:, :, 0] = (image[:, :, 0] - mean[0]) / std[0] image[:, :, 1] = (image[:, :, 1] - mean[1]) / std[1] image[:, :, 2] = (image[:, :, 2] - mean[2]) / std[2] return image def pad_if_needed(image, target_size): height, width, _ = image.shape y0 = abs((height - target_size) // 2) x0 = abs((width - target_size) // 2) background = np.zeros((target_size, target_size, 3), dtype="uint8") background[y0:(y0 + height), x0:(x0 + width), :] = image return background def predict(image, image_size, top_k = 20): image = image.convert("RGB") image = np.array(image) image = pad_if_needed(image, image_size) image = normalize_image(image) image = np.transpose(image, (2, 0, 1)) image = image[np.newaxis] image_features = image_encoder.run(None, {"input.1": image})[0] similarity = np.dot(image_features, species_features.T) sorted_similarity = np.argsort(similarity[0])[::-1][:top_k] species_scores = {id2spec[str(idx)]: similarity[0, idx] for idx in sorted_similarity} species_scores = {species: (float(score)+1)/2*100 for species, score in species_scores.items()} return species_scores def format_predictions(predictions): baseurl = "https://www.gbif.org/species/" formatted_strings = [] for species, value in predictions.items(): gbif_key = spec2key.get(species) value = round(value, 1) if gbif_key is None: formatted_strings.append(f"{species}: {value}%") else: formatted_strings.append(f'{species}: {value}%') format_predictions = "
".join(formatted_strings) return format_predictions def get_image(coords, max_dim): coords_utm = [transformer.transform(lon, lat) for lon, lat in coords] xs, ys = zip(*coords_utm) xmin, ymin, xmax, ymax = min(xs), min(ys), max(xs), max(ys) roi_width = xmax - xmin roi_height = ymax - ymin aspect_ratio = roi_width / roi_height if aspect_ratio > 1: width = max_dim height = int(max_dim / aspect_ratio) else: width = int(max_dim * aspect_ratio) height = max_dim wms_params = { 'username': os.getenv('WMSUSER'), 'password': os.getenv('WMSPW'), 'SERVICE': 'WMS', 'VERSION': '1.3.0', 'REQUEST': 'GetMap', 'BBOX': f"{xmin},{ymin},{xmax},{ymax}", 'CRS': 'EPSG:25832', 'WIDTH': width, 'HEIGHT': height, 'LAYERS': 'orto_foraar', 'STYLES': '', 'FORMAT': 'image/png', 'DPI': 96, 'MAP_RESOLUTION': 96, 'FORMAT_OPTIONS': 'dpi:96' } base_url = "https://services.datafordeler.dk/GeoDanmarkOrto/orto_foraar/1.0.0/WMS" response = requests.get(base_url, params=wms_params) if response.status_code != 200: raise HTTPException(status_code=500, detail=f"Error fetching image: {response.status_code}") img = Image.open(BytesIO(response.content)) mask = Image.new('L', (width, height), 0) x_norm = [(x - xmin) / roi_width for x in xs] y_norm = [(y - ymin) / roi_height for y in ys] x_img = [int(x * width) for x in x_norm] y_img = [int((1 - y) * height) for y in y_norm] ImageDraw.Draw(mask).polygon(list(zip(x_img, y_img)), outline=255, fill=255) masked_img = Image.new('RGB', img.size) masked_img.paste(img, mask=mask) return masked_img class GeoJSONInput(BaseModel): geojson: Dict[str, Any] @app.get("/", response_class=HTMLResponse) async def get_html(): html_file = "index.html" with open(html_file, "r") as f: content = f.read() return HTMLResponse(content=content) @app.post("/predict") async def predict_endpoint(geojson_input: GeoJSONInput): try: coords = geojson_input.geojson['geometry']['coordinates'][0] image = get_image(coords, IMAGE_SIZE) predictions = predict(image, IMAGE_SIZE) predictions_formatted = format_predictions(predictions) return {"predictions": predictions, "predictions_formatted": predictions_formatted} except Exception as e: raise HTTPException(status_code=500, detail=str(e))