|
import tensorflow as tf |
|
import keras |
|
import keras.layers as kl |
|
from keras_nlp.layers import SinePositionEncoding, TransformerEncoder |
|
|
|
class EnhancedHyenaPlusLayer(kl.Layer): |
|
""" |
|
Enhanced Hyena+DNA layer with multi-scale feature extraction, residual connections, |
|
explicit dimension alignment, and layer normalization for improved gradient flow and stability. |
|
""" |
|
def __init__(self, filters, kernel_size, output_dim, use_residual=True, dilation_rate=1, |
|
kernel_regularizer=None, **kwargs): |
|
super(EnhancedHyenaPlusLayer, self).__init__(**kwargs) |
|
self.filters = filters |
|
self.kernel_size = kernel_size |
|
self.output_dim = output_dim |
|
self.use_residual = use_residual |
|
self.dilation_rate = dilation_rate |
|
self.kernel_regularizer = kernel_regularizer |
|
|
|
|
|
self.conv = kl.Conv1D(filters, kernel_size, padding='same', |
|
kernel_regularizer=kernel_regularizer) |
|
|
|
|
|
self.dilated_conv = kl.Conv1D(filters // 2, kernel_size, |
|
padding='same', |
|
dilation_rate=dilation_rate, |
|
kernel_regularizer=kernel_regularizer) |
|
|
|
|
|
self.local_conv = kl.Conv1D(filters // 2, 3, padding='same', |
|
kernel_regularizer=kernel_regularizer) |
|
|
|
|
|
self.batch_norm = kl.BatchNormalization() |
|
self.activation = kl.Activation('relu') |
|
|
|
|
|
self.fusion = kl.Dense(filters, kernel_regularizer=kernel_regularizer) |
|
|
|
|
|
self.projection = kl.Dense(output_dim, kernel_regularizer=kernel_regularizer) |
|
|
|
|
|
self.layer_norm = kl.LayerNormalization() |
|
|
|
|
|
self.input_projection = None |
|
if use_residual: |
|
self.input_projection = kl.Dense(output_dim, kernel_regularizer=kernel_regularizer) |
|
|
|
def call(self, inputs, training=None): |
|
|
|
residual = inputs |
|
|
|
|
|
x_main = self.conv(inputs) |
|
|
|
|
|
x_dilated = self.dilated_conv(inputs) |
|
|
|
|
|
x_local = self.local_conv(inputs) |
|
|
|
|
|
x_multi = tf.concat([x_dilated, x_local], axis=-1) |
|
|
|
|
|
x = self.fusion(x_multi) + x_main |
|
|
|
x = self.batch_norm(x, training=training) |
|
x = self.activation(x) |
|
|
|
|
|
x = self.projection(x) |
|
|
|
|
|
if self.use_residual: |
|
|
|
residual = self.input_projection(residual) |
|
x = x + residual |
|
|
|
|
|
x = self.layer_norm(x) |
|
|
|
return x |
|
|
|
def get_config(self): |
|
config = super(EnhancedHyenaPlusLayer, self).get_config() |
|
config.update({ |
|
'filters': self.filters, |
|
'kernel_size': self.kernel_size, |
|
'output_dim': self.output_dim, |
|
'use_residual': self.use_residual, |
|
'dilation_rate': self.dilation_rate, |
|
'kernel_regularizer': self.kernel_regularizer |
|
}) |
|
return config |
|
|
|
class HybridContextAwareMSTA(kl.Layer): |
|
""" |
|
Hybrid Context-Aware Motif-Specific Transformer Attention (HCA-MSTA) module |
|
with enhanced biological interpretability and selective motif attention. |
|
Combines the strengths of previous approaches with improved positional encoding. |
|
""" |
|
def __init__(self, num_motifs, motif_dim, num_heads=4, dropout_rate=0.1, |
|
kernel_regularizer=None, activity_regularizer=None, **kwargs): |
|
super(HybridContextAwareMSTA, self).__init__(**kwargs) |
|
self.num_motifs = num_motifs |
|
self.motif_dim = motif_dim |
|
self.num_heads = num_heads |
|
self.dropout_rate = dropout_rate |
|
self.kernel_regularizer = kernel_regularizer |
|
self.activity_regularizer = activity_regularizer |
|
|
|
|
|
self.motif_embeddings = self.add_weight( |
|
shape=(num_motifs, motif_dim), |
|
initializer='glorot_uniform', |
|
regularizer=activity_regularizer, |
|
trainable=True, |
|
name='motif_embeddings' |
|
) |
|
|
|
|
|
self.motif_position_encoding = self.add_weight( |
|
shape=(num_motifs, motif_dim), |
|
initializer='glorot_uniform', |
|
trainable=True, |
|
name='motif_position_encoding' |
|
) |
|
|
|
|
|
self.motif_importance = self.add_weight( |
|
shape=(num_motifs, 1), |
|
initializer='ones', |
|
regularizer=activity_regularizer, |
|
trainable=True, |
|
name='motif_importance' |
|
) |
|
|
|
|
|
self.query_dense = kl.Dense(motif_dim, kernel_regularizer=kernel_regularizer) |
|
self.key_dense = kl.Dense(motif_dim, kernel_regularizer=kernel_regularizer) |
|
self.value_dense = kl.Dense(motif_dim, kernel_regularizer=kernel_regularizer) |
|
|
|
|
|
self.attention = kl.MultiHeadAttention( |
|
num_heads=num_heads, |
|
key_dim=motif_dim // num_heads, |
|
dropout=dropout_rate |
|
) |
|
|
|
|
|
self.gate_dense = kl.Dense(motif_dim, activation='sigmoid', |
|
kernel_regularizer=kernel_regularizer) |
|
|
|
|
|
self.output_dense = kl.Dense(motif_dim, kernel_regularizer=kernel_regularizer) |
|
self.dropout = kl.Dropout(dropout_rate) |
|
self.layer_norm = kl.LayerNormalization() |
|
|
|
|
|
self.ffn_dense1 = kl.Dense(motif_dim * 2, activation='relu', |
|
kernel_regularizer=kernel_regularizer) |
|
self.ffn_dense2 = kl.Dense(motif_dim, kernel_regularizer=kernel_regularizer) |
|
self.ffn_layer_norm = kl.LayerNormalization() |
|
self.ffn_dropout = kl.Dropout(dropout_rate) |
|
|
|
def positional_masking(self, sequence_embeddings, motif_embeddings): |
|
""" |
|
Generate hybrid positional masking based on sequence and motif relevance |
|
with improved biological context awareness and motif importance weighting. |
|
Combines inverse distance and Gaussian approaches for better biological relevance. |
|
""" |
|
|
|
similarity = tf.matmul(sequence_embeddings, tf.transpose(motif_embeddings, [0, 2, 1])) |
|
|
|
|
|
scaled_similarity = similarity / tf.sqrt(tf.cast(self.motif_dim, tf.float32)) |
|
|
|
|
|
attention_weights = tf.nn.softmax(scaled_similarity, axis=-1) |
|
|
|
|
|
seq_length = tf.shape(sequence_embeddings)[1] |
|
motif_length = tf.shape(motif_embeddings)[1] |
|
|
|
|
|
position_indices = tf.range(seq_length)[:, tf.newaxis] - tf.range(motif_length)[tf.newaxis, :] |
|
position_indices_float = tf.cast(position_indices, tf.float32) |
|
|
|
|
|
inverse_weights = 1.0 / (1.0 + tf.abs(position_indices_float)) |
|
|
|
|
|
gaussian_weights = tf.exp(-0.5 * tf.square(position_indices_float / 8.0)) |
|
|
|
|
|
|
|
position_weights = 0.5 * inverse_weights + 0.5 * gaussian_weights |
|
position_weights = tf.expand_dims(position_weights, 0) |
|
|
|
|
|
motif_weights = tf.nn.softmax(self.motif_importance * 1.5, axis=0) |
|
motif_weights = tf.expand_dims(tf.expand_dims(motif_weights, 0), 1) |
|
|
|
|
|
combined_weights = attention_weights * position_weights * tf.squeeze(motif_weights, -1) |
|
|
|
return combined_weights |
|
|
|
def call(self, inputs, training=None): |
|
|
|
batch_size = tf.shape(inputs)[0] |
|
|
|
|
|
motifs = tf.tile(tf.expand_dims(self.motif_embeddings, 0), [batch_size, 1, 1]) |
|
pos_encoding = tf.tile(tf.expand_dims(self.motif_position_encoding, 0), [batch_size, 1, 1]) |
|
|
|
|
|
motifs_with_pos = motifs + pos_encoding |
|
|
|
|
|
query = self.query_dense(inputs) |
|
|
|
|
|
key = self.key_dense(motifs_with_pos) |
|
value = self.value_dense(motifs_with_pos) |
|
|
|
|
|
pos_mask = self.positional_masking(query, motifs_with_pos) |
|
|
|
|
|
attention_output = self.attention( |
|
query=query, |
|
key=key, |
|
value=value, |
|
attention_mask=pos_mask, |
|
training=training |
|
) |
|
|
|
|
|
gate = self.gate_dense(inputs) |
|
gated_attention = gate * attention_output |
|
|
|
|
|
output = self.output_dense(gated_attention) |
|
output = self.dropout(output, training=training) |
|
output = self.layer_norm(output + inputs) |
|
|
|
|
|
ffn_output = self.ffn_dense1(output) |
|
ffn_output = self.ffn_dense2(ffn_output) |
|
ffn_output = self.ffn_dropout(ffn_output, training=training) |
|
final_output = self.ffn_layer_norm(output + ffn_output) |
|
|
|
return final_output |
|
|
|
def get_config(self): |
|
config = super(HybridContextAwareMSTA, self).get_config() |
|
config.update({ |
|
'num_motifs': self.num_motifs, |
|
'motif_dim': self.motif_dim, |
|
'num_heads': self.num_heads, |
|
'dropout_rate': self.dropout_rate, |
|
'kernel_regularizer': self.kernel_regularizer, |
|
'activity_regularizer': self.activity_regularizer |
|
}) |
|
return config |
|
|
|
def HyenaMSTAPlus(params): |
|
""" |
|
Enhanced HyenaMSTA+ model for enhancer activity prediction with multi-scale feature |
|
extraction, hybrid attention mechanism, and improved biological context modeling. |
|
""" |
|
if params['encode'] == 'one-hot': |
|
input_layer = kl.Input(shape=(249, 4)) |
|
elif params['encode'] == 'k-mer': |
|
input_layer = kl.Input(shape=(1, 64)) |
|
|
|
|
|
l2_reg = params.get('l2_reg', 1e-6) |
|
kernel_regularizer = tf.keras.regularizers.l2(l2_reg) |
|
activity_regularizer = tf.keras.regularizers.l1(l2_reg/20) |
|
|
|
|
|
x = input_layer |
|
hyena_layers = [] |
|
|
|
|
|
num_motifs = params.get('num_motifs', 48) |
|
motif_dim = params.get('motif_dim', 96) |
|
|
|
|
|
for i in range(params['convolution_layers']['n_layers']): |
|
|
|
dilation_rate = 2**min(i, 2) |
|
|
|
hyena_layer = EnhancedHyenaPlusLayer( |
|
filters=params['convolution_layers']['filters'][i], |
|
kernel_size=params['convolution_layers']['kernel_sizes'][i], |
|
output_dim=motif_dim, |
|
dilation_rate=dilation_rate, |
|
kernel_regularizer=kernel_regularizer, |
|
name=f'EnhancedHyenaPlus_{i+1}' |
|
) |
|
x = hyena_layer(x) |
|
hyena_layers.append(x) |
|
|
|
if params['encode'] == 'one-hot': |
|
x = kl.MaxPooling1D(2)(x) |
|
|
|
if params['dropout_conv'] == 'yes': |
|
x = kl.Dropout(params['dropout_prob'])(x) |
|
|
|
|
|
ca_msta = HybridContextAwareMSTA( |
|
num_motifs=num_motifs, |
|
motif_dim=motif_dim, |
|
num_heads=params.get('ca_msta_heads', 8), |
|
dropout_rate=params['dropout_prob'], |
|
kernel_regularizer=kernel_regularizer, |
|
activity_regularizer=activity_regularizer |
|
) |
|
|
|
x = ca_msta(x) |
|
|
|
|
|
x = kl.Flatten()(x) |
|
|
|
|
|
for i in range(params['n_dense_layer']): |
|
x = kl.Dense(params['dense_neurons'+str(i+1)], |
|
name=str('Dense_'+str(i+1)))(x) |
|
x = kl.BatchNormalization()(x) |
|
x = kl.Activation('relu')(x) |
|
x = kl.Dropout(params['dropout_prob'])(x) |
|
|
|
|
|
bottleneck = x |
|
|
|
|
|
tasks = ['Dev', 'Hk'] |
|
outputs = [] |
|
for task in tasks: |
|
outputs.append(kl.Dense(1, activation='linear', name=str('Dense_' + task))(bottleneck)) |
|
|
|
|
|
model = keras.models.Model([input_layer], outputs) |
|
model.compile( |
|
keras.optimizers.Adam(learning_rate=params['lr']), |
|
loss=['mse', 'mse'], |
|
loss_weights=[1, 1] |
|
) |
|
|
|
return model, params |
|
|