Upload 9 files
Browse filesuploaded the codes and sample models
- LICENSE +21 -0
- README.md +266 -3
- balanced_data_loader-1.py +216 -0
- classify_image_and_explain.py +256 -0
- data_loader.py +173 -0
- predict.py +65 -0
- requirements.txt +0 -0
- test.py +161 -0
- train.py +176 -0
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 Purushothaman
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,3 +1,266 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Interpretable-SONAR-Image-Classifier
|
2 |
+
|
3 |
+
Explainable AI for Underwater SONAR Image Classifier
|
4 |
+
|
5 |
+
## Prerequisites
|
6 |
+
|
7 |
+
- Python 3.6 or higher
|
8 |
+
|
9 |
+
## Running the Scripts
|
10 |
+
|
11 |
+
This guide will help you run the `data_loader.py`, `train.py`, `test.py`, `predict.py`, and `classify_image_and_explain.py` scripts directly from the command line or within a Python script.
|
12 |
+
|
13 |
+
### Prerequisites
|
14 |
+
|
15 |
+
1. **Python Installation**: Ensure you have Python installed. You can download it from [python.org](https://www.python.org/).
|
16 |
+
|
17 |
+
2. **Required Packages**: Install the required packages using `requirements.txt`.
|
18 |
+
```sh
|
19 |
+
pip install -r requirements.txt
|
20 |
+
```
|
21 |
+
|
22 |
+
### Script Descriptions and Usage
|
23 |
+
|
24 |
+
#### 1. `data_loader.py`
|
25 |
+
|
26 |
+
This script is used to load, process, split datasets (train, val, test), and augment data.
|
27 |
+
|
28 |
+
**Command Line Usage:**
|
29 |
+
|
30 |
+
```sh
|
31 |
+
python data_loader.py --path <path_to_data> --target_folder <path_to_target_folder> --dim <dimension> --batch_size <batch_size> --num_workers <num_workers> [--augment_data]
|
32 |
+
```
|
33 |
+
|
34 |
+
**Arguments:**
|
35 |
+
|
36 |
+
- `--path`: Path to the data.
|
37 |
+
- `--target_folder`: Path to the target folder where processed data will be saved.
|
38 |
+
- `--dim`: Dimension for resizing the images.
|
39 |
+
- `--batch_size`: Batch size for data loading.
|
40 |
+
- `--num_workers`: Number of workers for data loading.
|
41 |
+
- `--augment_data` (optional): Flag to enable data augmentation.
|
42 |
+
|
43 |
+
**Example:**
|
44 |
+
|
45 |
+
```sh
|
46 |
+
python data_loader.py --path "./dataset" --target_folder "./processed_data" --dim 224 --batch_size 32 --num_workers 4 --augment_data
|
47 |
+
```
|
48 |
+
|
49 |
+
**Dataset Structure:**
|
50 |
+
|
51 |
+
```sh
|
52 |
+
├── Dataset (Raw)
|
53 |
+
├── class_name_1
|
54 |
+
│ └── *.jpg
|
55 |
+
├── class_name_2
|
56 |
+
│ └── *.jpg
|
57 |
+
├── class_name_3
|
58 |
+
│ └── *.jpg
|
59 |
+
└── class_name_4
|
60 |
+
└── *.jpg
|
61 |
+
```
|
62 |
+
|
63 |
+
#### 2. `train.py`
|
64 |
+
|
65 |
+
This script is used for training and storing the models, leveraging transfer learning.
|
66 |
+
|
67 |
+
**Command Line Usage:**
|
68 |
+
|
69 |
+
```sh
|
70 |
+
python train.py --base_model_names <model_names> --shape <shape> --data_path <data_path> --log_dir <log_dir> --model_dir <model_dir> --epochs <epochs> --optimizer <optimizer> --learning_rate <learning_rate> --batch_size <batch_size>
|
71 |
+
```
|
72 |
+
|
73 |
+
**Arguments:**
|
74 |
+
|
75 |
+
- `--base_models`: Comma-separated list of base model names (e.g., 'VGG16, ResNet50').
|
76 |
+
- `--shape`: Image shape (size).
|
77 |
+
- `--data_path`: Path to the data.
|
78 |
+
- `--log_dir`: Path to the log directory.
|
79 |
+
- `--model_dir`: Path to the model directory.
|
80 |
+
- `--epochs`: Number of training epochs.
|
81 |
+
- `--optimizer`: Optimizer type ('adam' or 'sgd').
|
82 |
+
- `--learning_rate`: Learning rate for the optimizer.
|
83 |
+
- `--batch_size`: Batch size for training.
|
84 |
+
- `--patience`: Patience for early stopping (to prevent overfitting).
|
85 |
+
|
86 |
+
**Example:**
|
87 |
+
|
88 |
+
```sh
|
89 |
+
python train.py --base_models "VGG16" "DenseNet121" --shape 224 224 3 --data_path "./processed_data" --log_dir "./logs" --model_dir "./models" --epochs 100 --optimizer "adam" --learning_rate 0.0001 --batch_size 32
|
90 |
+
```
|
91 |
+
|
92 |
+
#### 3. `test.py`
|
93 |
+
|
94 |
+
This script is used for testing and storing the test logs of the above-trained models.
|
95 |
+
|
96 |
+
**Command Line Usage:**
|
97 |
+
|
98 |
+
```sh
|
99 |
+
python test.py --data_path <data_path> --base_model_name <base_model_name> --model_path <model_path> --models_folder_path <models_folder_path> --log_dir <log_dir>
|
100 |
+
```
|
101 |
+
|
102 |
+
**Arguments:**
|
103 |
+
|
104 |
+
- `--models_dir` (optional): Path to the models (directory).
|
105 |
+
- `--model_path`: Path to the model (.h5/Keras Model).
|
106 |
+
- `--img_path`: Path to the image file.
|
107 |
+
- `--test_dir`: Path to the test dataset (directory).
|
108 |
+
- `--train_dir`: Path to the training data.
|
109 |
+
- `--log_dir`: Path to the log directory.
|
110 |
+
|
111 |
+
**Example:**
|
112 |
+
|
113 |
+
```sh
|
114 |
+
python test.py --model_path "./models/vgg16_model.keras" --test_dir "./test_data" --train_dir "./data/train" --log_dir "./logs"
|
115 |
+
```
|
116 |
+
|
117 |
+
#### 4. `predict.py`
|
118 |
+
|
119 |
+
This script is used for making predictions on new images.
|
120 |
+
|
121 |
+
**Command Line Usage:**
|
122 |
+
|
123 |
+
```sh
|
124 |
+
python predict.py --model_path <model_path> --img_path <img_path> --train_dir <train_dir>
|
125 |
+
```
|
126 |
+
|
127 |
+
**Arguments:**
|
128 |
+
|
129 |
+
- `--model_path`: Path to the model file.
|
130 |
+
- `--img_path`: Path to the image file.
|
131 |
+
- `--train_dir`: Path to the training dataset (for the label decoder, can be replaced with a CSV file with slight code modifications).
|
132 |
+
|
133 |
+
**Example:**
|
134 |
+
|
135 |
+
```sh
|
136 |
+
python predict.py --model_path "./models/vgg16_model.keras" --img_path "./images/test_image.jpg" --train_dir "./data/train"
|
137 |
+
```
|
138 |
+
|
139 |
+
#### 5. `classify_image_and_explain.py`
|
140 |
+
|
141 |
+
This script is used for making predictions on new images and generating explanations using one or more explainers (LIME, SHAP, Grad-CAM). The explanations are saved in the specified output folder, with filenames indicating the method used (e.g., `lime_explanation_1.jpg`, `shap_explanation_1.jpg`, `gradcam_explanation_1.jpg`).
|
142 |
+
|
143 |
+
**Command Line Usage:**
|
144 |
+
|
145 |
+
```sh
|
146 |
+
python classify_image_and_explain.py --image_path <image_path> --model_path <model_path> --train_directory <train_directory> --num_samples <num_samples> --num_features <num_features> --segmentation_alg <segmentation_alg> --kernel_size <kernel_size> --max_dist <max_dist> --ratio <ratio> --max_evals <max_evals> --batch_size <batch_size> --explainer_types <explainer_types> --output_folder <output_folder>
|
147 |
+
```
|
148 |
+
|
149 |
+
**Arguments:**
|
150 |
+
|
151 |
+
- `--image_path` (required): Path to the input image.
|
152 |
+
- `--model_path` (required): Path to the trained model.
|
153 |
+
- `--train_directory` (required): Directory containing training images.
|
154 |
+
- `--num_samples` (default: 300): Number of samples for LIME.
|
155 |
+
- `--num_features` (default: 100): Number of features for LIME.
|
156 |
+
- `--segmentation_alg` (default: `quickshift`): Segmentation algorithm for LIME (`quickshift`, `slic`).
|
157 |
+
- `--kernel_size` (default: 4): Kernel size for the segmentation algorithm.
|
158 |
+
- `--max_dist` (default: 200): Maximum distance for the segmentation algorithm.
|
159 |
+
- `--ratio` (default: 0.2): Ratio for the segmentation algorithm.
|
160 |
+
- `--max_evals` (default: 400): Maximum evaluations for SHAP.
|
161 |
+
- `--batch_size` (default: 50): Batch size for SHAP.
|
162 |
+
- `--explainer_types` (default: 'all'): Comma-separated list of explainers to use (`lime`, `shap`, `gradcam`). Use 'all' to include all three explainers.
|
163 |
+
- `--output_folder` (optional): Folder to save explanation images.
|
164 |
+
|
165 |
+
**Example:**
|
166 |
+
|
167 |
+
```sh
|
168 |
+
python classify_image_and_explain.py --image_path "./images/test_image.jpg" --model_path "./models/model.keras" --train_directory "./data/train" --num_samples 300 --num_features 100 --segmentation_alg "quickshift" --kernel_size 4 --max_dist 200 --ratio 0.2 --max_evals 400 --batch_size 50 --explainer_types "lime, gradcam" --output_folder "./explanations"
|
169 |
+
```
|
170 |
+
|
171 |
+
### Supported Base Models
|
172 |
+
|
173 |
+
The following base models are supported for training:
|
174 |
+
- VGG16
|
175 |
+
- VGG19
|
176 |
+
- ResNet50
|
177 |
+
- ResNet101
|
178 |
+
- InceptionV3
|
179 |
+
- DenseNet121
|
180 |
+
- DenseNet201
|
181 |
+
- MobileNetV2
|
182 |
+
- Xception
|
183 |
+
- InceptionResNetV2
|
184 |
+
- NASNetLarge
|
185 |
+
- NASNetMobile
|
186 |
+
- EfficientNetB0
|
187 |
+
- EfficientNetB7
|
188 |
+
|
189 |
+
### Running Scripts in a Python Script
|
190 |
+
|
191 |
+
You can also run these scripts programmatically using Python's `subprocess` module. Here is an example of how to do this for each script:
|
192 |
+
|
193 |
+
```python
|
194 |
+
import subprocess
|
195 |
+
|
196 |
+
# Run data_loader.py
|
197 |
+
subprocess.run([
|
198 |
+
"python", "data_loader.py",
|
199 |
+
"--path", "./data",
|
200 |
+
"--target_folder", "./processed_data",
|
201 |
+
"--dim", "224",
|
202 |
+
"--batch_size", "32",
|
203 |
+
"--num_workers", "4",
|
204 |
+
"--augment_data"
|
205 |
+
])
|
206 |
+
|
207 |
+
# Run train.py
|
208 |
+
subprocess.run([
|
209 |
+
"python", "train.py",
|
210 |
+
"--base_models", "VGG16,ResNet50",
|
211 |
+
"--shape", "224, 224, 3",
|
212 |
+
"--data_path", "./data",
|
213 |
+
"--log_dir", "./logs",
|
214 |
+
"--model_dir", "./models",
|
215 |
+
"--epochs", "100",
|
216 |
+
"--optimizer", "adam",
|
217 |
+
"--learning_rate", "0.001",
|
218 |
+
"--batch_size", "32",
|
219 |
+
"--patience", "10"
|
220 |
+
])
|
221 |
+
|
222 |
+
# Run test.py
|
223 |
+
subprocess.run([
|
224 |
+
"python", "test.py",
|
225 |
+
"--models_dir", "./models",
|
226 |
+
"--img
|
227 |
+
|
228 |
+
_path", "./images/test_image.jpg",
|
229 |
+
"--train_dir", "./data/train",
|
230 |
+
"--log_dir", "./logs"
|
231 |
+
])
|
232 |
+
|
233 |
+
# Run classify_image_and_explain.py
|
234 |
+
subprocess.run([
|
235 |
+
"python", "classify_image_and_explain.py",
|
236 |
+
"--image_path", "./images/test_image.jpg",
|
237 |
+
"--model_path", "./models/model.h5",
|
238 |
+
"--train_directory", "./data/train",
|
239 |
+
"--num_samples", "300",
|
240 |
+
"--num_features", "100",
|
241 |
+
"--segmentation_alg", "quickshift",
|
242 |
+
"--kernel_size", "4",
|
243 |
+
"--max_dist", "200",
|
244 |
+
"--ratio", "0.2",
|
245 |
+
"--max_evals", "400",
|
246 |
+
"--batch_size", "50",
|
247 |
+
"--explainer_types", "lime,gradcam",
|
248 |
+
"--output_folder", "./explanations"
|
249 |
+
])
|
250 |
+
```
|
251 |
+
|
252 |
+
## License
|
253 |
+
This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details.
|
254 |
+
|
255 |
+
## Citing the part of the project: Under water sonar image classifier with XAI LIME
|
256 |
+
|
257 |
+
If you use our SONAR classifier or the explainer in your research, please use the following BibTeX entry.
|
258 |
+
|
259 |
+
```
|
260 |
+
@article{natarajan2024underwater,
|
261 |
+
title={Underwater SONAR Image Classification and Analysis using LIME-based Explainable Artificial Intelligence},
|
262 |
+
author={Natarajan, Purushothaman and Nambiar, Athira},
|
263 |
+
journal={arXiv preprint arXiv:2408.12837},
|
264 |
+
year={2024}
|
265 |
+
}
|
266 |
+
```
|
balanced_data_loader-1.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
import os
|
3 |
+
import argparse
|
4 |
+
from sklearn.model_selection import StratifiedShuffleSplit
|
5 |
+
from tqdm import tqdm
|
6 |
+
import uuid
|
7 |
+
import random
|
8 |
+
|
9 |
+
# Parses command line arguments
|
10 |
+
def parse_arguments():
|
11 |
+
parser = argparse.ArgumentParser(description='Image Data Loader with Augmentation and Splits')
|
12 |
+
parser.add_argument('--path', type=str, required=True, help='Path to the folder containing images')
|
13 |
+
parser.add_argument('--dim', type=int, default=224, help='Required image dimension')
|
14 |
+
parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
|
15 |
+
parser.add_argument('--target_folder', type=str, required=True, help='Folder to store the train, test, and val splits')
|
16 |
+
parser.add_argument('--augment_data', action='store_true', help='Apply data augmentation')
|
17 |
+
parser.add_argument('--balance', action='store_true', help='Balance the dataset')
|
18 |
+
parser.add_argument('--split_type', type=str, choices=['random', 'stratified'], default='random',
|
19 |
+
help='Type of data split (random or stratified)')
|
20 |
+
return parser.parse_args()
|
21 |
+
|
22 |
+
# Process the input images
|
23 |
+
def process_image(file_path, image_size):
|
24 |
+
image = tf.io.read_file(file_path)
|
25 |
+
image = tf.image.decode_image(image, channels=3, dtype=tf.float32)
|
26 |
+
image = tf.image.resize(image, [image_size, image_size])
|
27 |
+
image = tf.clip_by_value(image, 0.0, 1.0)
|
28 |
+
return image
|
29 |
+
|
30 |
+
# Balances the images of a specific class
|
31 |
+
def balance_class_images(image_paths, labels, target_count, image_size, label, label_to_index, output_folder):
|
32 |
+
print(f"Balancing class '{label}'...")
|
33 |
+
label_idx = label_to_index.get(label, None)
|
34 |
+
if label_idx is None:
|
35 |
+
print(f"Label '{label}' not found in label_to_index.")
|
36 |
+
return [], []
|
37 |
+
|
38 |
+
image_paths = [img for img, lbl in zip(image_paths, labels) if lbl == label_idx]
|
39 |
+
num_images = len(image_paths)
|
40 |
+
|
41 |
+
print(f"Class '{label}' has {num_images} images before balancing.")
|
42 |
+
|
43 |
+
balanced_images = []
|
44 |
+
balanced_labels = []
|
45 |
+
|
46 |
+
original_count = num_images
|
47 |
+
synthetic_count = 0
|
48 |
+
|
49 |
+
if num_images > target_count:
|
50 |
+
balanced_images.extend(random.sample(image_paths, target_count))
|
51 |
+
balanced_labels.extend([label_idx] * target_count)
|
52 |
+
print(f"Removed {num_images - target_count} images from class '{label}'.")
|
53 |
+
elif num_images < target_count:
|
54 |
+
balanced_images.extend(image_paths)
|
55 |
+
balanced_labels.extend([label_idx] * num_images)
|
56 |
+
|
57 |
+
num_to_add = target_count - num_images
|
58 |
+
print(f"Class '{label}' needs {num_to_add} additional images for balancing.")
|
59 |
+
|
60 |
+
while num_to_add > 0:
|
61 |
+
img_path = random.choice(image_paths)
|
62 |
+
image = process_image(img_path, image_size)
|
63 |
+
|
64 |
+
for _ in range(min(num_to_add, 5)): # Use up to 5 augmentations per image
|
65 |
+
augmented_image = augment_image(image)
|
66 |
+
balanced_images.append(augmented_image)
|
67 |
+
balanced_labels.append(label_idx)
|
68 |
+
num_to_add -= 1
|
69 |
+
synthetic_count += 1
|
70 |
+
|
71 |
+
print(f"Added {synthetic_count} augmented images to class '{label}'.")
|
72 |
+
print(f"Class '{label}' has {len(balanced_images)} images after balancing.")
|
73 |
+
|
74 |
+
class_folder = os.path.join(output_folder, str(label_idx))
|
75 |
+
if not os.path.exists(class_folder):
|
76 |
+
os.makedirs(class_folder)
|
77 |
+
|
78 |
+
for i, img in enumerate(balanced_images):
|
79 |
+
file_name = f"{uuid.uuid4()}.png"
|
80 |
+
file_path = os.path.join(class_folder, file_name)
|
81 |
+
save_image(img, file_path)
|
82 |
+
|
83 |
+
print(f"Saved {len(balanced_images)} images for class '{label}' (Original: {original_count}, Synthetic: {synthetic_count}).")
|
84 |
+
|
85 |
+
return balanced_images, balanced_labels
|
86 |
+
|
87 |
+
# Saves an image to a file
|
88 |
+
def save_image(image, file_path):
|
89 |
+
if isinstance(image, str):
|
90 |
+
image = process_image(image, image_size)
|
91 |
+
if isinstance(image, tf.Tensor):
|
92 |
+
image = tf.image.convert_image_dtype(image, dtype=tf.uint8)
|
93 |
+
image = tf.image.encode_png(image)
|
94 |
+
else:
|
95 |
+
raise ValueError("Expected image to be a TensorFlow tensor, but got a different type.")
|
96 |
+
|
97 |
+
tf.io.write_file(file_path, image)
|
98 |
+
|
99 |
+
# Augments an image with random transformations
|
100 |
+
def augment_image(image):
|
101 |
+
# Apply random augmentations using TensorFlow functions
|
102 |
+
image = tf.image.random_flip_left_right(image)
|
103 |
+
image = tf.image.random_flip_up_down(image)
|
104 |
+
image = tf.image.random_brightness(image, max_delta=0.1)
|
105 |
+
image = tf.image.random_contrast(image, lower=0.9, upper=1.1)
|
106 |
+
image = tf.image.random_saturation(image, lower=0.9, upper=1.1)
|
107 |
+
image = tf.image.random_hue(image, max_delta=0.1)
|
108 |
+
return image
|
109 |
+
|
110 |
+
# Creates a list of data augmentation functions
|
111 |
+
def create_datagens():
|
112 |
+
return [augment_image]
|
113 |
+
|
114 |
+
# Balances the entire dataset by balancing each class
|
115 |
+
def balance_data(images, labels, target_count, image_size, unique_labels, label_to_index, output_folder):
|
116 |
+
print(f"Balancing data: Target count per class = {target_count}")
|
117 |
+
|
118 |
+
all_balanced_images = []
|
119 |
+
all_balanced_labels = []
|
120 |
+
|
121 |
+
for label in tqdm(unique_labels, desc="Balancing classes"):
|
122 |
+
num_images = len([img for img, lbl in zip(images, labels) if lbl == label_to_index.get(label, -1)])
|
123 |
+
balanced_images, balanced_labels = balance_class_images(
|
124 |
+
images, labels, target_count, image_size, label, label_to_index, output_folder
|
125 |
+
)
|
126 |
+
all_balanced_images.extend(balanced_images)
|
127 |
+
all_balanced_labels.extend(balanced_labels)
|
128 |
+
|
129 |
+
total_original_images = sum(1 for img in all_balanced_images if isinstance(img, str))
|
130 |
+
total_synthetic_images = len(all_balanced_images) - total_original_images
|
131 |
+
|
132 |
+
print(f"\nTotal saved images: {len(all_balanced_images)} (Original: {total_original_images}, Synthetic: {total_synthetic_images})")
|
133 |
+
|
134 |
+
return all_balanced_images, all_balanced_labels
|
135 |
+
|
136 |
+
# Augments an image using TensorFlow functions
|
137 |
+
def tf_augment_image(file_path, label):
|
138 |
+
image = tf.image.resize(tf.image.decode_jpeg(tf.io.read_file(file_path)), [image_size, image_size])
|
139 |
+
image = tf.cast(image, tf.float32) / 255.0
|
140 |
+
augmented_image = augment_image(image)
|
141 |
+
return augmented_image, label
|
142 |
+
|
143 |
+
|
144 |
+
def map_fn(file_path, label):
|
145 |
+
image, label = tf.py_function(tf_augment_image, [file_path, label], [tf.float32, tf.int32])
|
146 |
+
image.set_shape([image_size, image_size, 3])
|
147 |
+
label.set_shape([])
|
148 |
+
return image, label
|
149 |
+
|
150 |
+
# Loads images, splits them into train, validation, and test sets, and saves the splits
|
151 |
+
def load_and_save_splits(path, image_size, batch_size, balance, datagens, target_folder, split_type):
|
152 |
+
all_images = []
|
153 |
+
labels = []
|
154 |
+
|
155 |
+
for class_folder in os.listdir(path):
|
156 |
+
class_path = os.path.join(path, class_folder)
|
157 |
+
if os.path.isdir(class_path):
|
158 |
+
for img_file in os.listdir(class_path):
|
159 |
+
img_path = os.path.join(class_path, img_file)
|
160 |
+
all_images.append(img_path)
|
161 |
+
labels.append(class_folder) # Use the folder name as the label
|
162 |
+
|
163 |
+
print(f"Loaded {len(all_images)} images across {len(set(labels))} classes.")
|
164 |
+
print(f"Labels found: {set(labels)}") # Print unique labels
|
165 |
+
|
166 |
+
unique_labels = list(set(labels))
|
167 |
+
label_to_index = {label: idx for idx, label in enumerate(unique_labels)}
|
168 |
+
encoded_labels = [label_to_index[label] for label in labels]
|
169 |
+
|
170 |
+
print(f"Label to index mapping: {label_to_index}")
|
171 |
+
|
172 |
+
if split_type == 'stratified':
|
173 |
+
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
|
174 |
+
train_indices, test_indices = next(sss.split(all_images, encoded_labels))
|
175 |
+
else: # random split
|
176 |
+
total_images = len(all_images)
|
177 |
+
indices = list(range(total_images))
|
178 |
+
random.shuffle(indices)
|
179 |
+
train_indices = indices[:int(0.8 * total_images)]
|
180 |
+
test_indices = indices[int(0.8 * total_images):]
|
181 |
+
|
182 |
+
train_files = [all_images[i] for i in train_indices]
|
183 |
+
train_labels = [encoded_labels[i] for i in train_indices]
|
184 |
+
test_files = [all_images[i] for i in test_indices]
|
185 |
+
test_labels = [encoded_labels[i] for i in test_indices]
|
186 |
+
|
187 |
+
# Create validation and test sets
|
188 |
+
sss_val = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=42)
|
189 |
+
val_indices, test_indices = next(sss_val.split(test_files, test_labels))
|
190 |
+
|
191 |
+
val_files = [test_files[i] for i in val_indices]
|
192 |
+
val_labels = [test_labels[i] for i in val_indices]
|
193 |
+
test_files = [test_files[i] for i in test_indices]
|
194 |
+
test_labels = [test_labels[i] for i in test_indices]
|
195 |
+
|
196 |
+
# Save splits
|
197 |
+
for split_name, file_list, labels_list in [("train", train_files, train_labels), ("val", val_files, val_labels), ("test", test_files, test_labels)]:
|
198 |
+
split_folder = os.path.join(target_folder, split_name)
|
199 |
+
os.makedirs(split_folder, exist_ok=True)
|
200 |
+
with open(os.path.join(split_folder, f"{split_name}_list.txt"), 'w') as file_list_file:
|
201 |
+
for img_path, label in zip(file_list, labels_list):
|
202 |
+
label_folder = os.path.join(split_folder, str(label))
|
203 |
+
if not os.path.exists(label_folder):
|
204 |
+
os.makedirs(label_folder)
|
205 |
+
file_list_file.write(f"{img_path}\n")
|
206 |
+
save_image(img_path, os.path.join(label_folder, f"{uuid.uuid4()}.png"))
|
207 |
+
|
208 |
+
print(f"Saved splits: train: {len(train_files)}, val: {len(val_files)}, test: {len(test_files)}.")
|
209 |
+
|
210 |
+
# Main function to run the data loader
|
211 |
+
def main():
|
212 |
+
args = parse_arguments()
|
213 |
+
load_and_save_splits(args.path, args.dim, args.batch_size, args.balance, create_datagens(), args.target_folder, args.split_type)
|
214 |
+
|
215 |
+
if __name__ == "__main__":
|
216 |
+
main()
|
classify_image_and_explain.py
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import tensorflow as tf
|
4 |
+
from tensorflow.keras.preprocessing.image import img_to_array, array_to_img, load_img
|
5 |
+
from lime.lime_image import LimeImageExplainer, SegmentationAlgorithm
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
from PIL import Image
|
8 |
+
import argparse
|
9 |
+
import shap
|
10 |
+
import cv2
|
11 |
+
import pickle
|
12 |
+
|
13 |
+
image_counter = 0
|
14 |
+
temp_folder = "temp_data"
|
15 |
+
output_folder = "explanations"
|
16 |
+
|
17 |
+
# Load the model and extract relevant details
|
18 |
+
def load_model_details(model_path):
|
19 |
+
if model_path.endswith('.keras'):
|
20 |
+
print("Loading .keras format model...")
|
21 |
+
model = tf.keras.models.load_model(model_path, compile=False)
|
22 |
+
elif model_path.endswith('.h5'):
|
23 |
+
print("Loading .h5 format model...")
|
24 |
+
model = tf.keras.models.load_model(model_path, compile=False)
|
25 |
+
else:
|
26 |
+
print("Loading SavedModel using TFSMLayer...")
|
27 |
+
model = tf.keras.Sequential([
|
28 |
+
tf.keras.layers.TFSMLayer(model_path, call_endpoint='serving_default')
|
29 |
+
])
|
30 |
+
|
31 |
+
input_shape = model.input_shape[1:3]
|
32 |
+
last_conv_layer_name = None
|
33 |
+
for layer in reversed(model.layers):
|
34 |
+
if isinstance(layer, tf.keras.layers.Conv2D):
|
35 |
+
last_conv_layer_name = layer.name
|
36 |
+
break
|
37 |
+
|
38 |
+
print(f"Model loaded with input shape: {input_shape} and last conv layer: {last_conv_layer_name}")
|
39 |
+
return model, last_conv_layer_name, input_shape
|
40 |
+
|
41 |
+
# Load the label encoder based on the training directory
|
42 |
+
def load_label_encoder(train_directory):
|
43 |
+
labels = sorted(os.listdir(train_directory))
|
44 |
+
label_encoder = {i: label for i, label in enumerate(labels)}
|
45 |
+
print(f"Label encoder created: {label_encoder}")
|
46 |
+
return label_encoder
|
47 |
+
|
48 |
+
def load_and_preprocess_image(filename, image_size):
|
49 |
+
# Load and preprocess the image for model input
|
50 |
+
print(f"Loading and preprocessing image from: {filename}")
|
51 |
+
image = tf.io.read_file(filename)
|
52 |
+
image = tf.image.decode_image(image, channels=3)
|
53 |
+
|
54 |
+
if not tf.executing_eagerly():
|
55 |
+
image.set_shape([None, None, 3])
|
56 |
+
|
57 |
+
image = tf.image.resize(image, [image_size[0], image_size[1]])
|
58 |
+
image = image / 255.0
|
59 |
+
image.set_shape([image_size[0], image_size[1], 3])
|
60 |
+
|
61 |
+
return image
|
62 |
+
|
63 |
+
# Create a dataset from the training directory
|
64 |
+
def create_dataset(data_dir, labels, image_size, batch_size):
|
65 |
+
print(f"Creating dataset from directory: {data_dir}")
|
66 |
+
image_files = []
|
67 |
+
image_labels = []
|
68 |
+
|
69 |
+
for label in labels:
|
70 |
+
label_dir = os.path.join(data_dir, label)
|
71 |
+
for image_file in os.listdir(label_dir):
|
72 |
+
image_files.append(os.path.join(label_dir, image_file))
|
73 |
+
image_labels.append(label)
|
74 |
+
|
75 |
+
label_map = {label: idx for idx, label in enumerate(labels)}
|
76 |
+
image_labels = [label_map[label] for label in image_labels]
|
77 |
+
|
78 |
+
dataset = tf.data.Dataset.from_tensor_slices((image_files, image_labels))
|
79 |
+
dataset = dataset.map(lambda x, y: (load_and_preprocess_image(x, image_size), y))
|
80 |
+
dataset = dataset.shuffle(buffer_size=len(image_files))
|
81 |
+
dataset = dataset.batch(batch_size).prefetch(buffer_size=tf.data.AUTOTUNE)
|
82 |
+
|
83 |
+
print("Dataset created and batched")
|
84 |
+
return dataset
|
85 |
+
|
86 |
+
# Save preprocessed data (images and labels) to a file
|
87 |
+
def save_preprocessed_data(X_train, y_train, file_path):
|
88 |
+
print(f"Saving preprocessed data to: {file_path}")
|
89 |
+
with open(file_path, 'wb') as file:
|
90 |
+
pickle.dump((X_train, y_train), file)
|
91 |
+
|
92 |
+
|
93 |
+
def load_preprocessed_data(file_path):
|
94 |
+
print(f"Loading preprocessed data from: {file_path}")
|
95 |
+
with open(file_path, 'rb') as file:
|
96 |
+
return pickle.load(file)
|
97 |
+
|
98 |
+
def make_gradcam_heatmap(img_array, model, last_conv_layer_name, pred_index=None):
|
99 |
+
# Generate a Grad-CAM heatmap for the given image and model
|
100 |
+
|
101 |
+
grad_model = tf.keras.models.Model(
|
102 |
+
inputs=model.inputs, outputs=[model.get_layer(last_conv_layer_name).output, model.output]
|
103 |
+
)
|
104 |
+
with tf.GradientTape() as tape:
|
105 |
+
last_conv_layer_output, preds = grad_model(img_array)
|
106 |
+
preds = tf.convert_to_tensor(preds)
|
107 |
+
class_channel = preds[:, pred_index]
|
108 |
+
# if pred_index is None:
|
109 |
+
# pred_index = tf.argmax(preds[0]) # Default to the class with the highest probability
|
110 |
+
# pred_index = tf.squeeze(pred_index) # Ensure pred_index is a scalar tensor
|
111 |
+
# if tf.executing_eagerly():
|
112 |
+
# pred_index = pred_index.numpy() # Convert to a NumPy array
|
113 |
+
# pred_index = int(pred_index) # Convert to a Python integer
|
114 |
+
# class_channel = preds[0][pred_index]
|
115 |
+
|
116 |
+
grads = tape.gradient(class_channel, last_conv_layer_output)
|
117 |
+
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
|
118 |
+
last_conv_layer_output = last_conv_layer_output[0]
|
119 |
+
heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
|
120 |
+
heatmap = tf.squeeze(heatmap)
|
121 |
+
heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
|
122 |
+
return heatmap.numpy()
|
123 |
+
|
124 |
+
def save_and_display_gradcam(array, heatmap, alpha=0.8):
|
125 |
+
# Save and display the Grad-CAM heatmap overlaid on the original image
|
126 |
+
print("Saving and displaying Grad-CAM result...")
|
127 |
+
heatmap = np.uint8(255 * heatmap)
|
128 |
+
jet = plt.cm.jet
|
129 |
+
jet_colors = jet(np.arange(256))[:, :3]
|
130 |
+
jet_heatmap = jet_colors[heatmap]
|
131 |
+
jet_heatmap = array_to_img(jet_heatmap)
|
132 |
+
jet_heatmap = jet_heatmap.resize((array.shape[1], array.shape[0]))
|
133 |
+
jet_heatmap = img_to_array(jet_heatmap)
|
134 |
+
superimposed_img = jet_heatmap * alpha + array
|
135 |
+
superimposed_img = array_to_img(superimposed_img)
|
136 |
+
return superimposed_img
|
137 |
+
|
138 |
+
def generate_splime_mask_top_n(img_array, model, explainer, top_n=1, num_features=100, num_samples=300):
|
139 |
+
# Generate a SP-LIME mask for the given image and model
|
140 |
+
# Use superpixel segmentation for SP-LIME
|
141 |
+
segmentation_fn = SegmentationAlgorithm('quickshift', kernel_size=4, max_dist=200, ratio=0.2)
|
142 |
+
|
143 |
+
explanation_instance = explainer.explain_instance(
|
144 |
+
img_array, model.predict, top_labels=top_n, hide_color=0,
|
145 |
+
num_samples=num_samples, num_features=num_features, segmentation_fn=segmentation_fn
|
146 |
+
)
|
147 |
+
explanation_mask = explanation_instance.get_image_and_mask(
|
148 |
+
explanation_instance.top_labels[0], positive_only=False,
|
149 |
+
num_features=num_features, hide_rest=True
|
150 |
+
)[1]
|
151 |
+
|
152 |
+
# Ensure mask is in the same shape as the input image
|
153 |
+
mask = np.zeros_like(img_array) # Create a mask of the same shape as img_array
|
154 |
+
mask[explanation_mask == 1] = img_array[explanation_mask == 1] # Overlay highlighted regions
|
155 |
+
|
156 |
+
# Set non-highlighted areas to white
|
157 |
+
mask = np.where(explanation_mask[:, :, np.newaxis] == 1, mask, 1.0)
|
158 |
+
|
159 |
+
return mask, explanation_instance
|
160 |
+
|
161 |
+
|
162 |
+
def explain_image_shap(img, model, class_names, top_prediction, max_evals=1000, batch_size=50):
|
163 |
+
# Generate SHAP explanations for the given image and model
|
164 |
+
masker = shap.maskers.Image("inpaint_telea", img[0].shape) # Update if necessary
|
165 |
+
|
166 |
+
# Define a function to predict probabilities from the model
|
167 |
+
def f(X):
|
168 |
+
return model.predict(X)
|
169 |
+
|
170 |
+
# Create the SHAP explainer
|
171 |
+
explainer = shap.Explainer(f, masker, output_names=class_names)
|
172 |
+
|
173 |
+
# Get SHAP values
|
174 |
+
shap_values = explainer(img, max_evals=max_evals, batch_size=batch_size, outputs=shap.Explanation.argsort.flip[:1])
|
175 |
+
|
176 |
+
return shap_values
|
177 |
+
|
178 |
+
def classify_image_and_explain(image_path, model_path, train_directory, num_samples, num_features, segmentation_alg, kernel_size, max_dist, ratio, max_evals, batch_size, explainer_types, output_folder):
|
179 |
+
# Main function to classify the image and generate explanations
|
180 |
+
global image_counter
|
181 |
+
|
182 |
+
if output_folder is None:
|
183 |
+
output_folder = "explanations"
|
184 |
+
if not os.path.exists(output_folder):
|
185 |
+
os.makedirs(output_folder)
|
186 |
+
|
187 |
+
model, last_conv_layer_name, input_shape = load_model_details(model_path)
|
188 |
+
label_encoder = load_label_encoder(train_directory)
|
189 |
+
labels = list(label_encoder.values())
|
190 |
+
|
191 |
+
# Load the image
|
192 |
+
image = load_img(image_path, target_size=input_shape)
|
193 |
+
if image.mode != 'RGB':
|
194 |
+
image = image.convert('RGB')
|
195 |
+
array = img_to_array(image)
|
196 |
+
img_array = array / 255.0
|
197 |
+
img_array = np.expand_dims(img_array, axis=0)
|
198 |
+
|
199 |
+
# Predict the class of the image
|
200 |
+
predictions = model.predict(img_array)
|
201 |
+
top_prediction = np.argmax(predictions[0])
|
202 |
+
top_label = label_encoder[top_prediction]
|
203 |
+
|
204 |
+
print(f"Prediction: {top_label} with probability {predictions[0][top_prediction]:.4f}")
|
205 |
+
|
206 |
+
# Generate explanations based on user-specified types
|
207 |
+
if 'gradcam' in explainer_types:
|
208 |
+
model.layers[-1].activation = None
|
209 |
+
heatmap = make_gradcam_heatmap(img_array, model, last_conv_layer_name)
|
210 |
+
gradcam_image = save_and_display_gradcam(img_to_array(image), heatmap)
|
211 |
+
gradcam_image.save(os.path.join(output_folder, f"gradcam_{image_counter}.png"))
|
212 |
+
|
213 |
+
if 'lime' in explainer_types:
|
214 |
+
# SPLIME Explanation
|
215 |
+
explainer = LimeImageExplainer()
|
216 |
+
splime_mask, explanation_instance = generate_splime_mask_top_n(img_array[0], model, explainer, top_n=1, num_features=num_features, num_samples=num_samples)
|
217 |
+
# Ensure splime_mask is in [0, 1] range before saving
|
218 |
+
splime_mask = np.clip(splime_mask, 0, 1)
|
219 |
+
plt.imsave(os.path.join(output_folder, f"splime_{image_counter}.png"), splime_mask)
|
220 |
+
|
221 |
+
if 'shap' in explainer_types:
|
222 |
+
custom_image = img_to_array(image) / 255.0 # Preprocess image for SHAP
|
223 |
+
shap_values = explain_image_shap(custom_image.reshape(1, *custom_image.shape), model, labels, top_prediction, max_evals=max_evals, batch_size=batch_size)
|
224 |
+
shap.image_plot(shap_values[0], custom_image, labels=[top_label], show=False)
|
225 |
+
plt.savefig(os.path.join(output_folder, f"shap_{image_counter}.png"))
|
226 |
+
#plt.show()
|
227 |
+
plt.close()
|
228 |
+
|
229 |
+
print("Image classification and explanation process completed.")
|
230 |
+
image_counter += 1
|
231 |
+
|
232 |
+
if __name__ == "__main__":
|
233 |
+
parser = argparse.ArgumentParser(description="Image classification and explanation script")
|
234 |
+
parser.add_argument("--image_path", type=str, required=True, help="Path to the input image")
|
235 |
+
parser.add_argument("--model_path", type=str, required=True, help="Path to the trained model")
|
236 |
+
parser.add_argument("--train_directory", type=str, required=True, help="Directory containing training images")
|
237 |
+
parser.add_argument("--num_samples", type=int, default=300, help="Number of samples for LIME")
|
238 |
+
parser.add_argument("--num_features", type=int, default=100, help="Number of features for LIME")
|
239 |
+
parser.add_argument("--segmentation_alg", type=str, default='quickshift', help="Segmentation algorithm for LIME (options: quickshift, slic)")
|
240 |
+
parser.add_argument("--kernel_size", type=int, default=4, help="Kernel size for segmentation algorithm")
|
241 |
+
parser.add_argument("--max_dist", type=int, default=200, help="Max distance for segmentation algorithm")
|
242 |
+
parser.add_argument("--ratio", type=float, default=0.2, help="Ratio for segmentation algorithm")
|
243 |
+
parser.add_argument("--max_evals", type=int, default=400, help="Maximum evaluations for SHAP")
|
244 |
+
parser.add_argument("--batch_size", type=int, default=50, help="Batch size for SHAP")
|
245 |
+
parser.add_argument("--explainer_types", type=str, default='all', help="Comma-separated list of explainers to use (options: lime, shap, gradcam). Use 'all' to include all three.")
|
246 |
+
parser.add_argument("--output_folder", type=str, default=None, help="Output folder for explanations")
|
247 |
+
|
248 |
+
args = parser.parse_args()
|
249 |
+
|
250 |
+
explainer_types = args.explainer_types.split(',') if args.explainer_types != 'all' else ['lime', 'shap', 'gradcam']
|
251 |
+
|
252 |
+
classify_image_and_explain(
|
253 |
+
args.image_path, args.model_path, args.train_directory, args.num_samples,
|
254 |
+
args.num_features, args.segmentation_alg, args.kernel_size, args.max_dist,
|
255 |
+
args.ratio, args.max_evals, args.batch_size, explainer_types, args.output_folder
|
256 |
+
)
|
data_loader.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
import os
|
3 |
+
import argparse
|
4 |
+
from sklearn.model_selection import StratifiedShuffleSplit
|
5 |
+
from tqdm import tqdm # For progress display
|
6 |
+
import sys
|
7 |
+
import uuid # Import uuid for unique filename generation
|
8 |
+
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
9 |
+
|
10 |
+
def parse_arguments():
|
11 |
+
parser = argparse.ArgumentParser(description='Image Data Loader with Augmentation and Splits')
|
12 |
+
parser.add_argument('--path', type=str, required=True, help='Path to the folder containing images')
|
13 |
+
parser.add_argument('--dim', type=int, default=224, help='Required image dimension')
|
14 |
+
parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
|
15 |
+
parser.add_argument('--num_workers', type=int, default=4, help='Number of workers for data loading')
|
16 |
+
parser.add_argument('--target_folder', type=str, required=True, help='Folder to store the train, test, and val splits')
|
17 |
+
parser.add_argument('--augment_data', action='store_true', help='Apply data augmentation')
|
18 |
+
return parser.parse_args()
|
19 |
+
|
20 |
+
def create_datagens():
|
21 |
+
# Create a list of ImageDataGenerator objects for different augmentations
|
22 |
+
return [
|
23 |
+
ImageDataGenerator(rescale=1./255),
|
24 |
+
ImageDataGenerator(rotation_range=20),
|
25 |
+
ImageDataGenerator(width_shift_range=0.2),
|
26 |
+
ImageDataGenerator(height_shift_range=0.2),
|
27 |
+
ImageDataGenerator(horizontal_flip=True)
|
28 |
+
]
|
29 |
+
|
30 |
+
def process_image(file_path, image_size):
|
31 |
+
# Read, decode, resize, and normalize an image
|
32 |
+
file_path = file_path.numpy().decode('utf-8')
|
33 |
+
image = tf.io.read_file(file_path)
|
34 |
+
image = tf.image.decode_image(image, channels=3, dtype=tf.float32)
|
35 |
+
image = tf.image.resize(image, [image_size, image_size])
|
36 |
+
image = tf.clip_by_value(image, 0.0, 1.0)
|
37 |
+
return image
|
38 |
+
|
39 |
+
def save_image(image, file_path):
|
40 |
+
# Convert image to uint8, encode as JPEG, and save to file
|
41 |
+
image = tf.image.convert_image_dtype(image, dtype=tf.uint8)
|
42 |
+
image = tf.image.encode_jpeg(image)
|
43 |
+
tf.io.write_file(file_path, image)
|
44 |
+
|
45 |
+
def load_data(path, image_size, batch_size):
|
46 |
+
all_images = []
|
47 |
+
labels = []
|
48 |
+
# Load images and labels from the specified path
|
49 |
+
|
50 |
+
for subdir, _, files in os.walk(path):
|
51 |
+
label = os.path.basename(subdir)
|
52 |
+
for fname in files:
|
53 |
+
if fname.endswith(('.jpg', '.jpeg', '.png')):
|
54 |
+
all_images.append(os.path.join(subdir, fname))
|
55 |
+
labels.append(label)
|
56 |
+
|
57 |
+
unique_labels = set(labels)
|
58 |
+
print(f"Found {len(all_images)} images in {path}\n")
|
59 |
+
print(f"Labels found ({len(unique_labels)}): {unique_labels}\n")
|
60 |
+
|
61 |
+
# Raise an error if no images are found
|
62 |
+
if len(all_images) == 0:
|
63 |
+
raise ValueError(f"No images found in the specified path: {path}")
|
64 |
+
|
65 |
+
# Stratified splitting the dataset
|
66 |
+
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
|
67 |
+
train_indices, test_indices = next(sss.split(all_images, labels))
|
68 |
+
|
69 |
+
train_files = [all_images[i] for i in train_indices]
|
70 |
+
train_labels = [labels[i] for i in train_indices]
|
71 |
+
test_files = [all_images[i] for i in test_indices]
|
72 |
+
test_labels = [labels[i] for i in test_indices]
|
73 |
+
|
74 |
+
sss_val = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=42)
|
75 |
+
val_indices, test_indices = next(sss_val.split(test_files, test_labels))
|
76 |
+
|
77 |
+
val_files = [test_files[i] for i in val_indices]
|
78 |
+
val_labels = [test_labels[i] for i in val_indices]
|
79 |
+
test_files = [test_files[i] for i in test_indices]
|
80 |
+
test_labels = [test_labels[i] for i in test_indices]
|
81 |
+
|
82 |
+
print(f"Data split into {len(train_files)} train, {len(val_files)} validation, and {len(test_files)} test images.\n")
|
83 |
+
|
84 |
+
# Define a function to load and augment images
|
85 |
+
def tf_load_and_augment_image(file_path, label):
|
86 |
+
image = tf.py_function(func=lambda x: process_image(x, image_size), inp=[file_path], Tout=tf.float32)
|
87 |
+
image.set_shape([image_size, image_size, 3])
|
88 |
+
return image, label
|
89 |
+
|
90 |
+
train_dataset = tf.data.Dataset.from_tensor_slices((train_files, train_labels))
|
91 |
+
val_dataset = tf.data.Dataset.from_tensor_slices((val_files, val_labels))
|
92 |
+
# Create datasets from the loaded files and labels
|
93 |
+
test_dataset = tf.data.Dataset.from_tensor_slices((test_files, test_labels))
|
94 |
+
|
95 |
+
train_dataset = train_dataset.map(lambda x, y: tf_load_and_augment_image(x, y))
|
96 |
+
val_dataset = val_dataset.map(lambda x, y: tf_load_and_augment_image(x, y))
|
97 |
+
test_dataset = test_dataset.map(lambda x, y: tf_load_and_augment_image(x, y))
|
98 |
+
|
99 |
+
train_dataset = train_dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
|
100 |
+
val_dataset = val_dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
|
101 |
+
test_dataset = test_dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
|
102 |
+
|
103 |
+
return train_dataset, val_dataset, test_dataset
|
104 |
+
|
105 |
+
def save_datasets_to_folders(dataset, folder_path, datagens=None):
|
106 |
+
# Save the dataset to specified folders with optional augmentations
|
107 |
+
if not os.path.exists(folder_path):
|
108 |
+
os.makedirs(folder_path)
|
109 |
+
|
110 |
+
count = 0
|
111 |
+
for batch_images, batch_labels in tqdm(dataset, desc=f"Saving to {folder_path}"):
|
112 |
+
for i in range(batch_images.shape[0]):
|
113 |
+
image = batch_images[i].numpy()
|
114 |
+
label = batch_labels[i].numpy().decode('utf-8')
|
115 |
+
label_folder = os.path.join(folder_path, label)
|
116 |
+
if not os.path.exists(label_folder):
|
117 |
+
os.makedirs(label_folder)
|
118 |
+
|
119 |
+
# Save the original image
|
120 |
+
file_path = os.path.join(label_folder, f"{uuid.uuid4().hex}.jpg")
|
121 |
+
save_image(image, file_path)
|
122 |
+
count += 1
|
123 |
+
|
124 |
+
# Apply augmentations if datagens are provided
|
125 |
+
if datagens:
|
126 |
+
for datagen in datagens:
|
127 |
+
aug_image = datagen.random_transform(image)
|
128 |
+
file_path = os.path.join(label_folder, f"{uuid.uuid4().hex}.jpg")
|
129 |
+
save_image(aug_image, file_path)
|
130 |
+
count += 1
|
131 |
+
|
132 |
+
print(f"Saved {count} images to {folder_path}\n")
|
133 |
+
return count
|
134 |
+
|
135 |
+
def main():
|
136 |
+
# Main function to parse arguments, load data, and save datasets
|
137 |
+
args = parse_arguments()
|
138 |
+
|
139 |
+
if not os.path.exists(args.target_folder):
|
140 |
+
os.makedirs(args.target_folder)
|
141 |
+
|
142 |
+
train_folder = os.path.join(args.target_folder, 'train')
|
143 |
+
val_folder = os.path.join(args.target_folder, 'val')
|
144 |
+
test_folder = os.path.join(args.target_folder, 'test')
|
145 |
+
|
146 |
+
datagens = create_datagens() if args.augment_data else None
|
147 |
+
|
148 |
+
train_dataset, val_dataset, test_dataset = load_data(
|
149 |
+
args.path,
|
150 |
+
args.dim,
|
151 |
+
args.batch_size
|
152 |
+
)
|
153 |
+
|
154 |
+
# Save datasets to respective folders and count images
|
155 |
+
train_count = save_datasets_to_folders(train_dataset, train_folder, datagens)
|
156 |
+
val_count = save_datasets_to_folders(val_dataset, val_folder)
|
157 |
+
test_count = save_datasets_to_folders(test_dataset, test_folder)
|
158 |
+
|
159 |
+
print(f"Train dataset saved to: {train_folder}\n")
|
160 |
+
print(f"Validation dataset saved to: {val_folder}\n")
|
161 |
+
print(f"Test dataset saved to: {test_folder}\n")
|
162 |
+
|
163 |
+
print('-'*20)
|
164 |
+
|
165 |
+
print(f"Number of images in training set: {train_count}\n")
|
166 |
+
print(f"Number of images in validation set: {val_count}\n")
|
167 |
+
print(f"Number of images in test set: {test_count}\n")
|
168 |
+
|
169 |
+
if __name__ == "__main__":
|
170 |
+
# Redirect stdout and stderr to avoid encoding issues
|
171 |
+
sys.stdout = open(sys.stdout.fileno(), mode='w', encoding='utf-8', buffering=1)
|
172 |
+
sys.stderr = open(sys.stderr.fileno(), mode='w', encoding='utf-8', buffering=1)
|
173 |
+
main()
|
predict.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import numpy as np
|
4 |
+
import tensorflow as tf
|
5 |
+
from tensorflow.keras.preprocessing import image
|
6 |
+
from tensorflow.keras.models import load_model
|
7 |
+
|
8 |
+
def load_and_preprocess_image(img_path, target_size):
|
9 |
+
# Load and preprocess the image for prediction.
|
10 |
+
"""Load and preprocess the image for prediction."""
|
11 |
+
img = image.load_img(img_path, target_size=target_size)
|
12 |
+
img_array = image.img_to_array(img)
|
13 |
+
img_array = np.expand_dims(img_array, axis=0) # Create batch axis
|
14 |
+
img_array = img_array / 255.0 # Normalize the image
|
15 |
+
return img_array
|
16 |
+
|
17 |
+
def load_model_from_file(model_path):
|
18 |
+
# Load the pre-trained model from the specified path.
|
19 |
+
"""Load the pre-trained model from the specified path."""
|
20 |
+
model = load_model(model_path)
|
21 |
+
print(f"Model loaded from {model_path}")
|
22 |
+
return model
|
23 |
+
|
24 |
+
def make_predictions(model, img_array):
|
25 |
+
# Make predictions using the loaded model.
|
26 |
+
"""Make predictions using the loaded model."""
|
27 |
+
predictions = model.predict(img_array)
|
28 |
+
return predictions
|
29 |
+
|
30 |
+
def get_class_names(train_dir):
|
31 |
+
"""Get class names from training directory."""
|
32 |
+
class_names = os.listdir(train_dir) # Assuming subfolder names are the class labels
|
33 |
+
class_names.sort() # Ensure consistent ordering
|
34 |
+
return class_names
|
35 |
+
|
36 |
+
def main(model_path, img_path, train_dir):
|
37 |
+
# Main function to load model, preprocess image, make predictions, and display results.
|
38 |
+
# Define target image size based on model requirements
|
39 |
+
target_size = (224, 224) # Adjust if needed
|
40 |
+
|
41 |
+
# Load the model
|
42 |
+
model = load_model_from_file(model_path)
|
43 |
+
|
44 |
+
# Get class names from train directory
|
45 |
+
class_names = get_class_names(train_dir)
|
46 |
+
|
47 |
+
# Load and preprocess the image
|
48 |
+
img_array = load_and_preprocess_image(img_path, target_size)
|
49 |
+
|
50 |
+
# Make predictions
|
51 |
+
predictions = make_predictions(model, img_array)
|
52 |
+
predicted_label_index = np.argmax(predictions, axis=1)[0]
|
53 |
+
predicted_label = class_names[predicted_label_index]
|
54 |
+
probability_score = predictions[0][predicted_label_index]
|
55 |
+
|
56 |
+
print(f"Predicted label: {predicted_label}, Probability: {probability_score:.4f}")
|
57 |
+
|
58 |
+
if __name__ == "__main__":
|
59 |
+
parser = argparse.ArgumentParser(description="Load a pre-trained model and make a prediction on a new image")
|
60 |
+
parser.add_argument('--model_path', type=str, required=True, help='Path to the saved model')
|
61 |
+
parser.add_argument('--img_path', type=str, required=True, help='Path to the image to be predicted')
|
62 |
+
parser.add_argument('--train_dir', type=str, required=True, help='Directory containing training dataset for inferring class names')
|
63 |
+
|
64 |
+
args = parser.parse_args()
|
65 |
+
main(args.model_path, args.img_path, args.train_dir)
|
requirements.txt
ADDED
Binary file (520 Bytes). View file
|
|
test.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import numpy as np
|
4 |
+
import tensorflow as tf
|
5 |
+
from tensorflow.keras.preprocessing import image
|
6 |
+
from tensorflow.keras.models import load_model
|
7 |
+
from sklearn.metrics import classification_report, confusion_matrix
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import seaborn as sns
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
# Load and preprocess an image for prediction
|
13 |
+
def load_and_preprocess_image(img_path, target_size):
|
14 |
+
"""Load and preprocess the image for prediction."""
|
15 |
+
img = image.load_img(img_path, target_size=target_size)
|
16 |
+
img_array = image.img_to_array(img)
|
17 |
+
img_array = np.expand_dims(img_array, axis=0) # Create batch axis
|
18 |
+
img_array = img_array / 255.0 # Normalize the image
|
19 |
+
return img_array
|
20 |
+
|
21 |
+
# Load all models from a specified directory
|
22 |
+
def load_all_models(model_dir):
|
23 |
+
"""Load all models from the specified directory."""
|
24 |
+
models = {}
|
25 |
+
for file_name in os.listdir(model_dir):
|
26 |
+
if file_name.endswith('_model.keras'):
|
27 |
+
model_path = os.path.join(model_dir, file_name)
|
28 |
+
model_name = file_name.split('_model.keras')[0] # Extract model name
|
29 |
+
model = load_model(model_path)
|
30 |
+
models[model_name] = model
|
31 |
+
print(f"Model loaded from {model_path}")
|
32 |
+
if not models:
|
33 |
+
raise FileNotFoundError(f"No model files found in {model_dir}.")
|
34 |
+
return models
|
35 |
+
|
36 |
+
# Load a single model from a specified path
|
37 |
+
def load_model_from_file(model_path):
|
38 |
+
"""Load a single model from the specified path."""
|
39 |
+
model = load_model(model_path)
|
40 |
+
print(f"Model loaded from {model_path}")
|
41 |
+
return model
|
42 |
+
|
43 |
+
def make_predictions(model, img_array):
|
44 |
+
# Make predictions using the loaded model
|
45 |
+
"""Make predictions using the loaded model."""
|
46 |
+
predictions = model.predict(img_array)
|
47 |
+
return predictions
|
48 |
+
|
49 |
+
def get_class_names(train_dir):
|
50 |
+
"""Get class names from training directory."""
|
51 |
+
class_names = os.listdir(train_dir) # Assuming subfolder names are the class labels
|
52 |
+
class_names.sort() # Ensure consistent ordering
|
53 |
+
return class_names
|
54 |
+
|
55 |
+
# Compute confusion matrix and classification report, and save to log directory
|
56 |
+
def compute_confusion_matrix_and_report(true_labels, predicted_labels, class_names, log_dir, model_name):
|
57 |
+
"""Compute confusion matrix and classification report, and save to log directory."""
|
58 |
+
# Compute confusion matrix
|
59 |
+
conf_matrix = confusion_matrix(true_labels, predicted_labels, labels=class_names)
|
60 |
+
report = classification_report(true_labels, predicted_labels, target_names=class_names)
|
61 |
+
|
62 |
+
# Print the classification report
|
63 |
+
print(f"Model: {model_name}")
|
64 |
+
print(report)
|
65 |
+
|
66 |
+
# Plot the confusion matrix
|
67 |
+
plt.figure(figsize=(10, 8))
|
68 |
+
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
|
69 |
+
plt.xlabel('Predicted Label')
|
70 |
+
plt.ylabel('True Label')
|
71 |
+
plt.title(f'Confusion Matrix - {model_name}')
|
72 |
+
|
73 |
+
# Save plot
|
74 |
+
if not os.path.exists(log_dir):
|
75 |
+
os.makedirs(log_dir)
|
76 |
+
|
77 |
+
conf_matrix_plot_file = os.path.join(log_dir, f'confusion_matrix_{model_name}.png')
|
78 |
+
plt.savefig(conf_matrix_plot_file)
|
79 |
+
plt.close()
|
80 |
+
|
81 |
+
# Save results to log directory
|
82 |
+
conf_matrix_file = os.path.join(log_dir, f'confusion_matrix_{model_name}.txt')
|
83 |
+
report_file = os.path.join(log_dir, f'classification_report_{model_name}.txt')
|
84 |
+
|
85 |
+
np.savetxt(conf_matrix_file, conf_matrix, fmt='%d', delimiter=',', header=','.join(class_names))
|
86 |
+
with open(report_file, 'w') as f:
|
87 |
+
f.write(report)
|
88 |
+
|
89 |
+
print(f"Confusion matrix and classification report saved to {log_dir} with model name: {model_name}")
|
90 |
+
|
91 |
+
# Main function to load models, make predictions, and evaluate performance
|
92 |
+
def main(model_path, model_dir, img_path, test_dir, train_dir, log_dir):
|
93 |
+
# Define target image size based on model requirements
|
94 |
+
target_size = (224, 224) # Adjust if needed
|
95 |
+
|
96 |
+
if model_path:
|
97 |
+
# Load a single model
|
98 |
+
model = load_model_from_file(model_path)
|
99 |
+
models = {os.path.basename(model_path): model}
|
100 |
+
elif model_dir:
|
101 |
+
# Load all models from a directory
|
102 |
+
models = load_all_models(model_dir)
|
103 |
+
else:
|
104 |
+
raise ValueError("Either --model_path or --model_dir must be provided.")
|
105 |
+
|
106 |
+
# Get class names from train directory
|
107 |
+
class_names = get_class_names(train_dir)
|
108 |
+
num_classes = len(class_names)
|
109 |
+
|
110 |
+
# If an image path is provided, perform prediction on that image
|
111 |
+
if img_path:
|
112 |
+
img_array = load_and_preprocess_image(img_path, target_size)
|
113 |
+
for model_name, model in models.items():
|
114 |
+
print(f"Model: {model_name}")
|
115 |
+
predictions = make_predictions(model, img_array)
|
116 |
+
predicted_label_index = np.argmax(predictions, axis=1)[0]
|
117 |
+
if predicted_label_index >= num_classes:
|
118 |
+
raise ValueError(f"Predicted label index {predicted_label_index} is out of range for class names list.")
|
119 |
+
predicted_label = class_names[predicted_label_index]
|
120 |
+
probability_score = predictions[0][predicted_label_index]
|
121 |
+
print('-'*20)
|
122 |
+
print(f"Predicted label: {predicted_label}, Probability: {probability_score:.4f}")
|
123 |
+
print('-'*20)
|
124 |
+
|
125 |
+
# If a test directory is provided, perform batch predictions and evaluation
|
126 |
+
if test_dir:
|
127 |
+
files = [os.path.join(root, file) for root, _, files in os.walk(test_dir) for file in files if file.endswith(('png', 'jpg', 'jpeg'))]
|
128 |
+
|
129 |
+
for model_name, model in models.items():
|
130 |
+
true_labels = []
|
131 |
+
predicted_labels = []
|
132 |
+
|
133 |
+
for img_path in tqdm(files, desc=f"Processing images with {model_name}"):
|
134 |
+
img_array = load_and_preprocess_image(img_path, target_size)
|
135 |
+
predictions = make_predictions(model, img_array)
|
136 |
+
predicted_label_index = np.argmax(predictions, axis=1)[0]
|
137 |
+
if predicted_label_index >= num_classes:
|
138 |
+
raise ValueError(f"Predicted label index {predicted_label_index} is out of range for class names list.")
|
139 |
+
predicted_label = class_names[predicted_label_index]
|
140 |
+
|
141 |
+
true_label = os.path.basename(os.path.dirname(img_path)) # Assuming folder name is the label
|
142 |
+
if true_label not in class_names:
|
143 |
+
raise ValueError(f"True label {true_label} is not in class names list.")
|
144 |
+
|
145 |
+
true_labels.append(true_label)
|
146 |
+
predicted_labels.append(predicted_label)
|
147 |
+
|
148 |
+
# Compute and save confusion matrix and classification report
|
149 |
+
compute_confusion_matrix_and_report(true_labels, predicted_labels, class_names, log_dir, model_name)
|
150 |
+
|
151 |
+
if __name__ == "__main__":
|
152 |
+
parser = argparse.ArgumentParser(description="Load models and make predictions on new images or a test dataset")
|
153 |
+
parser.add_argument('--model_path', type=str, help='Path to a single saved model')
|
154 |
+
parser.add_argument('--model_dir', type=str, help='Directory containing saved models (loads all models in the folder)')
|
155 |
+
parser.add_argument('--img_path', type=str, help='Path to the image to be predicted')
|
156 |
+
parser.add_argument('--test_dir', type=str, help='Directory containing test dataset for batch predictions')
|
157 |
+
parser.add_argument('--train_dir', type=str, required=True, help='Directory containing training dataset for inferring class names')
|
158 |
+
parser.add_argument('--log_dir', type=str, required=True, help='Directory to save prediction results')
|
159 |
+
|
160 |
+
args = parser.parse_args()
|
161 |
+
main(args.model_path, args.model_dir, args.img_path, args.test_dir, args.train_dir, args.log_dir)
|
train.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import tensorflow as tf
|
4 |
+
from tensorflow.keras.models import Model
|
5 |
+
from tensorflow.keras.applications import (VGG16, VGG19, ResNet50, ResNet101, InceptionV3,
|
6 |
+
DenseNet121, DenseNet201, MobileNetV2, Xception, InceptionResNetV2,
|
7 |
+
NASNetLarge, NASNetMobile, EfficientNetB0, EfficientNetB7)
|
8 |
+
from tensorflow.keras.layers import Dense, Flatten, Dropout, BatchNormalization
|
9 |
+
from tensorflow.keras.optimizers import Adam, SGD
|
10 |
+
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
def load_and_preprocess_image(filename, label, image_size):
|
14 |
+
# Load image
|
15 |
+
image = tf.io.read_file(filename)
|
16 |
+
image = tf.image.decode_image(image, channels=3)
|
17 |
+
|
18 |
+
# Ensure the image tensor has shape
|
19 |
+
if not tf.executing_eagerly():
|
20 |
+
image.set_shape([None, None, 3])
|
21 |
+
|
22 |
+
# Resize image to the specified size
|
23 |
+
image = tf.image.resize(image, [image_size[0], image_size[1]]) # Use height and width from the tuple
|
24 |
+
|
25 |
+
# Normalize image to [0, 1]
|
26 |
+
image = image / 255.0
|
27 |
+
image.set_shape([image_size[0], image_size[1], 3])
|
28 |
+
|
29 |
+
return image, label
|
30 |
+
|
31 |
+
def create_dataset(data_dir, labels, image_size, batch_size):
|
32 |
+
image_files = []
|
33 |
+
image_labels = []
|
34 |
+
|
35 |
+
for label in labels:
|
36 |
+
label_dir = os.path.join(data_dir, label)
|
37 |
+
for image_file in os.listdir(label_dir):
|
38 |
+
image_files.append(os.path.join(label_dir, image_file))
|
39 |
+
image_labels.append(label)
|
40 |
+
|
41 |
+
# Create a mapping from labels to indices
|
42 |
+
label_map = {label: idx for idx, label in enumerate(labels)}
|
43 |
+
image_labels = [label_map[label] for label in image_labels]
|
44 |
+
|
45 |
+
# Convert to TensorFlow datasets
|
46 |
+
dataset = tf.data.Dataset.from_tensor_slices((image_files, image_labels))
|
47 |
+
dataset = dataset.map(lambda x, y: load_and_preprocess_image(x, y, image_size))
|
48 |
+
dataset = dataset.shuffle(buffer_size=len(image_files))
|
49 |
+
dataset = dataset.batch(batch_size).prefetch(buffer_size=tf.data.AUTOTUNE)
|
50 |
+
|
51 |
+
return dataset
|
52 |
+
|
53 |
+
def create_and_train_model(base_model, model_name, shape, X_train, X_val, num_classes, labels, log_dir, model_dir,
|
54 |
+
epochs, optimizer_name, learning_rate, step_gamma, alpha, batch_size, patience):
|
55 |
+
# Freeze the base model layers
|
56 |
+
for layer in base_model.layers:
|
57 |
+
layer.trainable = False
|
58 |
+
|
59 |
+
# Add custom layers on top
|
60 |
+
x = base_model.output
|
61 |
+
x = Flatten()(x)
|
62 |
+
x = Dense(1024, activation='relu')(x)
|
63 |
+
x = Dropout(0.25)(x)
|
64 |
+
|
65 |
+
x = Dense(512, activation='relu')(x)
|
66 |
+
x = Dropout(0.25)(x)
|
67 |
+
|
68 |
+
x = Dense(256, activation='relu')(x)
|
69 |
+
x = BatchNormalization()(x)
|
70 |
+
x = Dropout(0.25)(x)
|
71 |
+
|
72 |
+
predictions = Dense(num_classes, activation='softmax')(x) # Use the number of classes
|
73 |
+
model = Model(inputs=base_model.input, outputs=predictions)
|
74 |
+
|
75 |
+
# Learning rate schedule
|
76 |
+
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
|
77 |
+
initial_learning_rate=learning_rate,
|
78 |
+
decay_steps=1000, # Adjust this according to your needs
|
79 |
+
decay_rate=step_gamma
|
80 |
+
)
|
81 |
+
|
82 |
+
# Select the optimizer
|
83 |
+
if optimizer_name.lower() == 'adam':
|
84 |
+
optimizer = Adam(learning_rate=lr_schedule)
|
85 |
+
elif optimizer_name.lower() == 'sgd':
|
86 |
+
optimizer = SGD(learning_rate=lr_schedule, momentum=alpha) # Example settings for SGD
|
87 |
+
else:
|
88 |
+
raise ValueError(f"Unsupported optimizer: {optimizer_name}")
|
89 |
+
|
90 |
+
# Compile the model
|
91 |
+
model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])
|
92 |
+
|
93 |
+
# Set up callbacks
|
94 |
+
checkpoint = ModelCheckpoint(os.path.join(model_dir, f'{model_name}_best_model.keras'),
|
95 |
+
monitor='val_accuracy', save_best_only=True, save_weights_only=False,
|
96 |
+
mode='max', verbose=1)
|
97 |
+
early_stopping = EarlyStopping(monitor='val_accuracy', patience=patience, verbose=1)
|
98 |
+
|
99 |
+
# Train the model
|
100 |
+
history = model.fit(X_train, epochs=epochs, validation_data=X_val, batch_size=batch_size,
|
101 |
+
callbacks=[checkpoint, early_stopping])
|
102 |
+
|
103 |
+
# Save training logs
|
104 |
+
with open(os.path.join(log_dir, f'{model_name}_training.log'), 'w') as f:
|
105 |
+
num_epochs = len(history.history['loss']) # Get the actual number of epochs completed
|
106 |
+
for epoch in range(num_epochs):
|
107 |
+
f.write(f"Epoch {epoch + 1}, "
|
108 |
+
f"Train Loss: {history.history['loss'][epoch]:.4f}, "
|
109 |
+
f"Train Accuracy: {history.history['accuracy'][epoch]:.4f}, "
|
110 |
+
f"Val Loss: {history.history['val_loss'][epoch]:.4f}, "
|
111 |
+
f"Val Accuracy: {history.history['val_accuracy'][epoch]:.4f}\n")
|
112 |
+
|
113 |
+
# Save labels in the model directory
|
114 |
+
with open(os.path.join(model_dir, 'labels.txt'), 'w') as f:
|
115 |
+
f.write('\n'.join(labels))
|
116 |
+
|
117 |
+
# Evaluate the model
|
118 |
+
test_loss, test_accuracy = model.evaluate(X_val)
|
119 |
+
print(f'Test Accuracy for {model_name}: {test_accuracy:.4f}')
|
120 |
+
print(f'Test Loss for {model_name}: {test_loss:.4f}')
|
121 |
+
|
122 |
+
# Optionally, save the trained model
|
123 |
+
model.save(os.path.join(model_dir, f'{model_name}_final_model.keras'))
|
124 |
+
|
125 |
+
def main(base_model_names, shape, data_path, log_dir, model_dir, epochs, optimizer, learning_rate, step_gamma, alpha, batch_size, patience):
|
126 |
+
if not os.path.exists(log_dir):
|
127 |
+
os.makedirs(log_dir)
|
128 |
+
if not os.path.exists(model_dir):
|
129 |
+
os.makedirs(model_dir)
|
130 |
+
|
131 |
+
# Extract labels from folder names
|
132 |
+
labels = sorted([d for d in os.listdir(os.path.join(data_path, 'train')) if os.path.isdir(os.path.join(data_path, 'train', d))])
|
133 |
+
num_classes = len(labels)
|
134 |
+
|
135 |
+
# Load data
|
136 |
+
X_train = create_dataset(os.path.join(data_path, 'train'), labels, shape, batch_size)
|
137 |
+
X_val = create_dataset(os.path.join(data_path, 'val'), labels, shape, batch_size)
|
138 |
+
|
139 |
+
if not base_model_names:
|
140 |
+
print("No base models specified. Exiting.")
|
141 |
+
return
|
142 |
+
|
143 |
+
# Define base models
|
144 |
+
base_models_dict = {
|
145 |
+
model_name: globals()[model_name](weights='imagenet', include_top=False, input_shape=shape)
|
146 |
+
for model_name in base_model_names
|
147 |
+
}
|
148 |
+
|
149 |
+
for model_name in base_model_names:
|
150 |
+
print(f'Training {model_name}...')
|
151 |
+
base_model = base_models_dict.get(model_name)
|
152 |
+
if base_model is None:
|
153 |
+
print(f"Model {model_name} not supported.")
|
154 |
+
continue
|
155 |
+
create_and_train_model(base_model, model_name, shape, X_train, X_val, num_classes, labels, log_dir, model_dir,
|
156 |
+
epochs, optimizer, learning_rate, step_gamma, alpha, batch_size, patience)
|
157 |
+
|
158 |
+
if __name__ == "__main__":
|
159 |
+
parser = argparse.ArgumentParser(description="Train models using transfer learning")
|
160 |
+
parser.add_argument('--base_models', type=str, nargs='+', default=[],
|
161 |
+
help='List of base models to use for training. Leave empty to skip model training.')
|
162 |
+
parser.add_argument('--shape', type=int, nargs=3, default=(224, 224, 3), help='Input shape of the images')
|
163 |
+
parser.add_argument('--data_path', type=str, required=True, help='Path to the image data')
|
164 |
+
parser.add_argument('--log_dir', type=str, required=True, help='Directory to save logs')
|
165 |
+
parser.add_argument('--model_dir', type=str, required=True, help='Directory to save models')
|
166 |
+
parser.add_argument('--epochs', type=int, default=100, help='Number of epochs to train')
|
167 |
+
parser.add_argument('--optimizer', type=str, default='adam', help='Optimizer to use (adam or sgd)')
|
168 |
+
parser.add_argument('--learning_rate', type=float, default=0.001, help='Initial learning rate')
|
169 |
+
parser.add_argument('--step_gamma', type=float, default=0.96, help='Gamma value for step learning rate schedule')
|
170 |
+
parser.add_argument('--alpha', type=float, default=0.9, help='Alpha for the optimizer (used for SGD)')
|
171 |
+
parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training')
|
172 |
+
parser.add_argument('--patience', type=int, default=10, help='Patience for early stopping')
|
173 |
+
|
174 |
+
args = parser.parse_args()
|
175 |
+
main(args.base_models, tuple(args.shape), args.data_path, args.log_dir, args.model_dir,
|
176 |
+
args.epochs, args.optimizer, args.learning_rate, args.step_gamma, args.alpha, args.batch_size, args.patience)
|