Pranavpai0309 commited on
Commit
a1f0c65
·
verified ·
1 Parent(s): ef36773

Delete ModelCode.py

Browse files
Files changed (1) hide show
  1. ModelCode.py +0 -85
ModelCode.py DELETED
@@ -1,85 +0,0 @@
1
- import os
2
- import torch
3
- from transformers import CLIPProcessor, CLIPModel
4
- from PIL import Image
5
- import torchvision.transforms as transforms
6
- from pytesseract import image_to_string
7
- import cv2
8
- from transformers import BlipProcessor, BlipForConditionalGeneration
9
- from collections import Counter
10
- from pytesseract import pytesseract
11
-
12
- clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
13
- clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
14
- pytesseract.pytesseract.tesseract_cmd = "/usr/bin/tesseract"
15
-
16
- blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
17
- blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
18
-
19
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
- clip_model = clip_model.to(device)
21
- blip_model = blip_model.to(device)
22
-
23
- def extract_frames(video_path, frame_rate=1):
24
- cap = cv2.VideoCapture(video_path)
25
- fps = cap.get(cv2.CAP_PROP_FPS)
26
- frames = []
27
- count = 0
28
-
29
- while cap.isOpened():
30
- ret, frame = cap.read()
31
- if not ret:
32
- break
33
- if int(count % (fps * frame_rate)) == 0:
34
- img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
35
- frames.append(img)
36
- count += 1
37
-
38
- cap.release()
39
- return frames
40
-
41
- def classify_frame_with_clip(image):
42
- texts = ["Ayurveda", "Non-Ayurveda"]
43
- inputs = clip_processor(text=texts, images=image, return_tensors="pt", padding=True).to(device)
44
- outputs = clip_model(**inputs)
45
- logits_per_image = outputs.logits_per_image
46
- probs = logits_per_image.softmax(dim=1)
47
- pred = torch.argmax(probs, dim=1).item()
48
- return texts[pred]
49
-
50
- def get_caption_with_blip(image):
51
- inputs = blip_processor(images=image, return_tensors="pt").to(device)
52
- out = blip_model.generate(**inputs)
53
- caption = blip_processor.decode(out[0], skip_special_tokens=True)
54
- return caption
55
-
56
- def extract_text_with_ocr(image):
57
- return image_to_string(image)
58
-
59
- def classify_video(video_path):
60
- frames = extract_frames(video_path, frame_rate=2)
61
-
62
- clip_preds = []
63
- blip_preds = []
64
- ocr_preds = []
65
-
66
- for frame in frames:
67
- clip_result = classify_frame_with_clip(frame)
68
- clip_preds.append(clip_result)
69
-
70
- caption = get_caption_with_blip(frame)
71
- blip_input = clip_processor(text=["Ayurveda", "Non-Ayurveda"], images=frame, return_tensors="pt", padding=True).to(device)
72
- blip_output = clip_model(**blip_input)
73
- blip_probs = blip_output.logits_per_image.softmax(dim=1)
74
- blip_pred = torch.argmax(blip_probs, dim=1).item()
75
- blip_preds.append(["Ayurveda", "Non-Ayurveda"][blip_pred])
76
-
77
- text = extract_text_with_ocr(frame)
78
- if any(keyword in text.lower() for keyword in ["ayurveda", "herbal", "vedic", "naturopathy"]):
79
- ocr_preds.append("Ayurveda")
80
- else:
81
- ocr_preds.append("Non-Ayurveda")
82
-
83
- all_preds = clip_preds + blip_preds + ocr_preds
84
- final_pred = Counter(all_preds).most_common(1)[0][0]
85
- return {"Type": final_pred}