File size: 14,877 Bytes
62a2f1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 |
import tensorflow as tf
import keras
import keras.layers as kl
from keras_nlp.layers import SinePositionEncoding, TransformerEncoder
class EnhancedHyenaPlusLayer(kl.Layer):
"""
Enhanced Hyena+DNA layer with multi-scale feature extraction, residual connections,
explicit dimension alignment, and layer normalization for improved gradient flow and stability.
"""
def __init__(self, filters, kernel_size, output_dim, use_residual=True, dilation_rate=1,
kernel_regularizer=None, **kwargs):
super(EnhancedHyenaPlusLayer, self).__init__(**kwargs)
self.filters = filters
self.kernel_size = kernel_size
self.output_dim = output_dim
self.use_residual = use_residual
self.dilation_rate = dilation_rate
self.kernel_regularizer = kernel_regularizer
# Core convolution for long-range dependencies with mild regularization
self.conv = kl.Conv1D(filters, kernel_size, padding='same',
kernel_regularizer=kernel_regularizer)
# Multi-scale feature extraction with dilated convolutions
self.dilated_conv = kl.Conv1D(filters // 2, kernel_size,
padding='same',
dilation_rate=dilation_rate,
kernel_regularizer=kernel_regularizer)
# Parallel small kernel convolution for local features
self.local_conv = kl.Conv1D(filters // 2, 3, padding='same',
kernel_regularizer=kernel_regularizer)
# Batch normalization and activation
self.batch_norm = kl.BatchNormalization()
self.activation = kl.Activation('relu')
# Feature fusion layer
self.fusion = kl.Dense(filters, kernel_regularizer=kernel_regularizer)
# Explicit dimension alignment projection with regularization
self.projection = kl.Dense(output_dim, kernel_regularizer=kernel_regularizer)
# Layer normalization for stability
self.layer_norm = kl.LayerNormalization()
# Input projection for residual connection if dimensions don't match
self.input_projection = None
if use_residual:
self.input_projection = kl.Dense(output_dim, kernel_regularizer=kernel_regularizer)
def call(self, inputs, training=None):
# Save input for residual connection
residual = inputs
# Process through main convolution
x_main = self.conv(inputs)
# Process through dilated convolution for capturing long-range patterns
x_dilated = self.dilated_conv(inputs)
# Process through local convolution for capturing local patterns
x_local = self.local_conv(inputs)
# Concatenate multi-scale features
x_multi = tf.concat([x_dilated, x_local], axis=-1)
# Fuse features
x = self.fusion(x_multi) + x_main
x = self.batch_norm(x, training=training)
x = self.activation(x)
# Project to target dimension
x = self.projection(x)
# Add residual connection if enabled
if self.use_residual:
# Project input if needed for dimension matching
residual = self.input_projection(residual)
x = x + residual
# Apply layer normalization
x = self.layer_norm(x)
return x
def get_config(self):
config = super(EnhancedHyenaPlusLayer, self).get_config()
config.update({
'filters': self.filters,
'kernel_size': self.kernel_size,
'output_dim': self.output_dim,
'use_residual': self.use_residual,
'dilation_rate': self.dilation_rate,
'kernel_regularizer': self.kernel_regularizer
})
return config
class HybridContextAwareMSTA(kl.Layer):
"""
Hybrid Context-Aware Motif-Specific Transformer Attention (HCA-MSTA) module
with enhanced biological interpretability and selective motif attention.
Combines the strengths of previous approaches with improved positional encoding.
"""
def __init__(self, num_motifs, motif_dim, num_heads=4, dropout_rate=0.1,
kernel_regularizer=None, activity_regularizer=None, **kwargs):
super(HybridContextAwareMSTA, self).__init__(**kwargs)
self.num_motifs = num_motifs
self.motif_dim = motif_dim
self.num_heads = num_heads
self.dropout_rate = dropout_rate
self.kernel_regularizer = kernel_regularizer
self.activity_regularizer = activity_regularizer
# Motif embeddings with mild regularization
self.motif_embeddings = self.add_weight(
shape=(num_motifs, motif_dim),
initializer='glorot_uniform',
regularizer=activity_regularizer,
trainable=True,
name='motif_embeddings'
)
# Positional encoding for motifs
self.motif_position_encoding = self.add_weight(
shape=(num_motifs, motif_dim),
initializer='glorot_uniform',
trainable=True,
name='motif_position_encoding'
)
# Biological prior weights for motifs (importance weights)
self.motif_importance = self.add_weight(
shape=(num_motifs, 1),
initializer='ones',
regularizer=activity_regularizer,
trainable=True,
name='motif_importance'
)
# Attention mechanism components with regularization
self.query_dense = kl.Dense(motif_dim, kernel_regularizer=kernel_regularizer)
self.key_dense = kl.Dense(motif_dim, kernel_regularizer=kernel_regularizer)
self.value_dense = kl.Dense(motif_dim, kernel_regularizer=kernel_regularizer)
# Multi-head attention
self.attention = kl.MultiHeadAttention(
num_heads=num_heads,
key_dim=motif_dim // num_heads,
dropout=dropout_rate
)
# Gating mechanism
self.gate_dense = kl.Dense(motif_dim, activation='sigmoid',
kernel_regularizer=kernel_regularizer)
# Output projection
self.output_dense = kl.Dense(motif_dim, kernel_regularizer=kernel_regularizer)
self.dropout = kl.Dropout(dropout_rate)
self.layer_norm = kl.LayerNormalization()
# Feed-forward network for feature enhancement
self.ffn_dense1 = kl.Dense(motif_dim * 2, activation='relu',
kernel_regularizer=kernel_regularizer)
self.ffn_dense2 = kl.Dense(motif_dim, kernel_regularizer=kernel_regularizer)
self.ffn_layer_norm = kl.LayerNormalization()
self.ffn_dropout = kl.Dropout(dropout_rate)
def positional_masking(self, sequence_embeddings, motif_embeddings):
"""
Generate hybrid positional masking based on sequence and motif relevance
with improved biological context awareness and motif importance weighting.
Combines inverse distance and Gaussian approaches for better biological relevance.
"""
# Calculate similarity between sequence embeddings and motif embeddings
similarity = tf.matmul(sequence_embeddings, tf.transpose(motif_embeddings, [0, 2, 1]))
# Scale similarity scores for numerical stability
scaled_similarity = similarity / tf.sqrt(tf.cast(self.motif_dim, tf.float32))
# Apply softmax to get attention-like weights
attention_weights = tf.nn.softmax(scaled_similarity, axis=-1)
# Calculate position-aware weights with hybrid approach
seq_length = tf.shape(sequence_embeddings)[1]
motif_length = tf.shape(motif_embeddings)[1]
# Create position indices
position_indices = tf.range(seq_length)[:, tf.newaxis] - tf.range(motif_length)[tf.newaxis, :]
position_indices_float = tf.cast(position_indices, tf.float32)
# Inverse distance weighting (for local context)
inverse_weights = 1.0 / (1.0 + tf.abs(position_indices_float))
# Gaussian weighting (for smooth transitions)
gaussian_weights = tf.exp(-0.5 * tf.square(position_indices_float / 8.0)) # Gaussian with σ=8
# Combine both weighting schemes for a hybrid approach
# This captures both sharp local context and smooth transitions
position_weights = 0.5 * inverse_weights + 0.5 * gaussian_weights
position_weights = tf.expand_dims(position_weights, 0) # Add batch dimension
# Apply motif importance weighting with temperature scaling for sharper focus
motif_weights = tf.nn.softmax(self.motif_importance * 1.5, axis=0) # Temperature scaling
motif_weights = tf.expand_dims(tf.expand_dims(motif_weights, 0), 1) # [1, 1, num_motifs, 1]
# Combine attention weights with position weights and motif importance
combined_weights = attention_weights * position_weights * tf.squeeze(motif_weights, -1)
return combined_weights
def call(self, inputs, training=None):
# Add positional encoding to motif embeddings
batch_size = tf.shape(inputs)[0]
# Expand motif embeddings and position encodings to batch dimension
motifs = tf.tile(tf.expand_dims(self.motif_embeddings, 0), [batch_size, 1, 1])
pos_encoding = tf.tile(tf.expand_dims(self.motif_position_encoding, 0), [batch_size, 1, 1])
# Add positional encoding to motifs
motifs_with_pos = motifs + pos_encoding
# Prepare query from input sequence embeddings
query = self.query_dense(inputs)
# Prepare key and value from motifs with positional encoding
key = self.key_dense(motifs_with_pos)
value = self.value_dense(motifs_with_pos)
# Generate positional masking
pos_mask = self.positional_masking(query, motifs_with_pos)
# Apply attention with positional masking
attention_output = self.attention(
query=query,
key=key,
value=value,
attention_mask=pos_mask,
training=training
)
# Apply gating mechanism to selectively focus on relevant features
gate = self.gate_dense(inputs)
gated_attention = gate * attention_output
# Process through output projection with residual connection
output = self.output_dense(gated_attention)
output = self.dropout(output, training=training)
output = self.layer_norm(output + inputs) # Residual connection
# Apply feed-forward network with residual connection
ffn_output = self.ffn_dense1(output)
ffn_output = self.ffn_dense2(ffn_output)
ffn_output = self.ffn_dropout(ffn_output, training=training)
final_output = self.ffn_layer_norm(output + ffn_output) # Residual connection
return final_output
def get_config(self):
config = super(HybridContextAwareMSTA, self).get_config()
config.update({
'num_motifs': self.num_motifs,
'motif_dim': self.motif_dim,
'num_heads': self.num_heads,
'dropout_rate': self.dropout_rate,
'kernel_regularizer': self.kernel_regularizer,
'activity_regularizer': self.activity_regularizer
})
return config
def HyenaMSTAPlus(params):
"""
Enhanced HyenaMSTA+ model for enhancer activity prediction with multi-scale feature
extraction, hybrid attention mechanism, and improved biological context modeling.
"""
if params['encode'] == 'one-hot':
input_layer = kl.Input(shape=(249, 4))
elif params['encode'] == 'k-mer':
input_layer = kl.Input(shape=(1, 64))
# Regularization settings - milder than previous run
l2_reg = params.get('l2_reg', 1e-6)
kernel_regularizer = tf.keras.regularizers.l2(l2_reg)
activity_regularizer = tf.keras.regularizers.l1(l2_reg/20)
# Hyena+DNA processing
x = input_layer
hyena_layers = []
# Number of motifs and embedding dimension - optimized based on previous runs
num_motifs = params.get('num_motifs', 48) # Adjusted to optimal value from Run 2
motif_dim = params.get('motif_dim', 96) # Adjusted to optimal value from Run 2
# Apply Enhanced Hyena+DNA layers with increasing dilation rates
for i in range(params['convolution_layers']['n_layers']):
# Use increasing dilation rates for broader receptive field
dilation_rate = 2**min(i, 2) # 1, 2, 4 (capped at 4 to avoid excessive sparsity)
hyena_layer = EnhancedHyenaPlusLayer(
filters=params['convolution_layers']['filters'][i],
kernel_size=params['convolution_layers']['kernel_sizes'][i],
output_dim=motif_dim,
dilation_rate=dilation_rate,
kernel_regularizer=kernel_regularizer,
name=f'EnhancedHyenaPlus_{i+1}'
)
x = hyena_layer(x)
hyena_layers.append(x)
if params['encode'] == 'one-hot':
x = kl.MaxPooling1D(2)(x)
if params['dropout_conv'] == 'yes':
x = kl.Dropout(params['dropout_prob'])(x)
# Hybrid Context-Aware MSTA processing
ca_msta = HybridContextAwareMSTA(
num_motifs=num_motifs,
motif_dim=motif_dim,
num_heads=params.get('ca_msta_heads', 8),
dropout_rate=params['dropout_prob'],
kernel_regularizer=kernel_regularizer,
activity_regularizer=activity_regularizer
)
x = ca_msta(x)
# Flatten and dense layers
x = kl.Flatten()(x)
# Fully connected layers
for i in range(params['n_dense_layer']):
x = kl.Dense(params['dense_neurons'+str(i+1)],
name=str('Dense_'+str(i+1)))(x)
x = kl.BatchNormalization()(x)
x = kl.Activation('relu')(x)
x = kl.Dropout(params['dropout_prob'])(x)
# Main model bottleneck
bottleneck = x
# Heads per task (developmental and housekeeping enhancer activities)
tasks = ['Dev', 'Hk']
outputs = []
for task in tasks:
outputs.append(kl.Dense(1, activation='linear', name=str('Dense_' + task))(bottleneck))
# Build Keras model
model = keras.models.Model([input_layer], outputs)
model.compile(
keras.optimizers.Adam(learning_rate=params['lr']),
loss=['mse', 'mse'],
loss_weights=[1, 1]
)
return model, params
|