Variable Selection Network (VSN) in TFT

The Variable Selection Network (VSN) is a critical component in Temporal Fusion Transformers (TFT), designed to identify and emphasize the most important input variables at each step in a time series. It dynamically selects relevant features, improving both interpretability and model performance by reducing the influence of irrelevant or noisy inputs.

Key Features of Variable Selection Networks

  1. Dynamic Variable Selection:
    VSN dynamically adjusts the importance of each variable at every time step, allowing the model to focus on the most relevant features as patterns change over time.

  2. Gating Mechanism:
    A gating mechanism assigns weights to input variables, controlling their contribution to the model’s predictions. Variables deemed more critical receive higher weights, enhancing their influence.

  3. Interpretability:
    By clearly indicating variable importance over time, VSN provides valuable insights into which features drive the model’s forecasts.

Detailed Explanation

A Variable Selection Network integrates several steps to select and weigh input features effectively:

1. Feature Embedding:

The input features at each time step are embedded into a higher-dimensional space using individual dense layers for each variable. This helps the network better capture the nuances of each variable.

Mathematically, for each input variable $( x_j )$:

$$ e_j = W_j x_j + b_j $$

where $( W_j )$ and $( b_j )$ are learnable parameters specific to variable $( j )$.

2. Variable Importance Computation:

A gating network computes the relative importance of each embedded variable. It produces weights using a softmax function, ensuring all weights sum to one:

$$ \alpha_j = \text{softmax}(W_g e_j + b_g) $$

where $( \alpha_j )$ represents the importance weight for the $( j^{th} )$ variable, and $( W_g )$, $( b_g )$ are gating network parameters.

3. Aggregation of Variables:

Each embedded feature $( e_j )$ is scaled by its respective importance weight $( \alpha_j )$, and summed to create a combined representation for each time step:

$$ v = \sum_j \alpha_j e_j $$

This aggregated vector $( v )$ encapsulates the most critical information from all input variables at the specific time step.

Example

Consider forecasting daily electricity demand with inputs like historical demand, temperature, holiday indicators, and day-of-week features:

Step-by-Step Illustration:

  1. Input Features:

    • Electricity Demand: $(3500 \text{ MW})$
    • Temperature: $(28^\circ \text{C})$
    • Holiday Indicator: $(0)$ (No holiday)
    • Day-of-Week: Monday
  2. Feature Embedding:
    Each input is embedded separately:

    • Demand embedding: $( e_{demand} = [0.9, 0.2] )$
    • Temperature embedding: $( e_{temp} = [0.7, 0.4] )$
    • Holiday embedding: $( e_{holiday} = [0.1, 0.3] )$
    • Day-of-week embedding: $( e_{weekday} = [0.6, 0.5] )$
  3. Importance Weights:
    The gating network calculates weights, e.g.:

    • Demand: $(0.5)$
    • Temperature: $(0.3)$
    • Holiday: $(0.05)$
    • Day-of-week: $(0.15)$
  4. Aggregation:
    The embeddings are scaled and summed:

$$ v = 0.5 \cdot e_{demand} + 0.3 \cdot e_{temp} + 0.05 \cdot e_{holiday} + 0.15 \cdot e_{weekday} $$

This final vector $( v )$ represents the weighted, critical information required for the model to predict electricity demand accurately.

Visualization Example

Visualize each variable’s importance dynamically across time. For instance:

  • Summer Months: High weight on Temperature.
  • Weekdays: Significant weight on Demand history and Day-of-week.
  • Holidays: Increased weight on Holiday indicator.

Summary

The Variable Selection Network in Temporal Fusion Transformers dynamically identifies and prioritizes input variables at every time step. By embedding inputs and applying gating mechanisms, VSN effectively filters out irrelevant noise and enhances interpretability. This dynamic selection significantly improves forecasting accuracy, particularly in complex time-series tasks with varying patterns and dependencies.