Pranavpai0309 commited on
Commit
adcd133
·
verified ·
1 Parent(s): d9761bf

Upload 2 files

Browse files
Files changed (2) hide show
  1. Dockerfile +26 -0
  2. ModelCode.py +83 -0
Dockerfile ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11
2
+
3
+ # Add a user
4
+ RUN useradd -m -u 1000 user
5
+ ENV PATH="/home/user/.local/bin:$PATH"
6
+
7
+ # Install system packages
8
+ COPY packages.txt /tmp/packages.txt
9
+ USER root
10
+ RUN apt-get update && \
11
+ xargs -a /tmp/packages.txt apt-get install -y && \
12
+ apt-get clean && rm -rf /var/lib/apt/lists/*
13
+ USER user
14
+
15
+ # Set working directory
16
+ WORKDIR /app
17
+
18
+ # Install Python packages
19
+ COPY --chown=user ./requirements.txt requirements.txt
20
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
21
+
22
+ # Copy the rest of the code
23
+ COPY --chown=user . /app
24
+
25
+ # Run the app
26
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
ModelCode.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transformers import CLIPProcessor, CLIPModel
4
+ from PIL import Image
5
+ import torchvision.transforms as transforms
6
+ from pytesseract import image_to_string
7
+ import cv2
8
+ from transformers import BlipProcessor, BlipForConditionalGeneration
9
+ from collections import Counter
10
+
11
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
12
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
13
+
14
+ blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
15
+ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
16
+
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+ clip_model = clip_model.to(device)
19
+ blip_model = blip_model.to(device)
20
+
21
+ def extract_frames(video_path, frame_rate=1):
22
+ cap = cv2.VideoCapture(video_path)
23
+ fps = cap.get(cv2.CAP_PROP_FPS)
24
+ frames = []
25
+ count = 0
26
+
27
+ while cap.isOpened():
28
+ ret, frame = cap.read()
29
+ if not ret:
30
+ break
31
+ if int(count % (fps * frame_rate)) == 0:
32
+ img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
33
+ frames.append(img)
34
+ count += 1
35
+
36
+ cap.release()
37
+ return frames
38
+
39
+ def classify_frame_with_clip(image):
40
+ texts = ["Ayurveda", "Non-Ayurveda"]
41
+ inputs = clip_processor(text=texts, images=image, return_tensors="pt", padding=True).to(device)
42
+ outputs = clip_model(**inputs)
43
+ logits_per_image = outputs.logits_per_image
44
+ probs = logits_per_image.softmax(dim=1)
45
+ pred = torch.argmax(probs, dim=1).item()
46
+ return texts[pred]
47
+
48
+ def get_caption_with_blip(image):
49
+ inputs = blip_processor(images=image, return_tensors="pt").to(device)
50
+ out = blip_model.generate(**inputs)
51
+ caption = blip_processor.decode(out[0], skip_special_tokens=True)
52
+ return caption
53
+
54
+ def extract_text_with_ocr(image):
55
+ return image_to_string(image)
56
+
57
+ def classify_video(video_path):
58
+ frames = extract_frames(video_path, frame_rate=2)
59
+
60
+ clip_preds = []
61
+ blip_preds = []
62
+ ocr_preds = []
63
+
64
+ for frame in frames:
65
+ clip_result = classify_frame_with_clip(frame)
66
+ clip_preds.append(clip_result)
67
+
68
+ caption = get_caption_with_blip(frame)
69
+ blip_input = clip_processor(text=["Ayurveda", "Non-Ayurveda"], images=frame, return_tensors="pt", padding=True).to(device)
70
+ blip_output = clip_model(**blip_input)
71
+ blip_probs = blip_output.logits_per_image.softmax(dim=1)
72
+ blip_pred = torch.argmax(blip_probs, dim=1).item()
73
+ blip_preds.append(["Ayurveda", "Non-Ayurveda"][blip_pred])
74
+
75
+ text = extract_text_with_ocr(frame)
76
+ if any(keyword in text.lower() for keyword in ["ayurveda", "herbal", "vedic", "naturopathy"]):
77
+ ocr_preds.append("Ayurveda")
78
+ else:
79
+ ocr_preds.append("Non-Ayurveda")
80
+
81
+ all_preds = clip_preds + blip_preds + ocr_preds
82
+ final_pred = Counter(all_preds).most_common(1)[0][0]
83
+ return {"Type": final_pred}