mosesb commited on
Commit
bb725a1
·
verified ·
1 Parent(s): 815e023

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +136 -133
README.md CHANGED
@@ -1,134 +1,137 @@
1
- ---
2
- license: mit
3
- library_name: timm
4
- tags:
5
- - image-classification
6
- - mobilevit
7
- - timm
8
- - drowsiness-detection
9
- - computer-vision
10
- - pytorch
11
- widget:
12
- - modelId: your-username/mobilevit-drowsiness-detection
13
- title: Drowsiness Detection with MobileViT v2
14
- url: https://huggingface.co/spaces/user-name/repo-name/resolve/main/grid_output.jpg
15
- datasets:
16
- - ismailnasri20/driver-drowsiness-dataset-ddd
17
- - yasharjebraeily/drowsy-detection-dataset
18
- metrics:
19
- - accuracy
20
- - f1
21
- - precision
22
- - recall
23
- ---
24
-
25
- # MobileViT v2 for Drowsiness Detection
26
-
27
- This repository contains a `MobileViT v2` classification model fine-tuned to detect driver drowsiness from images. The model is a state-of-the-art, lightweight, hybrid architecture combining convolutions with Vision Transformers, making it efficient and accurate. It classifies input images into two categories: `Drowsy` and `Non Drowsy`.
28
-
29
- This model was trained in PyTorch using the `timm` library and demonstrates high performance on an unseen test set, making it a reliable foundation for driver safety applications.
30
-
31
- ## Model Details
32
- * **Architecture:** `mobilevitv2_200`
33
- * **Fine-tuned on:** A combined dataset for driver drowsiness detection.
34
- * **Classes:** `Drowsy`, `Non Drowsy`
35
- * **Frameworks:** PyTorch, timm
36
-
37
- ## How to Get Started
38
-
39
- You can easily use this model with the `timm` and `torch` libraries. First, ensure you have the `best_model.pt` file from this repository.
40
-
41
- ```python
42
- # Install required libraries
43
- !pip install timm torch torchvision
44
-
45
- import torch
46
- import timm
47
- from PIL import Image
48
- from torchvision import transforms
49
-
50
- # --- 1. Setup Model and Preprocessing ---
51
- # Define the same transformations used for validation/testing
52
- val_test_transform = transforms.Compose([
53
- transforms.Resize((224, 224)),
54
- transforms.ToTensor(),
55
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
56
- ])
57
-
58
- # Define class names (ensure order matches training: Drowsy=0, Non Drowsy=1)
59
- class_names = ['Drowsy', 'Non Drowsy']
60
-
61
- # Load the model architecture
62
- model = timm.create_model('mobilevitv2_200', pretrained=False, num_classes=2)
63
-
64
- # Load the fine-tuned weights
65
- model_path = 'best_model.pt'
66
- model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
67
- model.eval()
68
-
69
- # --- 2. Run Inference ---
70
- image_path = 'path/to/your/image.jpg'
71
- image = Image.open(image_path).convert('RGB')
72
-
73
- # Preprocess the image
74
- input_tensor = val_test_transform(image).unsqueeze(0) # Add batch dimension
75
-
76
- # Get model prediction
77
- with torch.no_grad():
78
- output = model(input_tensor)
79
- probabilities = torch.nn.functional.softmax(output[0], dim=0)
80
- top_prob, top_class_index = torch.topk(probabilities, 1)
81
-
82
- class_name = class_names[top_class_index.item()]
83
- confidence = top_prob.item()
84
-
85
- print(f"Prediction: {class_name} with confidence {confidence:.4f}")
86
- ```
87
-
88
- ## Training Procedure
89
-
90
- The model was fine-tuned on a large dataset of over 40,000 driver images. The training process involved:
91
- - **Data Augmentation:** A strong augmentation pipeline was used for training, including `RandomResizedCrop`, `RandomHorizontalFlip`, `ColorJitter`, and `RandomErasing`.
92
- - **Transfer Learning:** The model was initialized with weights pretrained on ImageNet, enabling robust feature extraction and fast convergence.
93
- - **Early Stopping:** Training was halted after 30 epochs of no improvement in validation accuracy to prevent overfitting.
94
-
95
- ### Key Hyperparameters
96
- - **Image Size:** 224x224
97
- - **Batch Size:** 64
98
- - **Optimizer:** AdamW (lr=1e-4)
99
- - **Scheduler:** ExponentialLR (gamma=0.90)
100
- - **Loss Function:** CrossEntropyLoss
101
-
102
- ![Training Results](training_plot.png)
103
-
104
- ## Evaluation
105
-
106
- The model was evaluated on a completely **unseen test set** (from a different dataset than the primary training data) to ensure a fair assessment of its generalization capabilities.
107
-
108
- ### Key Performance Metrics
109
- | Metric | Value | Description |
110
- | :----: | :----: | :------------------------------------------------- |
111
- | **Accuracy** | 98.18% | Overall correctness on the test set. |
112
- | **APCER** | 3.57% | Rate of 'Drowsy' drivers missed (False Negatives). |
113
- | **BPCER** | 0.00% | Rate of 'Non Drowsy' drivers flagged (False Positives). |
114
- | **ACER** | 1.78% | Average of APCER and BPCER. |
115
-
116
- *APCER (Attack Presentation Classification Error Rate, adapted here) is the most critical safety metric, as it measures the failure to detect a drowsy driver.*
117
-
118
- ![Confusion Matrix](output_confusion_matrix.png)
119
-
120
- ### Model Explainability (Grad-CAM)
121
- To ensure the model is focusing on relevant facial features, Grad-CAM was used. The heatmaps confirm that the model's predictions are primarily based on the driver's eyes, mouth, and head position, which are key indicators of drowsiness.
122
-
123
- ![Grad-CAM Visualization](output_grad_cam.jpg)
124
-
125
- ## Intended Use and Limitations
126
- This model is intended as a proof-of-concept for driver safety systems and academic research. It should not be used as the sole mechanism for preventing accidents in a production environment without further rigorous testing.
127
-
128
- Real-world performance may vary based on:
129
- - Lighting conditions (especially at night).
130
- - Camera angles and distance.
131
- - Occlusions (e.g., sunglasses, hats, hands on face).
132
- - Individual differences not represented in the training data.
133
-
 
 
 
134
  *This model card is based on the training notebook [`MobileViT_Drowsiness.ipynb`](https://github.com/mosesab/MobileViT-Drowsiness-Detection/blob/main/MobileViT_Drowsiness.ipynb).*
 
1
+ ---
2
+ license: mit
3
+ library_name: timm
4
+ tags:
5
+ - image-classification
6
+ - mobilevit
7
+ - timm
8
+ - drowsiness-detection
9
+ - computer-vision
10
+ - pytorch
11
+ widget:
12
+ - modelId: mosesb/drowsiness-detection-mobileViT-v2
13
+ title: Drowsiness Detection with MobileViT v2
14
+ url: >-
15
+ https://huggingface.co/spaces/mosesb/drowsiness-detection-mobileViT-v2/resolve/main/output_grad_cam.jpg
16
+ datasets:
17
+ - ismailnasri20/driver-drowsiness-dataset-ddd
18
+ - yasharjebraeily/drowsy-detection-dataset
19
+ metrics:
20
+ - accuracy
21
+ - f1
22
+ - precision
23
+ - recall
24
+ base_model:
25
+ - apple/mobilevitv2-1.0-imagenet1k-256
26
+ ---
27
+
28
+ # MobileViT v2 for Drowsiness Detection
29
+
30
+ This repository contains a `MobileViT v2` classification model fine-tuned to detect driver drowsiness from images. The model is a state-of-the-art, lightweight, hybrid architecture combining convolutions with Vision Transformers, making it efficient and accurate. It classifies input images into two categories: `Drowsy` and `Non Drowsy`.
31
+
32
+ This model was trained in PyTorch using the `timm` library and demonstrates high performance on an unseen test set, making it a reliable foundation for driver safety applications.
33
+
34
+ ## Model Details
35
+ * **Architecture:** `mobilevitv2_200`
36
+ * **Fine-tuned on:** A combined dataset for driver drowsiness detection.
37
+ * **Classes:** `Drowsy`, `Non Drowsy`
38
+ * **Frameworks:** PyTorch, timm
39
+
40
+ ## How to Get Started
41
+
42
+ You can easily use this model with the `timm` and `torch` libraries. First, ensure you have the `best_model.pt` file from this repository.
43
+
44
+ ```python
45
+ # Install required libraries
46
+ !pip install timm torch torchvision
47
+
48
+ import torch
49
+ import timm
50
+ from PIL import Image
51
+ from torchvision import transforms
52
+
53
+ # --- 1. Setup Model and Preprocessing ---
54
+ # Define the same transformations used for validation/testing
55
+ val_test_transform = transforms.Compose([
56
+ transforms.Resize((224, 224)),
57
+ transforms.ToTensor(),
58
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
59
+ ])
60
+
61
+ # Define class names (ensure order matches training: Drowsy=0, Non Drowsy=1)
62
+ class_names = ['Drowsy', 'Non Drowsy']
63
+
64
+ # Load the model architecture
65
+ model = timm.create_model('mobilevitv2_200', pretrained=False, num_classes=2)
66
+
67
+ # Load the fine-tuned weights
68
+ model_path = 'best_model.pt'
69
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
70
+ model.eval()
71
+
72
+ # --- 2. Run Inference ---
73
+ image_path = 'path/to/your/image.jpg'
74
+ image = Image.open(image_path).convert('RGB')
75
+
76
+ # Preprocess the image
77
+ input_tensor = val_test_transform(image).unsqueeze(0) # Add batch dimension
78
+
79
+ # Get model prediction
80
+ with torch.no_grad():
81
+ output = model(input_tensor)
82
+ probabilities = torch.nn.functional.softmax(output[0], dim=0)
83
+ top_prob, top_class_index = torch.topk(probabilities, 1)
84
+
85
+ class_name = class_names[top_class_index.item()]
86
+ confidence = top_prob.item()
87
+
88
+ print(f"Prediction: {class_name} with confidence {confidence:.4f}")
89
+ ```
90
+
91
+ ## Training Procedure
92
+
93
+ The model was fine-tuned on a large dataset of over 40,000 driver images. The training process involved:
94
+ - **Data Augmentation:** A strong augmentation pipeline was used for training, including `RandomResizedCrop`, `RandomHorizontalFlip`, `ColorJitter`, and `RandomErasing`.
95
+ - **Transfer Learning:** The model was initialized with weights pretrained on ImageNet, enabling robust feature extraction and fast convergence.
96
+ - **Early Stopping:** Training was halted after 30 epochs of no improvement in validation accuracy to prevent overfitting.
97
+
98
+ ### Key Hyperparameters
99
+ - **Image Size:** 224x224
100
+ - **Batch Size:** 64
101
+ - **Optimizer:** AdamW (lr=1e-4)
102
+ - **Scheduler:** ExponentialLR (gamma=0.90)
103
+ - **Loss Function:** CrossEntropyLoss
104
+
105
+ ![Training Results](training_plot.png)
106
+
107
+ ## Evaluation
108
+
109
+ The model was evaluated on a completely **unseen test set** (from a different dataset than the primary training data) to ensure a fair assessment of its generalization capabilities.
110
+
111
+ ### Key Performance Metrics
112
+ | Metric | Value | Description |
113
+ | :----: | :----: | :------------------------------------------------- |
114
+ | **Accuracy** | 98.18% | Overall correctness on the test set. |
115
+ | **APCER** | 3.57% | Rate of 'Drowsy' drivers missed (False Negatives). |
116
+ | **BPCER** | 0.00% | Rate of 'Non Drowsy' drivers flagged (False Positives). |
117
+ | **ACER** | 1.78% | Average of APCER and BPCER. |
118
+
119
+ *APCER (Attack Presentation Classification Error Rate, adapted here) is the most critical safety metric, as it measures the failure to detect a drowsy driver.*
120
+
121
+ ![Confusion Matrix](output_confusion_matrix.png)
122
+
123
+ ### Model Explainability (Grad-CAM)
124
+ To ensure the model is focusing on relevant facial features, Grad-CAM was used. The heatmaps confirm that the model's predictions are primarily based on the driver's eyes, mouth, and head position, which are key indicators of drowsiness.
125
+
126
+ ![Grad-CAM Visualization](output_grad_cam.jpg)
127
+
128
+ ## Intended Use and Limitations
129
+ This model is intended as a proof-of-concept for driver safety systems and academic research. It should not be used as the sole mechanism for preventing accidents in a production environment without further rigorous testing.
130
+
131
+ Real-world performance may vary based on:
132
+ - Lighting conditions (especially at night).
133
+ - Camera angles and distance.
134
+ - Occlusions (e.g., sunglasses, hats, hands on face).
135
+ - Individual differences not represented in the training data.
136
+
137
  *This model card is based on the training notebook [`MobileViT_Drowsiness.ipynb`](https://github.com/mosesab/MobileViT-Drowsiness-Detection/blob/main/MobileViT_Drowsiness.ipynb).*