Update README.md
Browse files
README.md
CHANGED
@@ -1,3 +1,76 @@
|
|
1 |
-
---
|
2 |
-
license: mit
|
3 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: mit
|
3 |
+
---
|
4 |
+
**ViT-LSTM Action Recognition**
|
5 |
+
Overview
|
6 |
+
This project implements an Action Recognition Model using a ViT-LSTM architecture. It takes a short video as input and predicts the action performed in the video. The model extracts frame-wise ViT features and processes them using an LSTM to capture temporal dependencies.
|
7 |
+
|
8 |
+
**Model Details**
|
9 |
+
Base Model: ViT-Base-Patch16-224
|
10 |
+
Architecture: ViT (Feature Extractor) + LSTM (Temporal Modeling)
|
11 |
+
Number of Classes: 5
|
12 |
+
Dataset: Custom dataset with the following action categories:
|
13 |
+
BaseballPitch
|
14 |
+
Basketball
|
15 |
+
BenchPress
|
16 |
+
Biking
|
17 |
+
Billiards
|
18 |
+
**Working**
|
19 |
+
Extract Frames – The model extracts up to 16 frames from the uploaded video.
|
20 |
+
Feature Extraction – Each frame is passed through ViT, and feature vectors are obtained.
|
21 |
+
Temporal Processing – The LSTM processes these features to capture motion information.
|
22 |
+
Prediction – The final output is classified into one of the 5 action categories.
|
23 |
+
|
24 |
+
Model Training Details
|
25 |
+
Feature Dimension: 768
|
26 |
+
LSTM Hidden Dimension: 512
|
27 |
+
Number of LSTM Layers: 2 (Bidirectional)
|
28 |
+
Dropout: 0.3
|
29 |
+
Optimizer: Adam
|
30 |
+
Loss Function: Cross-Entropy Loss
|
31 |
+
Example Usage (Code Snippet)
|
32 |
+
If you want to use this model locally:
|
33 |
+
````
|
34 |
+
import torch
|
35 |
+
from transformers import ViTImageProcessor, ViTModel
|
36 |
+
from PIL import Image
|
37 |
+
import cv2
|
38 |
+
|
39 |
+
# Load Pretrained ViT
|
40 |
+
vit_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
|
41 |
+
vit_model = ViTModel.from_pretrained("google/vit-base-patch16-224")
|
42 |
+
|
43 |
+
# Load Custom ViT-LSTM Model
|
44 |
+
model = torch.load("Vit-LSTM.pth")
|
45 |
+
model.eval()
|
46 |
+
|
47 |
+
# Process an Example Video
|
48 |
+
video_path = "example.mp4"
|
49 |
+
cap = cv2.VideoCapture(video_path)
|
50 |
+
frames = []
|
51 |
+
|
52 |
+
while cap.isOpened():
|
53 |
+
ret, frame = cap.read()
|
54 |
+
if not ret:
|
55 |
+
break
|
56 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
57 |
+
frames.append(Image.fromarray(frame))
|
58 |
+
|
59 |
+
cap.release()
|
60 |
+
|
61 |
+
# Extract Features
|
62 |
+
inputs = vit_processor(images=frames, return_tensors="pt")["pixel_values"]
|
63 |
+
features = vit_model(inputs).last_hidden_state.mean(dim=1)
|
64 |
+
|
65 |
+
# Predict
|
66 |
+
features = features.unsqueeze(0) # Add batch dimension
|
67 |
+
output = model(features)
|
68 |
+
predicted_class = torch.argmax(output, dim=1).item()
|
69 |
+
|
70 |
+
LABELS = ["BaseballPitch", "Basketball", "BenchPress", "Biking", "Billiards"]
|
71 |
+
print("Predicted Action:", LABELS[predicted_class])
|
72 |
+
````
|
73 |
+
|
74 |
+
**Contributors**
|
75 |
+
Saurav Dhiani – Model Development & Deployment
|
76 |
+
ViT & LSTM – Core ML Architecture
|