Gated Residual Networks (GRNs) in Temporal Fusion Transformers

Gated Residual Networks (GRNs) are a key component of the Temporal Fusion Transformer (TFT) architecture. They enhance the model’s ability to handle complex, non-linear relationships in the data, and ensure smooth information flow across the network.

Key Features of Gated Residual Networks

  1. Non-Linear Transformations: GRNs apply non-linear transformations to the input data, allowing the network to learn complex relationships.

  2. Gating Mechanism: The gating mechanism helps control the flow of information, allowing the network to selectively pass through important information and filter out irrelevant details.

  3. Residual Connections: Residual connections help prevent the vanishing gradient problem, making it easier to train deep networks by allowing gradients to flow more directly through the network.

Detailed Explanation

A Gated Residual Network combines several components to effectively process the input data:

  1. Input Transformation: The input data is first passed through a fully connected (dense) layer to transform it into a suitable representation. This can be represented as: $$ x’ = W_x x + b_x $$ where $( W_x )$ and $( b_x )$ are the weights and bias of the dense layer, and $( x )$ is the input.

  2. Gating Mechanism: The gating mechanism controls the flow of information through the network. It consists of a sigmoid activation function to determine the importance of each input feature. The gate can be represented as: $$ g = \sigma(W_g x’ + b_g) $$

    where $( W_g )$ and $( b_g )$ are the weights and bias, and $( \sigma )$ is the sigmoid function.

  3. Element-wise Multiplication: The transformed input is then element-wise multiplied by the gate output: $$ z = g \odot \phi(W_z x’ + b_z) $$ where $( W_z )$ and $( b_z )$ are the weights and bias, $( \phi )$ is a non-linear activation function (e.g., ReLU), and $( \odot )$ represents element-wise multiplication.

  4. Residual Connection: The original input is added to the gated output to form the final output of the GRN: $$ y = z + x $$ This residual connection ensures that the network can easily propagate gradients during training, which helps in training deeper networks.

  5. Layer Normalization: Optionally, layer normalization can be applied to stabilize and accelerate the training process.

Example

Consider a scenario where we have a time series input feature representing daily sales data along with additional covariates such as promotions and holidays. A GRN processes this input as follows:

  1. Input Transformation: The input features (sales, promotions, holidays) are first transformed through a dense layer to create an intermediate representation.

  2. Gating Mechanism: The gating mechanism evaluates the importance of each transformed feature. For instance, during a promotion, the gate might assign higher importance to the promotion feature.

  3. Element-wise Multiplication: The intermediate representation is element-wise multiplied by the gate output, effectively filtering out less important features and enhancing significant ones.

  4. Residual Connection: The original input features are added back to the gated output to preserve the original information while adding the learned transformations.

Visualisation

Imagine the following input features for a day:

  • Sales Data: 150
  • Promotion Indicator: 1
  • Holiday Indicator: 0

Step-by-Step Processing:

  1. Input Transformation: The input is transformed to an intermediate representation, say $( x’ = [0.8, 0.5, 0.2] )$.

  2. Gating Mechanism: The gate might produce $( g = [0.9, 0.7, 0.1] )$, indicating high importance for the promotion feature.

  3. Element-wise Multiplication: The gated output $( z )$ becomes $( z = [0.72, 0.35, 0.02] )$.

  4. Residual Connection: Adding back the original input gives the final output $( y )$, which combines both the original and transformed information.

Summary

Gated Residual Networks in Temporal Fusion Transformers are designed to enhance the model’s ability to handle complex relationships in the data by combining non-linear transformations, gating mechanisms, and residual connections. This combination allows the TFT to effectively learn and propagate important information through the network, making it particularly powerful for time-series forecasting tasks where capturing intricate patterns and dependencies is crucial.