codeby-hp commited on
Commit
16b5510
·
verified ·
1 Parent(s): 0d39431

Upload 8 files

Browse files
Dockerfile ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ # Set working directory
4
+ WORKDIR /app
5
+
6
+ # Install system dependencies
7
+ RUN apt-get update && apt-get install -y --no-install-recommends \
8
+ build-essential \
9
+ libssl-dev \
10
+ libffi-dev \
11
+ && rm -rf /var/lib/apt/lists/*
12
+
13
+ # Copy requirements first for better caching
14
+ COPY fastapi_app/requirements.txt .
15
+ RUN pip install --no-cache-dir -r requirements.txt
16
+
17
+ # Copy application files
18
+ COPY fastapi_app .
19
+
20
+ # Create non-root user for security
21
+ RUN useradd -m -u 1000 appuser && chown -R appuser:appuser /app
22
+ USER appuser
23
+
24
+ EXPOSE 8000
25
+
26
+ # Health check
27
+ HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
28
+ CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1
29
+
30
+ # Run the application
31
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
dockerignore ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Include any files or directories that you don't want to be copied to your
2
+ # container here (e.g., local build artifacts, temporary files, etc.).
3
+ #
4
+ # For more help, visit the .dockerignore file reference guide at
5
+ # https://docs.docker.com/go/build-context-dockerignore/
6
+
7
+ **/.DS_Store
8
+ **/__pycache__
9
+ **/.venv
10
+ **/.classpath
11
+ **/.dockerignore
12
+ **/.env
13
+ **/.git
14
+ **/.gitignore
15
+ **/.project
16
+ **/.settings
17
+ **/.toolstarget
18
+ **/.vs
19
+ **/.vscode
20
+ **/*.*proj.user
21
+ **/*.dbmdl
22
+ **/*.jfm
23
+ **/bin
24
+ **/charts
25
+ **/docker-compose*
26
+ **/compose.y*ml
27
+ **/Dockerfile*
28
+ **/node_modules
29
+ **/npm-debug.log
30
+ **/obj
31
+ **/secrets.dev.yaml
32
+ **/values.dev.yaml
33
+ LICENSE
34
+ README.md
35
+ **/.aws
36
+ ml-models/
fastapi_app/app.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import time
4
+ import warnings
5
+ from pathlib import Path
6
+
7
+ import torch
8
+ from fastapi import FastAPI, UploadFile, File, HTTPException
9
+ from fastapi.responses import HTMLResponse
10
+ from fastapi.templating import Jinja2Templates
11
+ from fastapi.requests import Request
12
+ from transformers import AutoImageProcessor, pipeline
13
+ from PIL import Image
14
+ import io
15
+
16
+ from scripts.data_model import (
17
+ PoseClassificationResponse,
18
+ PosePrediction,
19
+ )
20
+ from scripts.s3 import download_model_from_s3
21
+ from scripts.huggingface_load import download_model_from_huggingface
22
+
23
+ # Toggle between S3 and Hugging Face model loading
24
+ # Set USE_HUGGINGFACE_MODELS = False to use S3 loader (production)
25
+ # Set USE_HUGGINGFACE_MODELS = True to use Hugging Face loader (Spaces deployment)
26
+ USE_HUGGINGFACE_MODELS = True
27
+
28
+ warnings.filterwarnings("ignore")
29
+
30
+ # Configure logging
31
+ logging.basicConfig(level=logging.INFO)
32
+ logger = logging.getLogger(__name__)
33
+
34
+ # Initialize FastAPI app
35
+ app = FastAPI(
36
+ title="Pose Classification API",
37
+ description="ViT-based human pose classification service",
38
+ version="0.0.0",
39
+ )
40
+
41
+ # Setup templates
42
+ template_dir = Path(__file__).parent / "templates"
43
+ if template_dir.exists():
44
+ templates = Jinja2Templates(directory=str(template_dir))
45
+
46
+ # Device selection
47
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48
+ logger.info(f"Using device: {device}")
49
+
50
+ # Model initialization
51
+ MODEL_NAME = "vit-human-pose-classification"
52
+ LOCAL_MODEL_PATH = f"ml-models/{MODEL_NAME}"
53
+ FORCE_DOWNLOAD = False
54
+
55
+ # Global model variables
56
+ pose_model = None
57
+ image_processor = None
58
+
59
+
60
+ def initialize_model():
61
+ """Initialize the pose classification model."""
62
+ global pose_model, image_processor
63
+
64
+ try:
65
+ logger.info("Initializing pose classification model...")
66
+
67
+ # Download model if not present
68
+ if not os.path.isdir(LOCAL_MODEL_PATH) or FORCE_DOWNLOAD:
69
+ if USE_HUGGINGFACE_MODELS:
70
+ logger.info(f"Downloading model from Hugging Face to {LOCAL_MODEL_PATH}")
71
+ success = download_model_from_huggingface(LOCAL_MODEL_PATH)
72
+ else:
73
+ logger.info(f"Downloading model from S3 to {LOCAL_MODEL_PATH}")
74
+ success = download_model_from_s3(LOCAL_MODEL_PATH, f"{MODEL_NAME}/")
75
+
76
+ if not success:
77
+ logger.error("Failed to download model")
78
+ return False
79
+
80
+ # Load image processor
81
+ image_processor = AutoImageProcessor.from_pretrained(
82
+ LOCAL_MODEL_PATH,
83
+ use_fast=True,
84
+ local_files_only=True,
85
+ )
86
+
87
+ # Load model pipeline
88
+ pose_model = pipeline(
89
+ "image-classification",
90
+ model=LOCAL_MODEL_PATH,
91
+ device=device,
92
+ image_processor=image_processor,
93
+ )
94
+
95
+ logger.info("Model initialized successfully")
96
+ return True
97
+
98
+ except Exception as e:
99
+ logger.error(f"Error initializing model: {e}")
100
+ return False
101
+
102
+
103
+ @app.on_event("startup")
104
+ async def startup_event():
105
+ """Initialize model on startup."""
106
+ if not initialize_model():
107
+ logger.warning("Model initialization failed, app will not be fully functional")
108
+
109
+
110
+ @app.get("/", response_class=HTMLResponse)
111
+ async def read_root(request: Request):
112
+ """Serve the main UI page."""
113
+ if template_dir.exists():
114
+ return templates.TemplateResponse("index.html", {"request": request})
115
+ return """
116
+ <!DOCTYPE html>
117
+ <html>
118
+ <head><title>Pose Classification</title></head>
119
+ <body><p>Template not found</p></body>
120
+ </html>
121
+ """
122
+
123
+
124
+ @app.get("/health")
125
+ async def health_check():
126
+ """Health check endpoint."""
127
+ if pose_model is not None:
128
+ return {"status": "healthy", "model_loaded": True}
129
+ return {"status": "unhealthy", "model_loaded": False}
130
+
131
+
132
+ @app.post("/api/v1/classify")
133
+ async def classify_pose(file: UploadFile = File(...)) -> PoseClassificationResponse:
134
+ """Classify pose from uploaded image.
135
+
136
+ Args:
137
+ file: Image file to classify
138
+
139
+ Returns:
140
+ PoseClassificationResponse with prediction results
141
+ """
142
+ if pose_model is None:
143
+ raise HTTPException(
144
+ status_code=503,
145
+ detail="Model not loaded. Please try again later.",
146
+ )
147
+
148
+ try:
149
+ # Read and validate image
150
+ content = await file.read()
151
+ image = Image.open(io.BytesIO(content))
152
+
153
+ # Run inference
154
+ start_time = time.time()
155
+ output = pose_model(image)
156
+ inference_time = int((time.time() - start_time) * 1000)
157
+
158
+ # Extract top prediction
159
+ top_prediction = output[0]
160
+
161
+ return PoseClassificationResponse(
162
+ prediction=PosePrediction(
163
+ label=top_prediction["label"],
164
+ score=round(top_prediction["score"], 4),
165
+ ),
166
+ prediction_time_ms=inference_time,
167
+ )
168
+
169
+ except Exception as e:
170
+ logger.error(f"Error during inference: {e}")
171
+ raise HTTPException(
172
+ status_code=400,
173
+ detail=f"Error processing image: {str(e)}",
174
+ )
175
+
176
+ if __name__=="__main__":
177
+ import uvicorn
178
+ uvicorn.run(app="app:app", port=8000, reload=True, host="0.0.0.0")
fastapi_app/requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.115.6
2
+ uvicorn[standard]==0.34.0
3
+ jinja2==3.1.5
4
+ python-multipart==0.0.18
5
+ boto3==1.34.149
6
+ python-dotenv==1.0.0
7
+ transformers==4.43.3
8
+ huggingface-hub==0.23.0
9
+ torch==2.3.1
10
+ torchvision==0.18.1
11
+ accelerate==0.33.0
12
+ Pillow==10.2.0
13
+ pydantic==2.8.2
14
+ pydantic[email]==2.8.2
fastapi_app/scripts/__init__.py ADDED
File without changes
fastapi_app/scripts/data_model.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Data models for pose classification API."""
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+
6
+ class PoseClassificationRequest(BaseModel):
7
+ """Request body for pose classification endpoint."""
8
+ url: str = Field(
9
+ description="Image URL for classification"
10
+ )
11
+
12
+
13
+ class PosePrediction(BaseModel):
14
+ """Single pose prediction result."""
15
+ label: str
16
+ score: float
17
+
18
+
19
+ class PoseClassificationResponse(BaseModel):
20
+ """Response body for pose classification endpoint."""
21
+ model_name: str = "vit-human-pose-classification"
22
+ prediction: PosePrediction
23
+ prediction_time_ms: int = Field(
24
+ description="Time taken for inference in milliseconds"
25
+ )
26
+
27
+
28
+
29
+
30
+
fastapi_app/scripts/huggingface_load.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hugging Face utilities for downloading ML models."""
2
+
3
+ import os
4
+ import logging
5
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
6
+ from huggingface_hub.utils import RepositoryNotFoundError
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ HF_MODEL_ID = "codeby-hp/finetune-VIT-HumanPoseClassification"
11
+
12
+
13
+ def download_model_from_huggingface(local_path: str) -> bool:
14
+ """Download model from Hugging Face Hub.
15
+
16
+ Args:
17
+ local_path: Local directory path to save model
18
+
19
+ Returns:
20
+ True if successful, False otherwise
21
+ """
22
+ try:
23
+ logger.info(f"Downloading model from Hugging Face: {HF_MODEL_ID}")
24
+ os.makedirs(local_path, exist_ok=True)
25
+
26
+ # Download image processor
27
+ logger.info("Downloading image processor...")
28
+ image_processor = AutoImageProcessor.from_pretrained(
29
+ HF_MODEL_ID,
30
+ cache_dir=local_path,
31
+ )
32
+ image_processor.save_pretrained(local_path)
33
+
34
+ # Download model
35
+ logger.info("Downloading model weights...")
36
+ model = AutoModelForImageClassification.from_pretrained(
37
+ HF_MODEL_ID,
38
+ cache_dir=local_path,
39
+ )
40
+ model.save_pretrained(local_path)
41
+
42
+ logger.info(f"Successfully downloaded model to {local_path}")
43
+ return True
44
+
45
+ except RepositoryNotFoundError as e:
46
+ logger.error(f"Model not found on Hugging Face Hub: {e}")
47
+ return False
48
+ except Exception as e:
49
+ logger.error(f"Error downloading model from Hugging Face: {e}")
50
+ return False
fastapi_app/templates/index.html ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Pose Classification</title>
7
+ <script src="https://cdn.tailwindcss.com"></script>
8
+ </head>
9
+ <body class="bg-gradient-to-br from-slate-900 via-slate-800 to-slate-900 min-h-screen">
10
+ <div class="min-h-screen flex items-center justify-center px-4 py-12">
11
+ <div class="w-full max-w-md">
12
+ <!-- Header -->
13
+ <div class="text-center mb-8">
14
+ <h1 class="text-4xl font-bold text-white mb-2">Pose Classification</h1>
15
+ <p class="text-slate-400 text-sm">Upload an image to classify human poses</p>
16
+ </div>
17
+
18
+ <!-- Main Card -->
19
+ <div class="bg-slate-800 rounded-lg shadow-xl overflow-hidden border border-slate-700">
20
+ <!-- Upload Section -->
21
+ <div class="p-8">
22
+ <form id="uploadForm" class="space-y-6">
23
+ <!-- File Input -->
24
+ <div class="relative">
25
+ <input
26
+ type="file"
27
+ id="imageInput"
28
+ accept="image/*"
29
+ class="hidden"
30
+ required
31
+ >
32
+ <label
33
+ for="imageInput"
34
+ class="flex items-center justify-center w-full px-4 py-6 border-2 border-dashed border-slate-600 rounded-lg cursor-pointer transition hover:border-blue-400 hover:bg-slate-700/50"
35
+ >
36
+ <div class="text-center">
37
+ <svg class="w-10 h-10 mx-auto mb-2 text-slate-400" fill="none" stroke="currentColor" viewBox="0 0 24 24">
38
+ <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 4v16m8-8H4"></path>
39
+ </svg>
40
+ <p class="text-slate-300 font-medium">Click to upload image</p>
41
+ <p class="text-slate-500 text-xs mt-1">PNG, JPG, JPEG up to 10MB</p>
42
+ </div>
43
+ </label>
44
+ </div>
45
+
46
+ <!-- Image Preview -->
47
+ <div id="previewContainer" class="hidden">
48
+ <img id="imagePreview" class="w-full h-64 object-cover rounded-lg" alt="Preview">
49
+ </div>
50
+
51
+ <!-- Submit Button -->
52
+ <button
53
+ type="submit"
54
+ id="submitBtn"
55
+ class="w-full bg-blue-600 hover:bg-blue-700 text-white font-semibold py-3 rounded-lg transition disabled:opacity-50 disabled:cursor-not-allowed"
56
+ disabled
57
+ >
58
+ Classify Pose
59
+ </button>
60
+ </form>
61
+ </div>
62
+
63
+ <!-- Results Section -->
64
+ <div id="resultsContainer" class="hidden border-t border-slate-700 bg-slate-700/30 p-8">
65
+ <h2 class="text-white font-semibold mb-4 text-lg">Classification Results</h2>
66
+
67
+ <div class="space-y-4">
68
+ <!-- Prediction -->
69
+ <div class="bg-slate-800/50 rounded-lg p-4">
70
+ <p class="text-slate-400 text-sm mb-1">Detected Pose</p>
71
+ <p id="predictionLabel" class="text-white text-2xl font-bold">-</p>
72
+ </div>
73
+
74
+ <!-- Confidence -->
75
+ <div class="bg-slate-800/50 rounded-lg p-4">
76
+ <p class="text-slate-400 text-sm mb-2">Confidence</p>
77
+ <div class="flex items-center space-x-3">
78
+ <div class="flex-1 bg-slate-700 rounded-full h-2">
79
+ <div id="confidenceBar" class="bg-green-500 h-2 rounded-full transition-all" style="width: 0%"></div>
80
+ </div>
81
+ <p id="confidenceScore" class="text-white font-semibold min-w-fit">0%</p>
82
+ </div>
83
+ </div>
84
+
85
+ <!-- Inference Time -->
86
+ <div class="bg-slate-800/50 rounded-lg p-4">
87
+ <p class="text-slate-400 text-sm mb-1">Inference Time</p>
88
+ <p id="inferenceTime" class="text-white text-lg font-semibold">-</p>
89
+ </div>
90
+ </div>
91
+
92
+ <!-- Reset Button -->
93
+ <button
94
+ onclick="resetForm()"
95
+ class="w-full mt-6 bg-slate-700 hover:bg-slate-600 text-white font-semibold py-2 rounded-lg transition"
96
+ >
97
+ Classify Another Image
98
+ </button>
99
+ </div>
100
+ </div>
101
+
102
+ <!-- Status Messages -->
103
+ <div id="loadingContainer" class="hidden mt-4 text-center">
104
+ <div class="inline-block">
105
+ <div class="animate-spin h-8 w-8 border-4 border-blue-400 border-t-transparent rounded-full"></div>
106
+ </div>
107
+ <p class="text-slate-400 mt-2">Processing image...</p>
108
+ </div>
109
+
110
+ <div id="errorContainer" class="hidden mt-4 p-4 bg-red-900/30 border border-red-700 rounded-lg">
111
+ <p id="errorMessage" class="text-red-300 text-sm"></p>
112
+ </div>
113
+ </div>
114
+ </div>
115
+
116
+ <script>
117
+ const uploadForm = document.getElementById('uploadForm');
118
+ const imageInput = document.getElementById('imageInput');
119
+ const previewContainer = document.getElementById('previewContainer');
120
+ const imagePreview = document.getElementById('imagePreview');
121
+ const submitBtn = document.getElementById('submitBtn');
122
+ const loadingContainer = document.getElementById('loadingContainer');
123
+ const resultsContainer = document.getElementById('resultsContainer');
124
+ const errorContainer = document.getElementById('errorContainer');
125
+ const errorMessage = document.getElementById('errorMessage');
126
+
127
+ // Handle image selection
128
+ imageInput.addEventListener('change', function(e) {
129
+ const file = e.target.files[0];
130
+ if (file) {
131
+ // Validate file size (10MB)
132
+ if (file.size > 10 * 1024 * 1024) {
133
+ showError('Image size must be less than 10MB');
134
+ imageInput.value = '';
135
+ submitBtn.disabled = true;
136
+ return;
137
+ }
138
+
139
+ // Show preview
140
+ const reader = new FileReader();
141
+ reader.onload = function(event) {
142
+ imagePreview.src = event.target.result;
143
+ previewContainer.classList.remove('hidden');
144
+ submitBtn.disabled = false;
145
+ errorContainer.classList.add('hidden');
146
+ resultsContainer.classList.add('hidden');
147
+ };
148
+ reader.readAsDataURL(file);
149
+ }
150
+ });
151
+
152
+ // Handle form submission
153
+ uploadForm.addEventListener('submit', async function(e) {
154
+ e.preventDefault();
155
+
156
+ const file = imageInput.files[0];
157
+ if (!file) return;
158
+
159
+ // Show loading state
160
+ submitBtn.disabled = true;
161
+ loadingContainer.classList.remove('hidden');
162
+ resultsContainer.classList.add('hidden');
163
+ errorContainer.classList.add('hidden');
164
+
165
+ try {
166
+ const formData = new FormData();
167
+ formData.append('file', file);
168
+
169
+ const response = await fetch('/api/v1/classify', {
170
+ method: 'POST',
171
+ body: formData
172
+ });
173
+
174
+ if (!response.ok) {
175
+ const error = await response.json();
176
+ throw new Error(error.detail || 'Classification failed');
177
+ }
178
+
179
+ const data = await response.json();
180
+ displayResults(data);
181
+
182
+ } catch (error) {
183
+ showError(error.message || 'An error occurred during classification');
184
+ } finally {
185
+ submitBtn.disabled = false;
186
+ loadingContainer.classList.add('hidden');
187
+ }
188
+ });
189
+
190
+ function displayResults(data) {
191
+ const confidence = (data.prediction.score * 100).toFixed(1);
192
+
193
+ document.getElementById('predictionLabel').textContent = data.prediction.label;
194
+ document.getElementById('confidenceScore').textContent = confidence + '%';
195
+ document.getElementById('confidenceBar').style.width = confidence + '%';
196
+ document.getElementById('inferenceTime').textContent = data.prediction_time_ms + ' ms';
197
+
198
+ resultsContainer.classList.remove('hidden');
199
+ }
200
+
201
+ function showError(message) {
202
+ errorMessage.textContent = message;
203
+ errorContainer.classList.remove('hidden');
204
+ }
205
+
206
+ function resetForm() {
207
+ imageInput.value = '';
208
+ previewContainer.classList.add('hidden');
209
+ resultsContainer.classList.add('hidden');
210
+ errorContainer.classList.add('hidden');
211
+ submitBtn.disabled = true;
212
+ }
213
+ </script>
214
+ </body>
215
+ </html>