6 minutes
Introduction to TFT
Introduction
This post kicks off a multi-part series dedicated to breaking down the Temporal Fusion Transformer (TFT) model in an easy-to-understand manner. Throughout the series, we will explore each major component of the TFT architecture in detail, aiming to provide both clarity and practical insights for readers interested in time-series forecasting and deep learning models.
Temporal Fusion Transformer: What is it and where is it used?
The Temporal Fusion Transformer (TFT) is a powerful machine learning model specifically designed for making predictions about future events based on time-series data. Think of it as a sophisticated crystal ball that can look at patterns from the past and present to make educated guesses about what might happen in the future.
TFT is particularly useful in scenarios like:
- Predicting electricity demand for power grids
- Forecasting sales for retail businesses
- Estimating future stock prices
- Predicting weather patterns
- Planning inventory levels for warehouses
What makes TFT special is its ability to handle different types of data (like historical numbers, categorical information, and known future events) while providing not just single predictions, but a range of possible outcomes with their probabilities.
Figure 1: Temporal Fusion Transformer architecture. Source: Lim et al. (2021)
1. Overview
TFT is designed to ingest:
- Static covariates (time-invariant features) - These are characteristics that don’t change over time, like a store’s location or a product’s category
- Past observed inputs (known only up to the current time $t$) - Historical data we’ve already seen, like past sales or weather conditions
- Known future inputs (features whose future values are known, e.g. calendar variables) - Information we already know about the future, like holidays or scheduled events
It outputs quantile forecasts for multiple horizons $h = 1,\dots,H$ - This means it gives us a range of possible outcomes with their probabilities, not just a single prediction.
2. Static Covariate Encoding
In simple terms, this step takes information that doesn’t change over time (like a store’s size or a product’s category) and converts it into a format that the model can use effectively. Think of it as creating a digital fingerprint for each unchanging characteristic.
- We embed each static feature vector $s$ into a context vector $S$ via a small Gated Residual Network (GRN):
$$ S = \mathrm{GRN}_{\mathrm{static}}(s) $$
- This $S$ will be used to “enrich” all future time-step representations.
3. Variable Selection Network
This is like having a smart assistant that knows which information is most important at any given time. For example, when predicting ice cream sales, temperature might be very important in summer but less so in winter.
At each time step $t$, we have input features $x_t^{(1)},\dots,x_t^{(m)}$ (continuous or categorical). TFT learns to weight them dynamically:
- Linear/Embedding transform
For each feature $j$:
$$ \xi_t^{(j)} = W_j,x_t^{(j)} + b_j $$
- Compute feature weights
Stack the ${\xi_t^{(j)}}$ into a vector $\Xi_t$, then pass through a GRN and softmax:
$$ v_t = \mathrm{softmax}\bigl(\mathrm{GRN}_{\mathrm{sel}}(\Xi_t,S)\bigr), \qquad v_t \in \mathbb{R}^m,\sum_j v_t^{(j)}=1 $$
- Fuse inputs
Weighted sum of transformed features:
$$ \tilde{\xi}_t = \sum_{j=1}^m v_t^{(j)} \xi_t^{(j)} $$
This produces a single vector $\tilde{\xi}_t$ per time step, while the weights $v_t$ are directly interpretable.
4. LSTM Encoder-Decoder
Think of this as the model’s memory system. It has two parts:
- An encoder that remembers what happened in the past
- A decoder that uses this memory to make predictions about the future
4.1 Encoder (Past Inputs)
- The sequence of fused past inputs ${\tilde{\xi}_{t-k},,\dots,\tilde{\xi}_t}$ is fed into an LSTM:
$$ h_t, c_t = \mathrm{LSTM}_{\mathrm{enc}}(\tilde{\xi}_t, h_{t-1}, c_{t-1}) $$
- Its final hidden state $(h_t, c_t)$ summarizes historical information.
4.2 Decoder (Known Future Inputs)
- Starting from $(h_t,c_t)$, the decoder LSTM processes known future fused inputs $\tilde{\xi}_{t+1}, \dots, \tilde{\xi}_{t+H}$:
$$ h_{t+h}, c_{t+h} = \mathrm{LSTM}_{\mathrm{dec}}(\tilde{\xi}_{t+h}, h_{t+h-1}, c_{t+h-1}) $$
5. Static Enrichment
This step combines the unchanging characteristics (like store location) with the time-varying information to create a richer understanding of each time step.
Before temporal fusion, each decoder hidden state $h_{t+h}$ is “enriched” by combining with the static context $S$:
$$ \hat{h}_{t+h} = \mathrm{GRN}_{\mathrm{enrich}}\bigl(\bigl[h_{t+h}, S\bigr]\bigr) + h_{t+h} $$
where we apply a residual connection and layer-norm inside the GRN.
6. Temporal Fusion Decoder
This is where the model puts everything together to make its final predictions. It’s like a sophisticated decision-making process that considers all the information it has gathered.
The enriched decoder outputs then pass through a stack of Fusion Blocks, each consisting of:
6.1 Masked Interpretable Multi-Head Self-Attention
This is like having multiple experts looking at the data from different angles, but they can only see information up to the current time they’re predicting for.
- Queries, Keys, Values all from $\hat{H} = [\hat{h}_{t+1}, \dots, \hat{h}_{t+H}]$.
- We apply a causal (upper-triangular) mask $M$ so that time $t+h$ cannot attend to $>t+h$.
- Attention head formula:
$$ \mathrm{Attn}(Q,K,V) = \mathrm{softmax}\Bigl(\tfrac{QK^\top}{\sqrt{d_k}} + M\Bigr)V $$
- Multi-head outputs are concatenated and projected back.
6.2 Gated Residual Network (GRN)
This is a smart way to process information that can skip unnecessary steps when they’re not needed, making the model more efficient.
Each sub-layer (attention or feed-forward) uses a GRN of the form:
- Input $a$ (and optional context $c$).
- Dense + ELU + Dense:
$$ z = W_2\bigl(\mathrm{ELU}(W_1[a,c] + b_1)\bigr) + b_2 $$
- Dropout on $z$.
- Gating:
$$ g = \sigma(W_g[a,c] + b_g),\quad \mathrm{GRN}(a,c) = \mathrm{LayerNorm}\bigl(a + g\odot z\bigr) $$
This lets the network skip transformations when not helpful.
6.3 Position-Wise Feed-Forward
This is the final processing step where each time point is processed independently. Think of it as a final check that makes specific adjustments for each future time point before making predictions.
After attention + add-norm, a point-wise FFN processes each time step independently:
$$ \mathrm{FFN}(x) = \mathrm{GRN}_{\mathrm{ff}}(x) $$
This step adds the final layer of processing that helps the model make more accurate predictions for each specific time point.
7. Quantile Forecasting Head
This is where the model makes its final predictions, but instead of giving a single number, it provides a range of possible outcomes with their probabilities. For example, instead of just saying “sales will be 100 units,” it might say “there’s a 90% chance sales will be between 80 and 120 units.”
Finally, for each desired quantile level $q\in {q_1,\dots,q_L}$, we apply a dense layer:
$$ \hat{y}_{t+h}(q) = W_q,h’_{t+h} + b_q $$
where $h’_{t+h}$ is the output of the last fusion block at horizon $h$. The model is trained by minimizing the quantile loss:
$$ \mathcal{L}_q(y, \hat{y}(q)) = \max\bigl(q, (y-\hat{y}),(q-1)(y-\hat{y})\bigr) $$
Appendix: Jargon Definitions
- Transformer: A type of neural network architecture that’s particularly good at processing sequences of data
- Covariates: Input features or variables that might influence the outcome
- LSTM: Long Short-Term Memory, a type of neural network that’s good at remembering patterns over time
- GRN: Gated Residual Network, a component that can selectively process information
- Self-Attention: A mechanism that helps the model focus on the most relevant parts of the input data
References
Lim, B., Arik, S. O., Loeff, N., & Pfister, T. (2021).
Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting.
Economic Modelling, 102, 105-123.
https://www.sciencedirect.com/science/article/pii/S0169207021000637