yxc97's picture
Upload folder using huggingface_hub
62a2f1c verified
import tensorflow as tf
import keras
import keras.layers as kl
from keras_nlp.layers import SinePositionEncoding, TransformerEncoder
class EnhancedHyenaPlusLayer(kl.Layer):
"""
Enhanced Hyena+DNA layer with multi-scale feature extraction, residual connections,
explicit dimension alignment, and layer normalization for improved gradient flow and stability.
"""
def __init__(self, filters, kernel_size, output_dim, use_residual=True, dilation_rate=1,
kernel_regularizer=None, **kwargs):
super(EnhancedHyenaPlusLayer, self).__init__(**kwargs)
self.filters = filters
self.kernel_size = kernel_size
self.output_dim = output_dim
self.use_residual = use_residual
self.dilation_rate = dilation_rate
self.kernel_regularizer = kernel_regularizer
# Core convolution for long-range dependencies with mild regularization
self.conv = kl.Conv1D(filters, kernel_size, padding='same',
kernel_regularizer=kernel_regularizer)
# Multi-scale feature extraction with dilated convolutions
self.dilated_conv = kl.Conv1D(filters // 2, kernel_size,
padding='same',
dilation_rate=dilation_rate,
kernel_regularizer=kernel_regularizer)
# Parallel small kernel convolution for local features
self.local_conv = kl.Conv1D(filters // 2, 3, padding='same',
kernel_regularizer=kernel_regularizer)
# Batch normalization and activation
self.batch_norm = kl.BatchNormalization()
self.activation = kl.Activation('relu')
# Feature fusion layer
self.fusion = kl.Dense(filters, kernel_regularizer=kernel_regularizer)
# Explicit dimension alignment projection with regularization
self.projection = kl.Dense(output_dim, kernel_regularizer=kernel_regularizer)
# Layer normalization for stability
self.layer_norm = kl.LayerNormalization()
# Input projection for residual connection if dimensions don't match
self.input_projection = None
if use_residual:
self.input_projection = kl.Dense(output_dim, kernel_regularizer=kernel_regularizer)
def call(self, inputs, training=None):
# Save input for residual connection
residual = inputs
# Process through main convolution
x_main = self.conv(inputs)
# Process through dilated convolution for capturing long-range patterns
x_dilated = self.dilated_conv(inputs)
# Process through local convolution for capturing local patterns
x_local = self.local_conv(inputs)
# Concatenate multi-scale features
x_multi = tf.concat([x_dilated, x_local], axis=-1)
# Fuse features
x = self.fusion(x_multi) + x_main
x = self.batch_norm(x, training=training)
x = self.activation(x)
# Project to target dimension
x = self.projection(x)
# Add residual connection if enabled
if self.use_residual:
# Project input if needed for dimension matching
residual = self.input_projection(residual)
x = x + residual
# Apply layer normalization
x = self.layer_norm(x)
return x
def get_config(self):
config = super(EnhancedHyenaPlusLayer, self).get_config()
config.update({
'filters': self.filters,
'kernel_size': self.kernel_size,
'output_dim': self.output_dim,
'use_residual': self.use_residual,
'dilation_rate': self.dilation_rate,
'kernel_regularizer': self.kernel_regularizer
})
return config
class HybridContextAwareMSTA(kl.Layer):
"""
Hybrid Context-Aware Motif-Specific Transformer Attention (HCA-MSTA) module
with enhanced biological interpretability and selective motif attention.
Combines the strengths of previous approaches with improved positional encoding.
"""
def __init__(self, num_motifs, motif_dim, num_heads=4, dropout_rate=0.1,
kernel_regularizer=None, activity_regularizer=None, **kwargs):
super(HybridContextAwareMSTA, self).__init__(**kwargs)
self.num_motifs = num_motifs
self.motif_dim = motif_dim
self.num_heads = num_heads
self.dropout_rate = dropout_rate
self.kernel_regularizer = kernel_regularizer
self.activity_regularizer = activity_regularizer
# Motif embeddings with mild regularization
self.motif_embeddings = self.add_weight(
shape=(num_motifs, motif_dim),
initializer='glorot_uniform',
regularizer=activity_regularizer,
trainable=True,
name='motif_embeddings'
)
# Positional encoding for motifs
self.motif_position_encoding = self.add_weight(
shape=(num_motifs, motif_dim),
initializer='glorot_uniform',
trainable=True,
name='motif_position_encoding'
)
# Biological prior weights for motifs (importance weights)
self.motif_importance = self.add_weight(
shape=(num_motifs, 1),
initializer='ones',
regularizer=activity_regularizer,
trainable=True,
name='motif_importance'
)
# Attention mechanism components with regularization
self.query_dense = kl.Dense(motif_dim, kernel_regularizer=kernel_regularizer)
self.key_dense = kl.Dense(motif_dim, kernel_regularizer=kernel_regularizer)
self.value_dense = kl.Dense(motif_dim, kernel_regularizer=kernel_regularizer)
# Multi-head attention
self.attention = kl.MultiHeadAttention(
num_heads=num_heads,
key_dim=motif_dim // num_heads,
dropout=dropout_rate
)
# Gating mechanism
self.gate_dense = kl.Dense(motif_dim, activation='sigmoid',
kernel_regularizer=kernel_regularizer)
# Output projection
self.output_dense = kl.Dense(motif_dim, kernel_regularizer=kernel_regularizer)
self.dropout = kl.Dropout(dropout_rate)
self.layer_norm = kl.LayerNormalization()
# Feed-forward network for feature enhancement
self.ffn_dense1 = kl.Dense(motif_dim * 2, activation='relu',
kernel_regularizer=kernel_regularizer)
self.ffn_dense2 = kl.Dense(motif_dim, kernel_regularizer=kernel_regularizer)
self.ffn_layer_norm = kl.LayerNormalization()
self.ffn_dropout = kl.Dropout(dropout_rate)
def positional_masking(self, sequence_embeddings, motif_embeddings):
"""
Generate hybrid positional masking based on sequence and motif relevance
with improved biological context awareness and motif importance weighting.
Combines inverse distance and Gaussian approaches for better biological relevance.
"""
# Calculate similarity between sequence embeddings and motif embeddings
similarity = tf.matmul(sequence_embeddings, tf.transpose(motif_embeddings, [0, 2, 1]))
# Scale similarity scores for numerical stability
scaled_similarity = similarity / tf.sqrt(tf.cast(self.motif_dim, tf.float32))
# Apply softmax to get attention-like weights
attention_weights = tf.nn.softmax(scaled_similarity, axis=-1)
# Calculate position-aware weights with hybrid approach
seq_length = tf.shape(sequence_embeddings)[1]
motif_length = tf.shape(motif_embeddings)[1]
# Create position indices
position_indices = tf.range(seq_length)[:, tf.newaxis] - tf.range(motif_length)[tf.newaxis, :]
position_indices_float = tf.cast(position_indices, tf.float32)
# Inverse distance weighting (for local context)
inverse_weights = 1.0 / (1.0 + tf.abs(position_indices_float))
# Gaussian weighting (for smooth transitions)
gaussian_weights = tf.exp(-0.5 * tf.square(position_indices_float / 8.0)) # Gaussian with σ=8
# Combine both weighting schemes for a hybrid approach
# This captures both sharp local context and smooth transitions
position_weights = 0.5 * inverse_weights + 0.5 * gaussian_weights
position_weights = tf.expand_dims(position_weights, 0) # Add batch dimension
# Apply motif importance weighting with temperature scaling for sharper focus
motif_weights = tf.nn.softmax(self.motif_importance * 1.5, axis=0) # Temperature scaling
motif_weights = tf.expand_dims(tf.expand_dims(motif_weights, 0), 1) # [1, 1, num_motifs, 1]
# Combine attention weights with position weights and motif importance
combined_weights = attention_weights * position_weights * tf.squeeze(motif_weights, -1)
return combined_weights
def call(self, inputs, training=None):
# Add positional encoding to motif embeddings
batch_size = tf.shape(inputs)[0]
# Expand motif embeddings and position encodings to batch dimension
motifs = tf.tile(tf.expand_dims(self.motif_embeddings, 0), [batch_size, 1, 1])
pos_encoding = tf.tile(tf.expand_dims(self.motif_position_encoding, 0), [batch_size, 1, 1])
# Add positional encoding to motifs
motifs_with_pos = motifs + pos_encoding
# Prepare query from input sequence embeddings
query = self.query_dense(inputs)
# Prepare key and value from motifs with positional encoding
key = self.key_dense(motifs_with_pos)
value = self.value_dense(motifs_with_pos)
# Generate positional masking
pos_mask = self.positional_masking(query, motifs_with_pos)
# Apply attention with positional masking
attention_output = self.attention(
query=query,
key=key,
value=value,
attention_mask=pos_mask,
training=training
)
# Apply gating mechanism to selectively focus on relevant features
gate = self.gate_dense(inputs)
gated_attention = gate * attention_output
# Process through output projection with residual connection
output = self.output_dense(gated_attention)
output = self.dropout(output, training=training)
output = self.layer_norm(output + inputs) # Residual connection
# Apply feed-forward network with residual connection
ffn_output = self.ffn_dense1(output)
ffn_output = self.ffn_dense2(ffn_output)
ffn_output = self.ffn_dropout(ffn_output, training=training)
final_output = self.ffn_layer_norm(output + ffn_output) # Residual connection
return final_output
def get_config(self):
config = super(HybridContextAwareMSTA, self).get_config()
config.update({
'num_motifs': self.num_motifs,
'motif_dim': self.motif_dim,
'num_heads': self.num_heads,
'dropout_rate': self.dropout_rate,
'kernel_regularizer': self.kernel_regularizer,
'activity_regularizer': self.activity_regularizer
})
return config
def HyenaMSTAPlus(params):
"""
Enhanced HyenaMSTA+ model for enhancer activity prediction with multi-scale feature
extraction, hybrid attention mechanism, and improved biological context modeling.
"""
if params['encode'] == 'one-hot':
input_layer = kl.Input(shape=(249, 4))
elif params['encode'] == 'k-mer':
input_layer = kl.Input(shape=(1, 64))
# Regularization settings - milder than previous run
l2_reg = params.get('l2_reg', 1e-6)
kernel_regularizer = tf.keras.regularizers.l2(l2_reg)
activity_regularizer = tf.keras.regularizers.l1(l2_reg/20)
# Hyena+DNA processing
x = input_layer
hyena_layers = []
# Number of motifs and embedding dimension - optimized based on previous runs
num_motifs = params.get('num_motifs', 48) # Adjusted to optimal value from Run 2
motif_dim = params.get('motif_dim', 96) # Adjusted to optimal value from Run 2
# Apply Enhanced Hyena+DNA layers with increasing dilation rates
for i in range(params['convolution_layers']['n_layers']):
# Use increasing dilation rates for broader receptive field
dilation_rate = 2**min(i, 2) # 1, 2, 4 (capped at 4 to avoid excessive sparsity)
hyena_layer = EnhancedHyenaPlusLayer(
filters=params['convolution_layers']['filters'][i],
kernel_size=params['convolution_layers']['kernel_sizes'][i],
output_dim=motif_dim,
dilation_rate=dilation_rate,
kernel_regularizer=kernel_regularizer,
name=f'EnhancedHyenaPlus_{i+1}'
)
x = hyena_layer(x)
hyena_layers.append(x)
if params['encode'] == 'one-hot':
x = kl.MaxPooling1D(2)(x)
if params['dropout_conv'] == 'yes':
x = kl.Dropout(params['dropout_prob'])(x)
# Hybrid Context-Aware MSTA processing
ca_msta = HybridContextAwareMSTA(
num_motifs=num_motifs,
motif_dim=motif_dim,
num_heads=params.get('ca_msta_heads', 8),
dropout_rate=params['dropout_prob'],
kernel_regularizer=kernel_regularizer,
activity_regularizer=activity_regularizer
)
x = ca_msta(x)
# Flatten and dense layers
x = kl.Flatten()(x)
# Fully connected layers
for i in range(params['n_dense_layer']):
x = kl.Dense(params['dense_neurons'+str(i+1)],
name=str('Dense_'+str(i+1)))(x)
x = kl.BatchNormalization()(x)
x = kl.Activation('relu')(x)
x = kl.Dropout(params['dropout_prob'])(x)
# Main model bottleneck
bottleneck = x
# Heads per task (developmental and housekeeping enhancer activities)
tasks = ['Dev', 'Hk']
outputs = []
for task in tasks:
outputs.append(kl.Dense(1, activation='linear', name=str('Dense_' + task))(bottleneck))
# Build Keras model
model = keras.models.Model([input_layer], outputs)
model.compile(
keras.optimizers.Adam(learning_rate=params['lr']),
loss=['mse', 'mse'],
loss_weights=[1, 1]
)
return model, params