Purushothamann commited on
Commit
ffd6b68
·
verified ·
1 Parent(s): fc9772b

Upload 9 files

Browse files

uploaded the codes and sample models

Files changed (9) hide show
  1. LICENSE +21 -0
  2. README.md +266 -3
  3. balanced_data_loader-1.py +216 -0
  4. classify_image_and_explain.py +256 -0
  5. data_loader.py +173 -0
  6. predict.py +65 -0
  7. requirements.txt +0 -0
  8. test.py +161 -0
  9. train.py +176 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Purushothaman
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,3 +1,266 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Interpretable-SONAR-Image-Classifier
2
+
3
+ Explainable AI for Underwater SONAR Image Classifier
4
+
5
+ ## Prerequisites
6
+
7
+ - Python 3.6 or higher
8
+
9
+ ## Running the Scripts
10
+
11
+ This guide will help you run the `data_loader.py`, `train.py`, `test.py`, `predict.py`, and `classify_image_and_explain.py` scripts directly from the command line or within a Python script.
12
+
13
+ ### Prerequisites
14
+
15
+ 1. **Python Installation**: Ensure you have Python installed. You can download it from [python.org](https://www.python.org/).
16
+
17
+ 2. **Required Packages**: Install the required packages using `requirements.txt`.
18
+ ```sh
19
+ pip install -r requirements.txt
20
+ ```
21
+
22
+ ### Script Descriptions and Usage
23
+
24
+ #### 1. `data_loader.py`
25
+
26
+ This script is used to load, process, split datasets (train, val, test), and augment data.
27
+
28
+ **Command Line Usage:**
29
+
30
+ ```sh
31
+ python data_loader.py --path <path_to_data> --target_folder <path_to_target_folder> --dim <dimension> --batch_size <batch_size> --num_workers <num_workers> [--augment_data]
32
+ ```
33
+
34
+ **Arguments:**
35
+
36
+ - `--path`: Path to the data.
37
+ - `--target_folder`: Path to the target folder where processed data will be saved.
38
+ - `--dim`: Dimension for resizing the images.
39
+ - `--batch_size`: Batch size for data loading.
40
+ - `--num_workers`: Number of workers for data loading.
41
+ - `--augment_data` (optional): Flag to enable data augmentation.
42
+
43
+ **Example:**
44
+
45
+ ```sh
46
+ python data_loader.py --path "./dataset" --target_folder "./processed_data" --dim 224 --batch_size 32 --num_workers 4 --augment_data
47
+ ```
48
+
49
+ **Dataset Structure:**
50
+
51
+ ```sh
52
+ ├── Dataset (Raw)
53
+    ├── class_name_1
54
+ │   └── *.jpg
55
+    ├── class_name_2
56
+ │   └── *.jpg
57
+    ├── class_name_3
58
+ │   └── *.jpg
59
+    └── class_name_4
60
+    └── *.jpg
61
+ ```
62
+
63
+ #### 2. `train.py`
64
+
65
+ This script is used for training and storing the models, leveraging transfer learning.
66
+
67
+ **Command Line Usage:**
68
+
69
+ ```sh
70
+ python train.py --base_model_names <model_names> --shape <shape> --data_path <data_path> --log_dir <log_dir> --model_dir <model_dir> --epochs <epochs> --optimizer <optimizer> --learning_rate <learning_rate> --batch_size <batch_size>
71
+ ```
72
+
73
+ **Arguments:**
74
+
75
+ - `--base_models`: Comma-separated list of base model names (e.g., 'VGG16, ResNet50').
76
+ - `--shape`: Image shape (size).
77
+ - `--data_path`: Path to the data.
78
+ - `--log_dir`: Path to the log directory.
79
+ - `--model_dir`: Path to the model directory.
80
+ - `--epochs`: Number of training epochs.
81
+ - `--optimizer`: Optimizer type ('adam' or 'sgd').
82
+ - `--learning_rate`: Learning rate for the optimizer.
83
+ - `--batch_size`: Batch size for training.
84
+ - `--patience`: Patience for early stopping (to prevent overfitting).
85
+
86
+ **Example:**
87
+
88
+ ```sh
89
+ python train.py --base_models "VGG16" "DenseNet121" --shape 224 224 3 --data_path "./processed_data" --log_dir "./logs" --model_dir "./models" --epochs 100 --optimizer "adam" --learning_rate 0.0001 --batch_size 32
90
+ ```
91
+
92
+ #### 3. `test.py`
93
+
94
+ This script is used for testing and storing the test logs of the above-trained models.
95
+
96
+ **Command Line Usage:**
97
+
98
+ ```sh
99
+ python test.py --data_path <data_path> --base_model_name <base_model_name> --model_path <model_path> --models_folder_path <models_folder_path> --log_dir <log_dir>
100
+ ```
101
+
102
+ **Arguments:**
103
+
104
+ - `--models_dir` (optional): Path to the models (directory).
105
+ - `--model_path`: Path to the model (.h5/Keras Model).
106
+ - `--img_path`: Path to the image file.
107
+ - `--test_dir`: Path to the test dataset (directory).
108
+ - `--train_dir`: Path to the training data.
109
+ - `--log_dir`: Path to the log directory.
110
+
111
+ **Example:**
112
+
113
+ ```sh
114
+ python test.py --model_path "./models/vgg16_model.keras" --test_dir "./test_data" --train_dir "./data/train" --log_dir "./logs"
115
+ ```
116
+
117
+ #### 4. `predict.py`
118
+
119
+ This script is used for making predictions on new images.
120
+
121
+ **Command Line Usage:**
122
+
123
+ ```sh
124
+ python predict.py --model_path <model_path> --img_path <img_path> --train_dir <train_dir>
125
+ ```
126
+
127
+ **Arguments:**
128
+
129
+ - `--model_path`: Path to the model file.
130
+ - `--img_path`: Path to the image file.
131
+ - `--train_dir`: Path to the training dataset (for the label decoder, can be replaced with a CSV file with slight code modifications).
132
+
133
+ **Example:**
134
+
135
+ ```sh
136
+ python predict.py --model_path "./models/vgg16_model.keras" --img_path "./images/test_image.jpg" --train_dir "./data/train"
137
+ ```
138
+
139
+ #### 5. `classify_image_and_explain.py`
140
+
141
+ This script is used for making predictions on new images and generating explanations using one or more explainers (LIME, SHAP, Grad-CAM). The explanations are saved in the specified output folder, with filenames indicating the method used (e.g., `lime_explanation_1.jpg`, `shap_explanation_1.jpg`, `gradcam_explanation_1.jpg`).
142
+
143
+ **Command Line Usage:**
144
+
145
+ ```sh
146
+ python classify_image_and_explain.py --image_path <image_path> --model_path <model_path> --train_directory <train_directory> --num_samples <num_samples> --num_features <num_features> --segmentation_alg <segmentation_alg> --kernel_size <kernel_size> --max_dist <max_dist> --ratio <ratio> --max_evals <max_evals> --batch_size <batch_size> --explainer_types <explainer_types> --output_folder <output_folder>
147
+ ```
148
+
149
+ **Arguments:**
150
+
151
+ - `--image_path` (required): Path to the input image.
152
+ - `--model_path` (required): Path to the trained model.
153
+ - `--train_directory` (required): Directory containing training images.
154
+ - `--num_samples` (default: 300): Number of samples for LIME.
155
+ - `--num_features` (default: 100): Number of features for LIME.
156
+ - `--segmentation_alg` (default: `quickshift`): Segmentation algorithm for LIME (`quickshift`, `slic`).
157
+ - `--kernel_size` (default: 4): Kernel size for the segmentation algorithm.
158
+ - `--max_dist` (default: 200): Maximum distance for the segmentation algorithm.
159
+ - `--ratio` (default: 0.2): Ratio for the segmentation algorithm.
160
+ - `--max_evals` (default: 400): Maximum evaluations for SHAP.
161
+ - `--batch_size` (default: 50): Batch size for SHAP.
162
+ - `--explainer_types` (default: 'all'): Comma-separated list of explainers to use (`lime`, `shap`, `gradcam`). Use 'all' to include all three explainers.
163
+ - `--output_folder` (optional): Folder to save explanation images.
164
+
165
+ **Example:**
166
+
167
+ ```sh
168
+ python classify_image_and_explain.py --image_path "./images/test_image.jpg" --model_path "./models/model.keras" --train_directory "./data/train" --num_samples 300 --num_features 100 --segmentation_alg "quickshift" --kernel_size 4 --max_dist 200 --ratio 0.2 --max_evals 400 --batch_size 50 --explainer_types "lime, gradcam" --output_folder "./explanations"
169
+ ```
170
+
171
+ ### Supported Base Models
172
+
173
+ The following base models are supported for training:
174
+ - VGG16
175
+ - VGG19
176
+ - ResNet50
177
+ - ResNet101
178
+ - InceptionV3
179
+ - DenseNet121
180
+ - DenseNet201
181
+ - MobileNetV2
182
+ - Xception
183
+ - InceptionResNetV2
184
+ - NASNetLarge
185
+ - NASNetMobile
186
+ - EfficientNetB0
187
+ - EfficientNetB7
188
+
189
+ ### Running Scripts in a Python Script
190
+
191
+ You can also run these scripts programmatically using Python's `subprocess` module. Here is an example of how to do this for each script:
192
+
193
+ ```python
194
+ import subprocess
195
+
196
+ # Run data_loader.py
197
+ subprocess.run([
198
+ "python", "data_loader.py",
199
+ "--path", "./data",
200
+ "--target_folder", "./processed_data",
201
+ "--dim", "224",
202
+ "--batch_size", "32",
203
+ "--num_workers", "4",
204
+ "--augment_data"
205
+ ])
206
+
207
+ # Run train.py
208
+ subprocess.run([
209
+ "python", "train.py",
210
+ "--base_models", "VGG16,ResNet50",
211
+ "--shape", "224, 224, 3",
212
+ "--data_path", "./data",
213
+ "--log_dir", "./logs",
214
+ "--model_dir", "./models",
215
+ "--epochs", "100",
216
+ "--optimizer", "adam",
217
+ "--learning_rate", "0.001",
218
+ "--batch_size", "32",
219
+ "--patience", "10"
220
+ ])
221
+
222
+ # Run test.py
223
+ subprocess.run([
224
+ "python", "test.py",
225
+ "--models_dir", "./models",
226
+ "--img
227
+
228
+ _path", "./images/test_image.jpg",
229
+ "--train_dir", "./data/train",
230
+ "--log_dir", "./logs"
231
+ ])
232
+
233
+ # Run classify_image_and_explain.py
234
+ subprocess.run([
235
+ "python", "classify_image_and_explain.py",
236
+ "--image_path", "./images/test_image.jpg",
237
+ "--model_path", "./models/model.h5",
238
+ "--train_directory", "./data/train",
239
+ "--num_samples", "300",
240
+ "--num_features", "100",
241
+ "--segmentation_alg", "quickshift",
242
+ "--kernel_size", "4",
243
+ "--max_dist", "200",
244
+ "--ratio", "0.2",
245
+ "--max_evals", "400",
246
+ "--batch_size", "50",
247
+ "--explainer_types", "lime,gradcam",
248
+ "--output_folder", "./explanations"
249
+ ])
250
+ ```
251
+
252
+ ## License
253
+ This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details.
254
+
255
+ ## Citing the part of the project: Under water sonar image classifier with XAI LIME
256
+
257
+ If you use our SONAR classifier or the explainer in your research, please use the following BibTeX entry.
258
+
259
+ ```
260
+ @article{natarajan2024underwater,
261
+ title={Underwater SONAR Image Classification and Analysis using LIME-based Explainable Artificial Intelligence},
262
+ author={Natarajan, Purushothaman and Nambiar, Athira},
263
+ journal={arXiv preprint arXiv:2408.12837},
264
+ year={2024}
265
+ }
266
+ ```
balanced_data_loader-1.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import os
3
+ import argparse
4
+ from sklearn.model_selection import StratifiedShuffleSplit
5
+ from tqdm import tqdm
6
+ import uuid
7
+ import random
8
+
9
+ # Parses command line arguments
10
+ def parse_arguments():
11
+ parser = argparse.ArgumentParser(description='Image Data Loader with Augmentation and Splits')
12
+ parser.add_argument('--path', type=str, required=True, help='Path to the folder containing images')
13
+ parser.add_argument('--dim', type=int, default=224, help='Required image dimension')
14
+ parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
15
+ parser.add_argument('--target_folder', type=str, required=True, help='Folder to store the train, test, and val splits')
16
+ parser.add_argument('--augment_data', action='store_true', help='Apply data augmentation')
17
+ parser.add_argument('--balance', action='store_true', help='Balance the dataset')
18
+ parser.add_argument('--split_type', type=str, choices=['random', 'stratified'], default='random',
19
+ help='Type of data split (random or stratified)')
20
+ return parser.parse_args()
21
+
22
+ # Process the input images
23
+ def process_image(file_path, image_size):
24
+ image = tf.io.read_file(file_path)
25
+ image = tf.image.decode_image(image, channels=3, dtype=tf.float32)
26
+ image = tf.image.resize(image, [image_size, image_size])
27
+ image = tf.clip_by_value(image, 0.0, 1.0)
28
+ return image
29
+
30
+ # Balances the images of a specific class
31
+ def balance_class_images(image_paths, labels, target_count, image_size, label, label_to_index, output_folder):
32
+ print(f"Balancing class '{label}'...")
33
+ label_idx = label_to_index.get(label, None)
34
+ if label_idx is None:
35
+ print(f"Label '{label}' not found in label_to_index.")
36
+ return [], []
37
+
38
+ image_paths = [img for img, lbl in zip(image_paths, labels) if lbl == label_idx]
39
+ num_images = len(image_paths)
40
+
41
+ print(f"Class '{label}' has {num_images} images before balancing.")
42
+
43
+ balanced_images = []
44
+ balanced_labels = []
45
+
46
+ original_count = num_images
47
+ synthetic_count = 0
48
+
49
+ if num_images > target_count:
50
+ balanced_images.extend(random.sample(image_paths, target_count))
51
+ balanced_labels.extend([label_idx] * target_count)
52
+ print(f"Removed {num_images - target_count} images from class '{label}'.")
53
+ elif num_images < target_count:
54
+ balanced_images.extend(image_paths)
55
+ balanced_labels.extend([label_idx] * num_images)
56
+
57
+ num_to_add = target_count - num_images
58
+ print(f"Class '{label}' needs {num_to_add} additional images for balancing.")
59
+
60
+ while num_to_add > 0:
61
+ img_path = random.choice(image_paths)
62
+ image = process_image(img_path, image_size)
63
+
64
+ for _ in range(min(num_to_add, 5)): # Use up to 5 augmentations per image
65
+ augmented_image = augment_image(image)
66
+ balanced_images.append(augmented_image)
67
+ balanced_labels.append(label_idx)
68
+ num_to_add -= 1
69
+ synthetic_count += 1
70
+
71
+ print(f"Added {synthetic_count} augmented images to class '{label}'.")
72
+ print(f"Class '{label}' has {len(balanced_images)} images after balancing.")
73
+
74
+ class_folder = os.path.join(output_folder, str(label_idx))
75
+ if not os.path.exists(class_folder):
76
+ os.makedirs(class_folder)
77
+
78
+ for i, img in enumerate(balanced_images):
79
+ file_name = f"{uuid.uuid4()}.png"
80
+ file_path = os.path.join(class_folder, file_name)
81
+ save_image(img, file_path)
82
+
83
+ print(f"Saved {len(balanced_images)} images for class '{label}' (Original: {original_count}, Synthetic: {synthetic_count}).")
84
+
85
+ return balanced_images, balanced_labels
86
+
87
+ # Saves an image to a file
88
+ def save_image(image, file_path):
89
+ if isinstance(image, str):
90
+ image = process_image(image, image_size)
91
+ if isinstance(image, tf.Tensor):
92
+ image = tf.image.convert_image_dtype(image, dtype=tf.uint8)
93
+ image = tf.image.encode_png(image)
94
+ else:
95
+ raise ValueError("Expected image to be a TensorFlow tensor, but got a different type.")
96
+
97
+ tf.io.write_file(file_path, image)
98
+
99
+ # Augments an image with random transformations
100
+ def augment_image(image):
101
+ # Apply random augmentations using TensorFlow functions
102
+ image = tf.image.random_flip_left_right(image)
103
+ image = tf.image.random_flip_up_down(image)
104
+ image = tf.image.random_brightness(image, max_delta=0.1)
105
+ image = tf.image.random_contrast(image, lower=0.9, upper=1.1)
106
+ image = tf.image.random_saturation(image, lower=0.9, upper=1.1)
107
+ image = tf.image.random_hue(image, max_delta=0.1)
108
+ return image
109
+
110
+ # Creates a list of data augmentation functions
111
+ def create_datagens():
112
+ return [augment_image]
113
+
114
+ # Balances the entire dataset by balancing each class
115
+ def balance_data(images, labels, target_count, image_size, unique_labels, label_to_index, output_folder):
116
+ print(f"Balancing data: Target count per class = {target_count}")
117
+
118
+ all_balanced_images = []
119
+ all_balanced_labels = []
120
+
121
+ for label in tqdm(unique_labels, desc="Balancing classes"):
122
+ num_images = len([img for img, lbl in zip(images, labels) if lbl == label_to_index.get(label, -1)])
123
+ balanced_images, balanced_labels = balance_class_images(
124
+ images, labels, target_count, image_size, label, label_to_index, output_folder
125
+ )
126
+ all_balanced_images.extend(balanced_images)
127
+ all_balanced_labels.extend(balanced_labels)
128
+
129
+ total_original_images = sum(1 for img in all_balanced_images if isinstance(img, str))
130
+ total_synthetic_images = len(all_balanced_images) - total_original_images
131
+
132
+ print(f"\nTotal saved images: {len(all_balanced_images)} (Original: {total_original_images}, Synthetic: {total_synthetic_images})")
133
+
134
+ return all_balanced_images, all_balanced_labels
135
+
136
+ # Augments an image using TensorFlow functions
137
+ def tf_augment_image(file_path, label):
138
+ image = tf.image.resize(tf.image.decode_jpeg(tf.io.read_file(file_path)), [image_size, image_size])
139
+ image = tf.cast(image, tf.float32) / 255.0
140
+ augmented_image = augment_image(image)
141
+ return augmented_image, label
142
+
143
+
144
+ def map_fn(file_path, label):
145
+ image, label = tf.py_function(tf_augment_image, [file_path, label], [tf.float32, tf.int32])
146
+ image.set_shape([image_size, image_size, 3])
147
+ label.set_shape([])
148
+ return image, label
149
+
150
+ # Loads images, splits them into train, validation, and test sets, and saves the splits
151
+ def load_and_save_splits(path, image_size, batch_size, balance, datagens, target_folder, split_type):
152
+ all_images = []
153
+ labels = []
154
+
155
+ for class_folder in os.listdir(path):
156
+ class_path = os.path.join(path, class_folder)
157
+ if os.path.isdir(class_path):
158
+ for img_file in os.listdir(class_path):
159
+ img_path = os.path.join(class_path, img_file)
160
+ all_images.append(img_path)
161
+ labels.append(class_folder) # Use the folder name as the label
162
+
163
+ print(f"Loaded {len(all_images)} images across {len(set(labels))} classes.")
164
+ print(f"Labels found: {set(labels)}") # Print unique labels
165
+
166
+ unique_labels = list(set(labels))
167
+ label_to_index = {label: idx for idx, label in enumerate(unique_labels)}
168
+ encoded_labels = [label_to_index[label] for label in labels]
169
+
170
+ print(f"Label to index mapping: {label_to_index}")
171
+
172
+ if split_type == 'stratified':
173
+ sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
174
+ train_indices, test_indices = next(sss.split(all_images, encoded_labels))
175
+ else: # random split
176
+ total_images = len(all_images)
177
+ indices = list(range(total_images))
178
+ random.shuffle(indices)
179
+ train_indices = indices[:int(0.8 * total_images)]
180
+ test_indices = indices[int(0.8 * total_images):]
181
+
182
+ train_files = [all_images[i] for i in train_indices]
183
+ train_labels = [encoded_labels[i] for i in train_indices]
184
+ test_files = [all_images[i] for i in test_indices]
185
+ test_labels = [encoded_labels[i] for i in test_indices]
186
+
187
+ # Create validation and test sets
188
+ sss_val = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=42)
189
+ val_indices, test_indices = next(sss_val.split(test_files, test_labels))
190
+
191
+ val_files = [test_files[i] for i in val_indices]
192
+ val_labels = [test_labels[i] for i in val_indices]
193
+ test_files = [test_files[i] for i in test_indices]
194
+ test_labels = [test_labels[i] for i in test_indices]
195
+
196
+ # Save splits
197
+ for split_name, file_list, labels_list in [("train", train_files, train_labels), ("val", val_files, val_labels), ("test", test_files, test_labels)]:
198
+ split_folder = os.path.join(target_folder, split_name)
199
+ os.makedirs(split_folder, exist_ok=True)
200
+ with open(os.path.join(split_folder, f"{split_name}_list.txt"), 'w') as file_list_file:
201
+ for img_path, label in zip(file_list, labels_list):
202
+ label_folder = os.path.join(split_folder, str(label))
203
+ if not os.path.exists(label_folder):
204
+ os.makedirs(label_folder)
205
+ file_list_file.write(f"{img_path}\n")
206
+ save_image(img_path, os.path.join(label_folder, f"{uuid.uuid4()}.png"))
207
+
208
+ print(f"Saved splits: train: {len(train_files)}, val: {len(val_files)}, test: {len(test_files)}.")
209
+
210
+ # Main function to run the data loader
211
+ def main():
212
+ args = parse_arguments()
213
+ load_and_save_splits(args.path, args.dim, args.batch_size, args.balance, create_datagens(), args.target_folder, args.split_type)
214
+
215
+ if __name__ == "__main__":
216
+ main()
classify_image_and_explain.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from tensorflow.keras.preprocessing.image import img_to_array, array_to_img, load_img
5
+ from lime.lime_image import LimeImageExplainer, SegmentationAlgorithm
6
+ import matplotlib.pyplot as plt
7
+ from PIL import Image
8
+ import argparse
9
+ import shap
10
+ import cv2
11
+ import pickle
12
+
13
+ image_counter = 0
14
+ temp_folder = "temp_data"
15
+ output_folder = "explanations"
16
+
17
+ # Load the model and extract relevant details
18
+ def load_model_details(model_path):
19
+ if model_path.endswith('.keras'):
20
+ print("Loading .keras format model...")
21
+ model = tf.keras.models.load_model(model_path, compile=False)
22
+ elif model_path.endswith('.h5'):
23
+ print("Loading .h5 format model...")
24
+ model = tf.keras.models.load_model(model_path, compile=False)
25
+ else:
26
+ print("Loading SavedModel using TFSMLayer...")
27
+ model = tf.keras.Sequential([
28
+ tf.keras.layers.TFSMLayer(model_path, call_endpoint='serving_default')
29
+ ])
30
+
31
+ input_shape = model.input_shape[1:3]
32
+ last_conv_layer_name = None
33
+ for layer in reversed(model.layers):
34
+ if isinstance(layer, tf.keras.layers.Conv2D):
35
+ last_conv_layer_name = layer.name
36
+ break
37
+
38
+ print(f"Model loaded with input shape: {input_shape} and last conv layer: {last_conv_layer_name}")
39
+ return model, last_conv_layer_name, input_shape
40
+
41
+ # Load the label encoder based on the training directory
42
+ def load_label_encoder(train_directory):
43
+ labels = sorted(os.listdir(train_directory))
44
+ label_encoder = {i: label for i, label in enumerate(labels)}
45
+ print(f"Label encoder created: {label_encoder}")
46
+ return label_encoder
47
+
48
+ def load_and_preprocess_image(filename, image_size):
49
+ # Load and preprocess the image for model input
50
+ print(f"Loading and preprocessing image from: {filename}")
51
+ image = tf.io.read_file(filename)
52
+ image = tf.image.decode_image(image, channels=3)
53
+
54
+ if not tf.executing_eagerly():
55
+ image.set_shape([None, None, 3])
56
+
57
+ image = tf.image.resize(image, [image_size[0], image_size[1]])
58
+ image = image / 255.0
59
+ image.set_shape([image_size[0], image_size[1], 3])
60
+
61
+ return image
62
+
63
+ # Create a dataset from the training directory
64
+ def create_dataset(data_dir, labels, image_size, batch_size):
65
+ print(f"Creating dataset from directory: {data_dir}")
66
+ image_files = []
67
+ image_labels = []
68
+
69
+ for label in labels:
70
+ label_dir = os.path.join(data_dir, label)
71
+ for image_file in os.listdir(label_dir):
72
+ image_files.append(os.path.join(label_dir, image_file))
73
+ image_labels.append(label)
74
+
75
+ label_map = {label: idx for idx, label in enumerate(labels)}
76
+ image_labels = [label_map[label] for label in image_labels]
77
+
78
+ dataset = tf.data.Dataset.from_tensor_slices((image_files, image_labels))
79
+ dataset = dataset.map(lambda x, y: (load_and_preprocess_image(x, image_size), y))
80
+ dataset = dataset.shuffle(buffer_size=len(image_files))
81
+ dataset = dataset.batch(batch_size).prefetch(buffer_size=tf.data.AUTOTUNE)
82
+
83
+ print("Dataset created and batched")
84
+ return dataset
85
+
86
+ # Save preprocessed data (images and labels) to a file
87
+ def save_preprocessed_data(X_train, y_train, file_path):
88
+ print(f"Saving preprocessed data to: {file_path}")
89
+ with open(file_path, 'wb') as file:
90
+ pickle.dump((X_train, y_train), file)
91
+
92
+
93
+ def load_preprocessed_data(file_path):
94
+ print(f"Loading preprocessed data from: {file_path}")
95
+ with open(file_path, 'rb') as file:
96
+ return pickle.load(file)
97
+
98
+ def make_gradcam_heatmap(img_array, model, last_conv_layer_name, pred_index=None):
99
+ # Generate a Grad-CAM heatmap for the given image and model
100
+
101
+ grad_model = tf.keras.models.Model(
102
+ inputs=model.inputs, outputs=[model.get_layer(last_conv_layer_name).output, model.output]
103
+ )
104
+ with tf.GradientTape() as tape:
105
+ last_conv_layer_output, preds = grad_model(img_array)
106
+ preds = tf.convert_to_tensor(preds)
107
+ class_channel = preds[:, pred_index]
108
+ # if pred_index is None:
109
+ # pred_index = tf.argmax(preds[0]) # Default to the class with the highest probability
110
+ # pred_index = tf.squeeze(pred_index) # Ensure pred_index is a scalar tensor
111
+ # if tf.executing_eagerly():
112
+ # pred_index = pred_index.numpy() # Convert to a NumPy array
113
+ # pred_index = int(pred_index) # Convert to a Python integer
114
+ # class_channel = preds[0][pred_index]
115
+
116
+ grads = tape.gradient(class_channel, last_conv_layer_output)
117
+ pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
118
+ last_conv_layer_output = last_conv_layer_output[0]
119
+ heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
120
+ heatmap = tf.squeeze(heatmap)
121
+ heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
122
+ return heatmap.numpy()
123
+
124
+ def save_and_display_gradcam(array, heatmap, alpha=0.8):
125
+ # Save and display the Grad-CAM heatmap overlaid on the original image
126
+ print("Saving and displaying Grad-CAM result...")
127
+ heatmap = np.uint8(255 * heatmap)
128
+ jet = plt.cm.jet
129
+ jet_colors = jet(np.arange(256))[:, :3]
130
+ jet_heatmap = jet_colors[heatmap]
131
+ jet_heatmap = array_to_img(jet_heatmap)
132
+ jet_heatmap = jet_heatmap.resize((array.shape[1], array.shape[0]))
133
+ jet_heatmap = img_to_array(jet_heatmap)
134
+ superimposed_img = jet_heatmap * alpha + array
135
+ superimposed_img = array_to_img(superimposed_img)
136
+ return superimposed_img
137
+
138
+ def generate_splime_mask_top_n(img_array, model, explainer, top_n=1, num_features=100, num_samples=300):
139
+ # Generate a SP-LIME mask for the given image and model
140
+ # Use superpixel segmentation for SP-LIME
141
+ segmentation_fn = SegmentationAlgorithm('quickshift', kernel_size=4, max_dist=200, ratio=0.2)
142
+
143
+ explanation_instance = explainer.explain_instance(
144
+ img_array, model.predict, top_labels=top_n, hide_color=0,
145
+ num_samples=num_samples, num_features=num_features, segmentation_fn=segmentation_fn
146
+ )
147
+ explanation_mask = explanation_instance.get_image_and_mask(
148
+ explanation_instance.top_labels[0], positive_only=False,
149
+ num_features=num_features, hide_rest=True
150
+ )[1]
151
+
152
+ # Ensure mask is in the same shape as the input image
153
+ mask = np.zeros_like(img_array) # Create a mask of the same shape as img_array
154
+ mask[explanation_mask == 1] = img_array[explanation_mask == 1] # Overlay highlighted regions
155
+
156
+ # Set non-highlighted areas to white
157
+ mask = np.where(explanation_mask[:, :, np.newaxis] == 1, mask, 1.0)
158
+
159
+ return mask, explanation_instance
160
+
161
+
162
+ def explain_image_shap(img, model, class_names, top_prediction, max_evals=1000, batch_size=50):
163
+ # Generate SHAP explanations for the given image and model
164
+ masker = shap.maskers.Image("inpaint_telea", img[0].shape) # Update if necessary
165
+
166
+ # Define a function to predict probabilities from the model
167
+ def f(X):
168
+ return model.predict(X)
169
+
170
+ # Create the SHAP explainer
171
+ explainer = shap.Explainer(f, masker, output_names=class_names)
172
+
173
+ # Get SHAP values
174
+ shap_values = explainer(img, max_evals=max_evals, batch_size=batch_size, outputs=shap.Explanation.argsort.flip[:1])
175
+
176
+ return shap_values
177
+
178
+ def classify_image_and_explain(image_path, model_path, train_directory, num_samples, num_features, segmentation_alg, kernel_size, max_dist, ratio, max_evals, batch_size, explainer_types, output_folder):
179
+ # Main function to classify the image and generate explanations
180
+ global image_counter
181
+
182
+ if output_folder is None:
183
+ output_folder = "explanations"
184
+ if not os.path.exists(output_folder):
185
+ os.makedirs(output_folder)
186
+
187
+ model, last_conv_layer_name, input_shape = load_model_details(model_path)
188
+ label_encoder = load_label_encoder(train_directory)
189
+ labels = list(label_encoder.values())
190
+
191
+ # Load the image
192
+ image = load_img(image_path, target_size=input_shape)
193
+ if image.mode != 'RGB':
194
+ image = image.convert('RGB')
195
+ array = img_to_array(image)
196
+ img_array = array / 255.0
197
+ img_array = np.expand_dims(img_array, axis=0)
198
+
199
+ # Predict the class of the image
200
+ predictions = model.predict(img_array)
201
+ top_prediction = np.argmax(predictions[0])
202
+ top_label = label_encoder[top_prediction]
203
+
204
+ print(f"Prediction: {top_label} with probability {predictions[0][top_prediction]:.4f}")
205
+
206
+ # Generate explanations based on user-specified types
207
+ if 'gradcam' in explainer_types:
208
+ model.layers[-1].activation = None
209
+ heatmap = make_gradcam_heatmap(img_array, model, last_conv_layer_name)
210
+ gradcam_image = save_and_display_gradcam(img_to_array(image), heatmap)
211
+ gradcam_image.save(os.path.join(output_folder, f"gradcam_{image_counter}.png"))
212
+
213
+ if 'lime' in explainer_types:
214
+ # SPLIME Explanation
215
+ explainer = LimeImageExplainer()
216
+ splime_mask, explanation_instance = generate_splime_mask_top_n(img_array[0], model, explainer, top_n=1, num_features=num_features, num_samples=num_samples)
217
+ # Ensure splime_mask is in [0, 1] range before saving
218
+ splime_mask = np.clip(splime_mask, 0, 1)
219
+ plt.imsave(os.path.join(output_folder, f"splime_{image_counter}.png"), splime_mask)
220
+
221
+ if 'shap' in explainer_types:
222
+ custom_image = img_to_array(image) / 255.0 # Preprocess image for SHAP
223
+ shap_values = explain_image_shap(custom_image.reshape(1, *custom_image.shape), model, labels, top_prediction, max_evals=max_evals, batch_size=batch_size)
224
+ shap.image_plot(shap_values[0], custom_image, labels=[top_label], show=False)
225
+ plt.savefig(os.path.join(output_folder, f"shap_{image_counter}.png"))
226
+ #plt.show()
227
+ plt.close()
228
+
229
+ print("Image classification and explanation process completed.")
230
+ image_counter += 1
231
+
232
+ if __name__ == "__main__":
233
+ parser = argparse.ArgumentParser(description="Image classification and explanation script")
234
+ parser.add_argument("--image_path", type=str, required=True, help="Path to the input image")
235
+ parser.add_argument("--model_path", type=str, required=True, help="Path to the trained model")
236
+ parser.add_argument("--train_directory", type=str, required=True, help="Directory containing training images")
237
+ parser.add_argument("--num_samples", type=int, default=300, help="Number of samples for LIME")
238
+ parser.add_argument("--num_features", type=int, default=100, help="Number of features for LIME")
239
+ parser.add_argument("--segmentation_alg", type=str, default='quickshift', help="Segmentation algorithm for LIME (options: quickshift, slic)")
240
+ parser.add_argument("--kernel_size", type=int, default=4, help="Kernel size for segmentation algorithm")
241
+ parser.add_argument("--max_dist", type=int, default=200, help="Max distance for segmentation algorithm")
242
+ parser.add_argument("--ratio", type=float, default=0.2, help="Ratio for segmentation algorithm")
243
+ parser.add_argument("--max_evals", type=int, default=400, help="Maximum evaluations for SHAP")
244
+ parser.add_argument("--batch_size", type=int, default=50, help="Batch size for SHAP")
245
+ parser.add_argument("--explainer_types", type=str, default='all', help="Comma-separated list of explainers to use (options: lime, shap, gradcam). Use 'all' to include all three.")
246
+ parser.add_argument("--output_folder", type=str, default=None, help="Output folder for explanations")
247
+
248
+ args = parser.parse_args()
249
+
250
+ explainer_types = args.explainer_types.split(',') if args.explainer_types != 'all' else ['lime', 'shap', 'gradcam']
251
+
252
+ classify_image_and_explain(
253
+ args.image_path, args.model_path, args.train_directory, args.num_samples,
254
+ args.num_features, args.segmentation_alg, args.kernel_size, args.max_dist,
255
+ args.ratio, args.max_evals, args.batch_size, explainer_types, args.output_folder
256
+ )
data_loader.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import os
3
+ import argparse
4
+ from sklearn.model_selection import StratifiedShuffleSplit
5
+ from tqdm import tqdm # For progress display
6
+ import sys
7
+ import uuid # Import uuid for unique filename generation
8
+ from tensorflow.keras.preprocessing.image import ImageDataGenerator
9
+
10
+ def parse_arguments():
11
+ parser = argparse.ArgumentParser(description='Image Data Loader with Augmentation and Splits')
12
+ parser.add_argument('--path', type=str, required=True, help='Path to the folder containing images')
13
+ parser.add_argument('--dim', type=int, default=224, help='Required image dimension')
14
+ parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
15
+ parser.add_argument('--num_workers', type=int, default=4, help='Number of workers for data loading')
16
+ parser.add_argument('--target_folder', type=str, required=True, help='Folder to store the train, test, and val splits')
17
+ parser.add_argument('--augment_data', action='store_true', help='Apply data augmentation')
18
+ return parser.parse_args()
19
+
20
+ def create_datagens():
21
+ # Create a list of ImageDataGenerator objects for different augmentations
22
+ return [
23
+ ImageDataGenerator(rescale=1./255),
24
+ ImageDataGenerator(rotation_range=20),
25
+ ImageDataGenerator(width_shift_range=0.2),
26
+ ImageDataGenerator(height_shift_range=0.2),
27
+ ImageDataGenerator(horizontal_flip=True)
28
+ ]
29
+
30
+ def process_image(file_path, image_size):
31
+ # Read, decode, resize, and normalize an image
32
+ file_path = file_path.numpy().decode('utf-8')
33
+ image = tf.io.read_file(file_path)
34
+ image = tf.image.decode_image(image, channels=3, dtype=tf.float32)
35
+ image = tf.image.resize(image, [image_size, image_size])
36
+ image = tf.clip_by_value(image, 0.0, 1.0)
37
+ return image
38
+
39
+ def save_image(image, file_path):
40
+ # Convert image to uint8, encode as JPEG, and save to file
41
+ image = tf.image.convert_image_dtype(image, dtype=tf.uint8)
42
+ image = tf.image.encode_jpeg(image)
43
+ tf.io.write_file(file_path, image)
44
+
45
+ def load_data(path, image_size, batch_size):
46
+ all_images = []
47
+ labels = []
48
+ # Load images and labels from the specified path
49
+
50
+ for subdir, _, files in os.walk(path):
51
+ label = os.path.basename(subdir)
52
+ for fname in files:
53
+ if fname.endswith(('.jpg', '.jpeg', '.png')):
54
+ all_images.append(os.path.join(subdir, fname))
55
+ labels.append(label)
56
+
57
+ unique_labels = set(labels)
58
+ print(f"Found {len(all_images)} images in {path}\n")
59
+ print(f"Labels found ({len(unique_labels)}): {unique_labels}\n")
60
+
61
+ # Raise an error if no images are found
62
+ if len(all_images) == 0:
63
+ raise ValueError(f"No images found in the specified path: {path}")
64
+
65
+ # Stratified splitting the dataset
66
+ sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
67
+ train_indices, test_indices = next(sss.split(all_images, labels))
68
+
69
+ train_files = [all_images[i] for i in train_indices]
70
+ train_labels = [labels[i] for i in train_indices]
71
+ test_files = [all_images[i] for i in test_indices]
72
+ test_labels = [labels[i] for i in test_indices]
73
+
74
+ sss_val = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=42)
75
+ val_indices, test_indices = next(sss_val.split(test_files, test_labels))
76
+
77
+ val_files = [test_files[i] for i in val_indices]
78
+ val_labels = [test_labels[i] for i in val_indices]
79
+ test_files = [test_files[i] for i in test_indices]
80
+ test_labels = [test_labels[i] for i in test_indices]
81
+
82
+ print(f"Data split into {len(train_files)} train, {len(val_files)} validation, and {len(test_files)} test images.\n")
83
+
84
+ # Define a function to load and augment images
85
+ def tf_load_and_augment_image(file_path, label):
86
+ image = tf.py_function(func=lambda x: process_image(x, image_size), inp=[file_path], Tout=tf.float32)
87
+ image.set_shape([image_size, image_size, 3])
88
+ return image, label
89
+
90
+ train_dataset = tf.data.Dataset.from_tensor_slices((train_files, train_labels))
91
+ val_dataset = tf.data.Dataset.from_tensor_slices((val_files, val_labels))
92
+ # Create datasets from the loaded files and labels
93
+ test_dataset = tf.data.Dataset.from_tensor_slices((test_files, test_labels))
94
+
95
+ train_dataset = train_dataset.map(lambda x, y: tf_load_and_augment_image(x, y))
96
+ val_dataset = val_dataset.map(lambda x, y: tf_load_and_augment_image(x, y))
97
+ test_dataset = test_dataset.map(lambda x, y: tf_load_and_augment_image(x, y))
98
+
99
+ train_dataset = train_dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
100
+ val_dataset = val_dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
101
+ test_dataset = test_dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
102
+
103
+ return train_dataset, val_dataset, test_dataset
104
+
105
+ def save_datasets_to_folders(dataset, folder_path, datagens=None):
106
+ # Save the dataset to specified folders with optional augmentations
107
+ if not os.path.exists(folder_path):
108
+ os.makedirs(folder_path)
109
+
110
+ count = 0
111
+ for batch_images, batch_labels in tqdm(dataset, desc=f"Saving to {folder_path}"):
112
+ for i in range(batch_images.shape[0]):
113
+ image = batch_images[i].numpy()
114
+ label = batch_labels[i].numpy().decode('utf-8')
115
+ label_folder = os.path.join(folder_path, label)
116
+ if not os.path.exists(label_folder):
117
+ os.makedirs(label_folder)
118
+
119
+ # Save the original image
120
+ file_path = os.path.join(label_folder, f"{uuid.uuid4().hex}.jpg")
121
+ save_image(image, file_path)
122
+ count += 1
123
+
124
+ # Apply augmentations if datagens are provided
125
+ if datagens:
126
+ for datagen in datagens:
127
+ aug_image = datagen.random_transform(image)
128
+ file_path = os.path.join(label_folder, f"{uuid.uuid4().hex}.jpg")
129
+ save_image(aug_image, file_path)
130
+ count += 1
131
+
132
+ print(f"Saved {count} images to {folder_path}\n")
133
+ return count
134
+
135
+ def main():
136
+ # Main function to parse arguments, load data, and save datasets
137
+ args = parse_arguments()
138
+
139
+ if not os.path.exists(args.target_folder):
140
+ os.makedirs(args.target_folder)
141
+
142
+ train_folder = os.path.join(args.target_folder, 'train')
143
+ val_folder = os.path.join(args.target_folder, 'val')
144
+ test_folder = os.path.join(args.target_folder, 'test')
145
+
146
+ datagens = create_datagens() if args.augment_data else None
147
+
148
+ train_dataset, val_dataset, test_dataset = load_data(
149
+ args.path,
150
+ args.dim,
151
+ args.batch_size
152
+ )
153
+
154
+ # Save datasets to respective folders and count images
155
+ train_count = save_datasets_to_folders(train_dataset, train_folder, datagens)
156
+ val_count = save_datasets_to_folders(val_dataset, val_folder)
157
+ test_count = save_datasets_to_folders(test_dataset, test_folder)
158
+
159
+ print(f"Train dataset saved to: {train_folder}\n")
160
+ print(f"Validation dataset saved to: {val_folder}\n")
161
+ print(f"Test dataset saved to: {test_folder}\n")
162
+
163
+ print('-'*20)
164
+
165
+ print(f"Number of images in training set: {train_count}\n")
166
+ print(f"Number of images in validation set: {val_count}\n")
167
+ print(f"Number of images in test set: {test_count}\n")
168
+
169
+ if __name__ == "__main__":
170
+ # Redirect stdout and stderr to avoid encoding issues
171
+ sys.stdout = open(sys.stdout.fileno(), mode='w', encoding='utf-8', buffering=1)
172
+ sys.stderr = open(sys.stderr.fileno(), mode='w', encoding='utf-8', buffering=1)
173
+ main()
predict.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ from tensorflow.keras.preprocessing import image
6
+ from tensorflow.keras.models import load_model
7
+
8
+ def load_and_preprocess_image(img_path, target_size):
9
+ # Load and preprocess the image for prediction.
10
+ """Load and preprocess the image for prediction."""
11
+ img = image.load_img(img_path, target_size=target_size)
12
+ img_array = image.img_to_array(img)
13
+ img_array = np.expand_dims(img_array, axis=0) # Create batch axis
14
+ img_array = img_array / 255.0 # Normalize the image
15
+ return img_array
16
+
17
+ def load_model_from_file(model_path):
18
+ # Load the pre-trained model from the specified path.
19
+ """Load the pre-trained model from the specified path."""
20
+ model = load_model(model_path)
21
+ print(f"Model loaded from {model_path}")
22
+ return model
23
+
24
+ def make_predictions(model, img_array):
25
+ # Make predictions using the loaded model.
26
+ """Make predictions using the loaded model."""
27
+ predictions = model.predict(img_array)
28
+ return predictions
29
+
30
+ def get_class_names(train_dir):
31
+ """Get class names from training directory."""
32
+ class_names = os.listdir(train_dir) # Assuming subfolder names are the class labels
33
+ class_names.sort() # Ensure consistent ordering
34
+ return class_names
35
+
36
+ def main(model_path, img_path, train_dir):
37
+ # Main function to load model, preprocess image, make predictions, and display results.
38
+ # Define target image size based on model requirements
39
+ target_size = (224, 224) # Adjust if needed
40
+
41
+ # Load the model
42
+ model = load_model_from_file(model_path)
43
+
44
+ # Get class names from train directory
45
+ class_names = get_class_names(train_dir)
46
+
47
+ # Load and preprocess the image
48
+ img_array = load_and_preprocess_image(img_path, target_size)
49
+
50
+ # Make predictions
51
+ predictions = make_predictions(model, img_array)
52
+ predicted_label_index = np.argmax(predictions, axis=1)[0]
53
+ predicted_label = class_names[predicted_label_index]
54
+ probability_score = predictions[0][predicted_label_index]
55
+
56
+ print(f"Predicted label: {predicted_label}, Probability: {probability_score:.4f}")
57
+
58
+ if __name__ == "__main__":
59
+ parser = argparse.ArgumentParser(description="Load a pre-trained model and make a prediction on a new image")
60
+ parser.add_argument('--model_path', type=str, required=True, help='Path to the saved model')
61
+ parser.add_argument('--img_path', type=str, required=True, help='Path to the image to be predicted')
62
+ parser.add_argument('--train_dir', type=str, required=True, help='Directory containing training dataset for inferring class names')
63
+
64
+ args = parser.parse_args()
65
+ main(args.model_path, args.img_path, args.train_dir)
requirements.txt ADDED
Binary file (520 Bytes). View file
 
test.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ from tensorflow.keras.preprocessing import image
6
+ from tensorflow.keras.models import load_model
7
+ from sklearn.metrics import classification_report, confusion_matrix
8
+ import matplotlib.pyplot as plt
9
+ import seaborn as sns
10
+ from tqdm import tqdm
11
+
12
+ # Load and preprocess an image for prediction
13
+ def load_and_preprocess_image(img_path, target_size):
14
+ """Load and preprocess the image for prediction."""
15
+ img = image.load_img(img_path, target_size=target_size)
16
+ img_array = image.img_to_array(img)
17
+ img_array = np.expand_dims(img_array, axis=0) # Create batch axis
18
+ img_array = img_array / 255.0 # Normalize the image
19
+ return img_array
20
+
21
+ # Load all models from a specified directory
22
+ def load_all_models(model_dir):
23
+ """Load all models from the specified directory."""
24
+ models = {}
25
+ for file_name in os.listdir(model_dir):
26
+ if file_name.endswith('_model.keras'):
27
+ model_path = os.path.join(model_dir, file_name)
28
+ model_name = file_name.split('_model.keras')[0] # Extract model name
29
+ model = load_model(model_path)
30
+ models[model_name] = model
31
+ print(f"Model loaded from {model_path}")
32
+ if not models:
33
+ raise FileNotFoundError(f"No model files found in {model_dir}.")
34
+ return models
35
+
36
+ # Load a single model from a specified path
37
+ def load_model_from_file(model_path):
38
+ """Load a single model from the specified path."""
39
+ model = load_model(model_path)
40
+ print(f"Model loaded from {model_path}")
41
+ return model
42
+
43
+ def make_predictions(model, img_array):
44
+ # Make predictions using the loaded model
45
+ """Make predictions using the loaded model."""
46
+ predictions = model.predict(img_array)
47
+ return predictions
48
+
49
+ def get_class_names(train_dir):
50
+ """Get class names from training directory."""
51
+ class_names = os.listdir(train_dir) # Assuming subfolder names are the class labels
52
+ class_names.sort() # Ensure consistent ordering
53
+ return class_names
54
+
55
+ # Compute confusion matrix and classification report, and save to log directory
56
+ def compute_confusion_matrix_and_report(true_labels, predicted_labels, class_names, log_dir, model_name):
57
+ """Compute confusion matrix and classification report, and save to log directory."""
58
+ # Compute confusion matrix
59
+ conf_matrix = confusion_matrix(true_labels, predicted_labels, labels=class_names)
60
+ report = classification_report(true_labels, predicted_labels, target_names=class_names)
61
+
62
+ # Print the classification report
63
+ print(f"Model: {model_name}")
64
+ print(report)
65
+
66
+ # Plot the confusion matrix
67
+ plt.figure(figsize=(10, 8))
68
+ sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
69
+ plt.xlabel('Predicted Label')
70
+ plt.ylabel('True Label')
71
+ plt.title(f'Confusion Matrix - {model_name}')
72
+
73
+ # Save plot
74
+ if not os.path.exists(log_dir):
75
+ os.makedirs(log_dir)
76
+
77
+ conf_matrix_plot_file = os.path.join(log_dir, f'confusion_matrix_{model_name}.png')
78
+ plt.savefig(conf_matrix_plot_file)
79
+ plt.close()
80
+
81
+ # Save results to log directory
82
+ conf_matrix_file = os.path.join(log_dir, f'confusion_matrix_{model_name}.txt')
83
+ report_file = os.path.join(log_dir, f'classification_report_{model_name}.txt')
84
+
85
+ np.savetxt(conf_matrix_file, conf_matrix, fmt='%d', delimiter=',', header=','.join(class_names))
86
+ with open(report_file, 'w') as f:
87
+ f.write(report)
88
+
89
+ print(f"Confusion matrix and classification report saved to {log_dir} with model name: {model_name}")
90
+
91
+ # Main function to load models, make predictions, and evaluate performance
92
+ def main(model_path, model_dir, img_path, test_dir, train_dir, log_dir):
93
+ # Define target image size based on model requirements
94
+ target_size = (224, 224) # Adjust if needed
95
+
96
+ if model_path:
97
+ # Load a single model
98
+ model = load_model_from_file(model_path)
99
+ models = {os.path.basename(model_path): model}
100
+ elif model_dir:
101
+ # Load all models from a directory
102
+ models = load_all_models(model_dir)
103
+ else:
104
+ raise ValueError("Either --model_path or --model_dir must be provided.")
105
+
106
+ # Get class names from train directory
107
+ class_names = get_class_names(train_dir)
108
+ num_classes = len(class_names)
109
+
110
+ # If an image path is provided, perform prediction on that image
111
+ if img_path:
112
+ img_array = load_and_preprocess_image(img_path, target_size)
113
+ for model_name, model in models.items():
114
+ print(f"Model: {model_name}")
115
+ predictions = make_predictions(model, img_array)
116
+ predicted_label_index = np.argmax(predictions, axis=1)[0]
117
+ if predicted_label_index >= num_classes:
118
+ raise ValueError(f"Predicted label index {predicted_label_index} is out of range for class names list.")
119
+ predicted_label = class_names[predicted_label_index]
120
+ probability_score = predictions[0][predicted_label_index]
121
+ print('-'*20)
122
+ print(f"Predicted label: {predicted_label}, Probability: {probability_score:.4f}")
123
+ print('-'*20)
124
+
125
+ # If a test directory is provided, perform batch predictions and evaluation
126
+ if test_dir:
127
+ files = [os.path.join(root, file) for root, _, files in os.walk(test_dir) for file in files if file.endswith(('png', 'jpg', 'jpeg'))]
128
+
129
+ for model_name, model in models.items():
130
+ true_labels = []
131
+ predicted_labels = []
132
+
133
+ for img_path in tqdm(files, desc=f"Processing images with {model_name}"):
134
+ img_array = load_and_preprocess_image(img_path, target_size)
135
+ predictions = make_predictions(model, img_array)
136
+ predicted_label_index = np.argmax(predictions, axis=1)[0]
137
+ if predicted_label_index >= num_classes:
138
+ raise ValueError(f"Predicted label index {predicted_label_index} is out of range for class names list.")
139
+ predicted_label = class_names[predicted_label_index]
140
+
141
+ true_label = os.path.basename(os.path.dirname(img_path)) # Assuming folder name is the label
142
+ if true_label not in class_names:
143
+ raise ValueError(f"True label {true_label} is not in class names list.")
144
+
145
+ true_labels.append(true_label)
146
+ predicted_labels.append(predicted_label)
147
+
148
+ # Compute and save confusion matrix and classification report
149
+ compute_confusion_matrix_and_report(true_labels, predicted_labels, class_names, log_dir, model_name)
150
+
151
+ if __name__ == "__main__":
152
+ parser = argparse.ArgumentParser(description="Load models and make predictions on new images or a test dataset")
153
+ parser.add_argument('--model_path', type=str, help='Path to a single saved model')
154
+ parser.add_argument('--model_dir', type=str, help='Directory containing saved models (loads all models in the folder)')
155
+ parser.add_argument('--img_path', type=str, help='Path to the image to be predicted')
156
+ parser.add_argument('--test_dir', type=str, help='Directory containing test dataset for batch predictions')
157
+ parser.add_argument('--train_dir', type=str, required=True, help='Directory containing training dataset for inferring class names')
158
+ parser.add_argument('--log_dir', type=str, required=True, help='Directory to save prediction results')
159
+
160
+ args = parser.parse_args()
161
+ main(args.model_path, args.model_dir, args.img_path, args.test_dir, args.train_dir, args.log_dir)
train.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import tensorflow as tf
4
+ from tensorflow.keras.models import Model
5
+ from tensorflow.keras.applications import (VGG16, VGG19, ResNet50, ResNet101, InceptionV3,
6
+ DenseNet121, DenseNet201, MobileNetV2, Xception, InceptionResNetV2,
7
+ NASNetLarge, NASNetMobile, EfficientNetB0, EfficientNetB7)
8
+ from tensorflow.keras.layers import Dense, Flatten, Dropout, BatchNormalization
9
+ from tensorflow.keras.optimizers import Adam, SGD
10
+ from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
11
+ import numpy as np
12
+
13
+ def load_and_preprocess_image(filename, label, image_size):
14
+ # Load image
15
+ image = tf.io.read_file(filename)
16
+ image = tf.image.decode_image(image, channels=3)
17
+
18
+ # Ensure the image tensor has shape
19
+ if not tf.executing_eagerly():
20
+ image.set_shape([None, None, 3])
21
+
22
+ # Resize image to the specified size
23
+ image = tf.image.resize(image, [image_size[0], image_size[1]]) # Use height and width from the tuple
24
+
25
+ # Normalize image to [0, 1]
26
+ image = image / 255.0
27
+ image.set_shape([image_size[0], image_size[1], 3])
28
+
29
+ return image, label
30
+
31
+ def create_dataset(data_dir, labels, image_size, batch_size):
32
+ image_files = []
33
+ image_labels = []
34
+
35
+ for label in labels:
36
+ label_dir = os.path.join(data_dir, label)
37
+ for image_file in os.listdir(label_dir):
38
+ image_files.append(os.path.join(label_dir, image_file))
39
+ image_labels.append(label)
40
+
41
+ # Create a mapping from labels to indices
42
+ label_map = {label: idx for idx, label in enumerate(labels)}
43
+ image_labels = [label_map[label] for label in image_labels]
44
+
45
+ # Convert to TensorFlow datasets
46
+ dataset = tf.data.Dataset.from_tensor_slices((image_files, image_labels))
47
+ dataset = dataset.map(lambda x, y: load_and_preprocess_image(x, y, image_size))
48
+ dataset = dataset.shuffle(buffer_size=len(image_files))
49
+ dataset = dataset.batch(batch_size).prefetch(buffer_size=tf.data.AUTOTUNE)
50
+
51
+ return dataset
52
+
53
+ def create_and_train_model(base_model, model_name, shape, X_train, X_val, num_classes, labels, log_dir, model_dir,
54
+ epochs, optimizer_name, learning_rate, step_gamma, alpha, batch_size, patience):
55
+ # Freeze the base model layers
56
+ for layer in base_model.layers:
57
+ layer.trainable = False
58
+
59
+ # Add custom layers on top
60
+ x = base_model.output
61
+ x = Flatten()(x)
62
+ x = Dense(1024, activation='relu')(x)
63
+ x = Dropout(0.25)(x)
64
+
65
+ x = Dense(512, activation='relu')(x)
66
+ x = Dropout(0.25)(x)
67
+
68
+ x = Dense(256, activation='relu')(x)
69
+ x = BatchNormalization()(x)
70
+ x = Dropout(0.25)(x)
71
+
72
+ predictions = Dense(num_classes, activation='softmax')(x) # Use the number of classes
73
+ model = Model(inputs=base_model.input, outputs=predictions)
74
+
75
+ # Learning rate schedule
76
+ lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
77
+ initial_learning_rate=learning_rate,
78
+ decay_steps=1000, # Adjust this according to your needs
79
+ decay_rate=step_gamma
80
+ )
81
+
82
+ # Select the optimizer
83
+ if optimizer_name.lower() == 'adam':
84
+ optimizer = Adam(learning_rate=lr_schedule)
85
+ elif optimizer_name.lower() == 'sgd':
86
+ optimizer = SGD(learning_rate=lr_schedule, momentum=alpha) # Example settings for SGD
87
+ else:
88
+ raise ValueError(f"Unsupported optimizer: {optimizer_name}")
89
+
90
+ # Compile the model
91
+ model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])
92
+
93
+ # Set up callbacks
94
+ checkpoint = ModelCheckpoint(os.path.join(model_dir, f'{model_name}_best_model.keras'),
95
+ monitor='val_accuracy', save_best_only=True, save_weights_only=False,
96
+ mode='max', verbose=1)
97
+ early_stopping = EarlyStopping(monitor='val_accuracy', patience=patience, verbose=1)
98
+
99
+ # Train the model
100
+ history = model.fit(X_train, epochs=epochs, validation_data=X_val, batch_size=batch_size,
101
+ callbacks=[checkpoint, early_stopping])
102
+
103
+ # Save training logs
104
+ with open(os.path.join(log_dir, f'{model_name}_training.log'), 'w') as f:
105
+ num_epochs = len(history.history['loss']) # Get the actual number of epochs completed
106
+ for epoch in range(num_epochs):
107
+ f.write(f"Epoch {epoch + 1}, "
108
+ f"Train Loss: {history.history['loss'][epoch]:.4f}, "
109
+ f"Train Accuracy: {history.history['accuracy'][epoch]:.4f}, "
110
+ f"Val Loss: {history.history['val_loss'][epoch]:.4f}, "
111
+ f"Val Accuracy: {history.history['val_accuracy'][epoch]:.4f}\n")
112
+
113
+ # Save labels in the model directory
114
+ with open(os.path.join(model_dir, 'labels.txt'), 'w') as f:
115
+ f.write('\n'.join(labels))
116
+
117
+ # Evaluate the model
118
+ test_loss, test_accuracy = model.evaluate(X_val)
119
+ print(f'Test Accuracy for {model_name}: {test_accuracy:.4f}')
120
+ print(f'Test Loss for {model_name}: {test_loss:.4f}')
121
+
122
+ # Optionally, save the trained model
123
+ model.save(os.path.join(model_dir, f'{model_name}_final_model.keras'))
124
+
125
+ def main(base_model_names, shape, data_path, log_dir, model_dir, epochs, optimizer, learning_rate, step_gamma, alpha, batch_size, patience):
126
+ if not os.path.exists(log_dir):
127
+ os.makedirs(log_dir)
128
+ if not os.path.exists(model_dir):
129
+ os.makedirs(model_dir)
130
+
131
+ # Extract labels from folder names
132
+ labels = sorted([d for d in os.listdir(os.path.join(data_path, 'train')) if os.path.isdir(os.path.join(data_path, 'train', d))])
133
+ num_classes = len(labels)
134
+
135
+ # Load data
136
+ X_train = create_dataset(os.path.join(data_path, 'train'), labels, shape, batch_size)
137
+ X_val = create_dataset(os.path.join(data_path, 'val'), labels, shape, batch_size)
138
+
139
+ if not base_model_names:
140
+ print("No base models specified. Exiting.")
141
+ return
142
+
143
+ # Define base models
144
+ base_models_dict = {
145
+ model_name: globals()[model_name](weights='imagenet', include_top=False, input_shape=shape)
146
+ for model_name in base_model_names
147
+ }
148
+
149
+ for model_name in base_model_names:
150
+ print(f'Training {model_name}...')
151
+ base_model = base_models_dict.get(model_name)
152
+ if base_model is None:
153
+ print(f"Model {model_name} not supported.")
154
+ continue
155
+ create_and_train_model(base_model, model_name, shape, X_train, X_val, num_classes, labels, log_dir, model_dir,
156
+ epochs, optimizer, learning_rate, step_gamma, alpha, batch_size, patience)
157
+
158
+ if __name__ == "__main__":
159
+ parser = argparse.ArgumentParser(description="Train models using transfer learning")
160
+ parser.add_argument('--base_models', type=str, nargs='+', default=[],
161
+ help='List of base models to use for training. Leave empty to skip model training.')
162
+ parser.add_argument('--shape', type=int, nargs=3, default=(224, 224, 3), help='Input shape of the images')
163
+ parser.add_argument('--data_path', type=str, required=True, help='Path to the image data')
164
+ parser.add_argument('--log_dir', type=str, required=True, help='Directory to save logs')
165
+ parser.add_argument('--model_dir', type=str, required=True, help='Directory to save models')
166
+ parser.add_argument('--epochs', type=int, default=100, help='Number of epochs to train')
167
+ parser.add_argument('--optimizer', type=str, default='adam', help='Optimizer to use (adam or sgd)')
168
+ parser.add_argument('--learning_rate', type=float, default=0.001, help='Initial learning rate')
169
+ parser.add_argument('--step_gamma', type=float, default=0.96, help='Gamma value for step learning rate schedule')
170
+ parser.add_argument('--alpha', type=float, default=0.9, help='Alpha for the optimizer (used for SGD)')
171
+ parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training')
172
+ parser.add_argument('--patience', type=int, default=10, help='Patience for early stopping')
173
+
174
+ args = parser.parse_args()
175
+ main(args.base_models, tuple(args.shape), args.data_path, args.log_dir, args.model_dir,
176
+ args.epochs, args.optimizer, args.learning_rate, args.step_gamma, args.alpha, args.batch_size, args.patience)