|
import os |
|
import logging |
|
from dataclasses import dataclass |
|
from typing import Optional, Tuple, List, Dict, Any |
|
import time |
|
import json |
|
import pathlib |
|
from tqdm import tqdm |
|
import pandas as pd |
|
import numpy as np |
|
import argparse |
|
import torch |
|
from torch import nn |
|
from torch.utils.data import DataLoader, Dataset |
|
from transformers import ( |
|
get_linear_schedule_with_warmup, |
|
BertForSequenceClassification, |
|
AutoTokenizer, |
|
AdamW |
|
) |
|
from sklearn.metrics import roc_auc_score |
|
|
|
import traceback |
|
|
|
|
|
logging.basicConfig( |
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
level=logging.INFO, |
|
handlers=[ |
|
logging.FileHandler('training.log'), |
|
logging.StreamHandler() |
|
] |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@dataclass |
|
class TrainingConfig: |
|
max_seq_len: int = 50 |
|
epochs: int = 3 |
|
batch_size: int = 32 |
|
learning_rate: float = 2e-5 |
|
patience: int = 1 |
|
max_grad_norm: float = 10.0 |
|
warmup_ratio: float = 0.1 |
|
model_path: str = '/cpfs01/shared/MA4Tool/hug_ckpts/BERT_ckpt' |
|
num_labels: int = 2 |
|
if_save_model: bool = True |
|
out_dir: str = './run_0' |
|
|
|
def validate(self) -> None: |
|
if self.max_seq_len <= 0: |
|
raise ValueError("max_seq_len must be positive") |
|
if self.epochs <= 0: |
|
raise ValueError("epochs must be positive") |
|
if self.batch_size <= 0: |
|
raise ValueError("batch_size must be positive") |
|
if not (0.0 < self.learning_rate): |
|
raise ValueError("learning_rate must be between 0 and 1") |
|
|
|
|
|
class DataPrecessForSentence(Dataset): |
|
def __init__(self, bert_tokenizer: AutoTokenizer, df: pd.DataFrame, max_seq_len: int = 50): |
|
self.bert_tokenizer = bert_tokenizer |
|
self.max_seq_len = max_seq_len |
|
self.input_ids, self.attention_mask, self.token_type_ids, self.labels = self._get_input(df) |
|
|
|
def __len__(self) -> int: |
|
return len(self.labels) |
|
|
|
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
|
return ( |
|
self.input_ids[idx], |
|
self.attention_mask[idx], |
|
self.token_type_ids[idx], |
|
self.labels[idx] |
|
) |
|
|
|
def _get_input(self, df: pd.DataFrame) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
|
sentences = df['s1'].values |
|
labels = df['similarity'].values |
|
|
|
tokens_seq = list(map(self.bert_tokenizer.tokenize, sentences)) |
|
result = list(map(self._truncate_and_pad, tokens_seq)) |
|
|
|
input_ids = torch.tensor([i[0] for i in result], dtype=torch.long) |
|
attention_mask = torch.tensor([i[1] for i in result], dtype=torch.long) |
|
token_type_ids = torch.tensor([i[2] for i in result], dtype=torch.long) |
|
labels = torch.tensor(labels, dtype=torch.long) |
|
|
|
return input_ids, attention_mask, token_type_ids, labels |
|
|
|
def _truncate_and_pad(self, tokens_seq: List[str]) -> Tuple[List[int], List[int], List[int]]: |
|
tokens_seq = ['[CLS]'] + tokens_seq[:self.max_seq_len - 1] |
|
padding_length = self.max_seq_len - len(tokens_seq) |
|
|
|
input_ids = self.bert_tokenizer.convert_tokens_to_ids(tokens_seq) |
|
input_ids += [0] * padding_length |
|
attention_mask = [1] * len(tokens_seq) + [0] * padding_length |
|
token_type_ids = [0] * self.max_seq_len |
|
|
|
return input_ids, attention_mask, token_type_ids |
|
|
|
|
|
class BertClassifier(nn.Module): |
|
def __init__(self, model_path: str, num_labels: int, requires_grad: bool = True): |
|
super().__init__() |
|
try: |
|
self.bert = BertForSequenceClassification.from_pretrained( |
|
model_path, |
|
num_labels=num_labels |
|
) |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
except Exception as e: |
|
logger.error(f"Failed to load BERT model: {e}") |
|
raise |
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
for param in self.bert.parameters(): |
|
param.requires_grad = requires_grad |
|
|
|
def forward( |
|
self, |
|
batch_seqs: torch.Tensor, |
|
batch_seq_masks: torch.Tensor, |
|
batch_seq_segments: torch.Tensor, |
|
labels: torch.Tensor |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
loss, logits = self.bert( |
|
input_ids=batch_seqs, |
|
attention_mask=batch_seq_masks, |
|
token_type_ids=batch_seq_segments, |
|
labels=labels |
|
)[:2] |
|
probabilities = nn.functional.softmax(logits, dim=-1) |
|
return loss, logits, probabilities |
|
|
|
|
|
class BertTrainer: |
|
def __init__(self, config: TrainingConfig): |
|
self.config = config |
|
self.config.validate() |
|
self.model = BertClassifier(config.model_path, config.num_labels) |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.model.to(self.device) |
|
|
|
def _prepare_data( |
|
self, |
|
train_df: pd.DataFrame, |
|
dev_df: pd.DataFrame, |
|
test_df: pd.DataFrame |
|
) -> Tuple[DataLoader, DataLoader, DataLoader]: |
|
train_data = DataPrecessForSentence( |
|
self.model.tokenizer, |
|
train_df, |
|
max_seq_len=self.config.max_seq_len |
|
) |
|
train_loader = DataLoader( |
|
train_data, |
|
shuffle=True, |
|
batch_size=self.config.batch_size |
|
) |
|
|
|
dev_data = DataPrecessForSentence( |
|
self.model.tokenizer, |
|
dev_df, |
|
max_seq_len=self.config.max_seq_len |
|
) |
|
dev_loader = DataLoader( |
|
dev_data, |
|
shuffle=False, |
|
batch_size=self.config.batch_size |
|
) |
|
|
|
test_data = DataPrecessForSentence( |
|
self.model.tokenizer, |
|
test_df, |
|
max_seq_len=self.config.max_seq_len |
|
) |
|
test_loader = DataLoader( |
|
test_data, |
|
shuffle=False, |
|
batch_size=self.config.batch_size |
|
) |
|
|
|
return train_loader, dev_loader, test_loader |
|
|
|
def _prepare_optimizer(self, num_training_steps: int) -> Tuple[AdamW, Any]: |
|
param_optimizer = list(self.model.named_parameters()) |
|
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] |
|
optimizer_grouped_parameters = [ |
|
{ |
|
'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], |
|
'weight_decay': 0.01 |
|
}, |
|
{ |
|
'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], |
|
'weight_decay': 0.0 |
|
} |
|
] |
|
|
|
optimizer = AdamW( |
|
optimizer_grouped_parameters, |
|
lr=self.config.learning_rate |
|
) |
|
|
|
scheduler = get_linear_schedule_with_warmup( |
|
optimizer, |
|
num_warmup_steps=int(num_training_steps * self.config.warmup_ratio), |
|
num_training_steps=num_training_steps |
|
) |
|
|
|
return optimizer, scheduler |
|
|
|
def _initialize_training_stats(self) -> Dict[str, List]: |
|
return { |
|
'epochs_count': [], |
|
'train_losses': [], |
|
'train_accuracies': [], |
|
'valid_losses': [], |
|
'valid_accuracies': [], |
|
'valid_aucs': [] |
|
} |
|
|
|
def _update_training_stats( |
|
self, |
|
training_stats: Dict[str, List], |
|
epoch: int, |
|
train_metrics: Dict[str, float], |
|
val_metrics: Dict[str, float] |
|
) -> None: |
|
training_stats['epochs_count'].append(epoch) |
|
training_stats['train_losses'].append(train_metrics['loss']) |
|
training_stats['train_accuracies'].append(train_metrics['accuracy']) |
|
training_stats['valid_losses'].append(val_metrics['loss']) |
|
training_stats['valid_accuracies'].append(val_metrics['accuracy']) |
|
training_stats['valid_aucs'].append(val_metrics['auc']) |
|
|
|
logger.info( |
|
f"Training - Loss: {train_metrics['loss']:.4f}, " |
|
f"Accuracy: {train_metrics['accuracy'] * 100:.2f}%" |
|
) |
|
logger.info( |
|
f"Validation - Loss: {val_metrics['loss']:.4f}, " |
|
f"Accuracy: {val_metrics['accuracy'] * 100:.2f}%, " |
|
f"AUC: {val_metrics['auc']:.4f}" |
|
) |
|
|
|
def _save_checkpoint( |
|
self, |
|
target_dir: str, |
|
epoch: int, |
|
optimizer: AdamW, |
|
best_score: float, |
|
training_stats: Dict[str, List] |
|
) -> None: |
|
checkpoint = { |
|
"epoch": epoch, |
|
"model": self.model.state_dict(), |
|
"optimizer": optimizer.state_dict(), |
|
"best_score": best_score, |
|
**training_stats |
|
} |
|
torch.save( |
|
checkpoint, |
|
os.path.join(target_dir, "best.pth.tar") |
|
) |
|
logger.info("Model saved successfully") |
|
|
|
def _load_checkpoint( |
|
self, |
|
checkpoint_path: str, |
|
optimizer: AdamW, |
|
training_stats: Dict[str, List] |
|
) -> float: |
|
checkpoint = torch.load(checkpoint_path) |
|
self.model.load_state_dict(checkpoint["model"]) |
|
optimizer.load_state_dict(checkpoint["optimizer"]) |
|
for key in training_stats: |
|
training_stats[key] = checkpoint[key] |
|
logger.info(f"Loaded checkpoint from epoch {checkpoint['epoch']}") |
|
return checkpoint["best_score"] |
|
|
|
def _train_epoch( |
|
self, |
|
train_loader: DataLoader, |
|
optimizer: AdamW, |
|
scheduler: Any |
|
) -> Dict[str, float]: |
|
self.model.train() |
|
total_loss = 0 |
|
correct_preds = 0 |
|
|
|
for batch in tqdm(train_loader, desc="Training"): |
|
batch = tuple(t.to(self.device) for t in batch) |
|
input_ids, attention_mask, token_type_ids, labels = batch |
|
|
|
optimizer.zero_grad() |
|
loss, _, probabilities = self.model(input_ids, attention_mask, token_type_ids, labels) |
|
|
|
loss.backward() |
|
nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm) |
|
|
|
optimizer.step() |
|
scheduler.step() |
|
|
|
total_loss += loss.item() |
|
correct_preds += (probabilities.argmax(dim=1) == labels).sum().item() |
|
|
|
return { |
|
'loss': total_loss / len(train_loader), |
|
'accuracy': correct_preds / len(train_loader.dataset) |
|
} |
|
|
|
def _validate_epoch(self, dev_loader: DataLoader) -> Tuple[Dict[str, float], List[float]]: |
|
self.model.eval() |
|
total_loss = 0 |
|
correct_preds = 0 |
|
all_probs = [] |
|
all_labels = [] |
|
|
|
with torch.no_grad(): |
|
for batch in tqdm(dev_loader, desc="Validating"): |
|
batch = tuple(t.to(self.device) for t in batch) |
|
input_ids, attention_mask, token_type_ids, labels = batch |
|
|
|
loss, _, probabilities = self.model(input_ids, attention_mask, token_type_ids, labels) |
|
|
|
total_loss += loss.item() |
|
correct_preds += (probabilities.argmax(dim=1) == labels).sum().item() |
|
all_probs.extend(probabilities[:, 1].cpu().numpy()) |
|
all_labels.extend(labels.cpu().numpy()) |
|
|
|
metrics = { |
|
'loss': total_loss / len(dev_loader), |
|
'accuracy': correct_preds / len(dev_loader.dataset), |
|
'auc': roc_auc_score(all_labels, all_probs) |
|
} |
|
|
|
return metrics, all_probs |
|
|
|
def _evaluate_test_set( |
|
self, |
|
test_loader: DataLoader, |
|
target_dir: str, |
|
epoch: int |
|
) -> None: |
|
test_metrics, all_probs = self._validate_epoch(test_loader) |
|
logger.info(f"Test accuracy: {test_metrics['accuracy'] * 100:.2f}%") |
|
|
|
test_prediction = pd.DataFrame({'prob_1': all_probs}) |
|
test_prediction['prob_0'] = 1 - test_prediction['prob_1'] |
|
test_prediction['prediction'] = test_prediction.apply( |
|
lambda x: 0 if (x['prob_0'] > x['prob_1']) else 1, |
|
axis=1 |
|
) |
|
|
|
output_path = os.path.join(target_dir, f"test_prediction_epoch_{epoch}.csv") |
|
test_prediction.to_csv(output_path, index=False) |
|
logger.info(f"Test predictions saved to {output_path}") |
|
|
|
def train_and_evaluate( |
|
self, |
|
train_df: pd.DataFrame, |
|
dev_df: pd.DataFrame, |
|
test_df: pd.DataFrame, |
|
target_dir: str, |
|
checkpoint: Optional[str] = None |
|
) -> None: |
|
try: |
|
os.makedirs(target_dir, exist_ok=True) |
|
|
|
train_loader, dev_loader, test_loader = self._prepare_data( |
|
train_df, dev_df, test_df |
|
) |
|
|
|
optimizer, scheduler = self._prepare_optimizer( |
|
len(train_loader) * self.config.epochs |
|
) |
|
|
|
training_stats = self._initialize_training_stats() |
|
best_score = 0.0 |
|
patience_counter = 0 |
|
|
|
if checkpoint: |
|
best_score = self._load_checkpoint(checkpoint, optimizer, training_stats) |
|
|
|
for epoch in range(1, self.config.epochs + 1): |
|
logger.info(f"Training epoch {epoch}") |
|
|
|
|
|
train_metrics = self._train_epoch(train_loader, optimizer, scheduler) |
|
|
|
|
|
val_metrics, _ = self._validate_epoch(dev_loader) |
|
|
|
self._update_training_stats(training_stats, epoch, train_metrics, val_metrics) |
|
|
|
|
|
if val_metrics['accuracy'] > best_score: |
|
best_score = val_metrics['accuracy'] |
|
patience_counter = 0 |
|
if self.config.if_save_model: |
|
self._save_checkpoint( |
|
target_dir, |
|
epoch, |
|
optimizer, |
|
best_score, |
|
training_stats |
|
) |
|
self._evaluate_test_set(test_loader, target_dir, epoch) |
|
else: |
|
patience_counter += 1 |
|
if patience_counter >= self.config.patience: |
|
logger.info("Early stopping triggered") |
|
break |
|
|
|
final_infos = { |
|
"sentiment": { |
|
"means": { |
|
"best_acc": best_score |
|
} |
|
} |
|
} |
|
|
|
with open(os.path.join(self.config.out_dir, "final_info.json"), "w") as f: |
|
json.dump(final_infos, f) |
|
|
|
except Exception as e: |
|
logger.error(f"Training failed: {e}") |
|
raise |
|
|
|
|
|
def set_seed(seed: int = 42) -> None: |
|
import random |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = False |
|
os.environ['PYTHONHASHSEED'] = str(seed) |
|
|
|
|
|
def main(out_dir): |
|
try: |
|
config = TrainingConfig(out_dir=out_dir) |
|
pathlib.Path(config.out_dir).mkdir(parents=True, exist_ok=True) |
|
|
|
data_path = "/cpfs01/shared/MA4Tool/datasets/SST-2/" |
|
train_df = pd.read_csv( |
|
os.path.join(data_path, "train.tsv"), |
|
sep='\t', |
|
header=None, |
|
names=['similarity', 's1'] |
|
) |
|
dev_df = pd.read_csv( |
|
os.path.join(data_path, "dev.tsv"), |
|
sep='\t', |
|
header=None, |
|
names=['similarity', 's1'] |
|
) |
|
test_df = pd.read_csv( |
|
os.path.join(data_path, "test.tsv"), |
|
sep='\t', |
|
header=None, |
|
names=['similarity', 's1'] |
|
) |
|
|
|
set_seed(2024) |
|
|
|
trainer = BertTrainer(config) |
|
trainer.train_and_evaluate(train_df, dev_df, test_df, "./output/Bert/") |
|
|
|
except Exception as e: |
|
logger.error(f"Program failed: {e}") |
|
raise |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--out_dir", type=str, default="run_0") |
|
args = parser.parse_args() |
|
try: |
|
main(args.out_dir) |
|
except Exception as e: |
|
print("Original error in subprocess:", flush=True) |
|
traceback.print_exc(file=open(os.path.join(args.out_dir, "traceback.log"), "w")) |
|
raise |
|
|