Flash Attenion算法原理
1.1 Flash Attention Step-by-Step
1.2 Softmax
\( \text{softmax}(x_i) = \frac{e^{x_i}}{\sum_{k=1}^{N} e^{x_k}} \)
import torch A = torch.randn(2, 6) A_exp = torch.exp(A) A_sum = torch.sum(A_exp, dim=1).unsqueeze(1) P = A_exp / A_sum #广播 print(A) print(P)
结果
tensor([[ 1.0668, -0.3969, -0.2226, 0.7207, 1.0509, -1.0740],[ 0.6774, 1.0916, -1.8402, -1.0806, 0.9309, 2.4612]]) tensor([[0.3016, 0.0698, 0.0831, 0.2133, 0.2968, 0.0355],[0.0999, 0.1512, 0.0081, 0.0172, 0.1288, 0.5948]])
1.3 Safe Softmax
原始softmax
数值不稳定,改写成Safe Softmax
版本
\( \text{SafeSoftmax}(x_i) = \frac{e^{x_i - m}}{\sum_{k=1}^{N} e^{x_k - m}} \quad \text{其中 } m = \max(x_1, x_2, ..., x_N) \)
1.3 Online Softmax
\( \begin{aligned} d_i' &= \sum_{j}^{i} e^{x_j - m_i} \\ &= \sum_{j}^{i-1} e^{x_j - m_i} + e^{x_i - m_i} \\ &= \sum_{j}^{i-1} e^{x_j - m_{i-1} + m_{i-1} - m_i} + e^{x_i - m_i} \\ &= \left( \sum_{j}^{i-1} e^{x_j - m_{i-1}} \right) e^{m_{i-1} - m_i} + e^{x_i - m_i} \\ &= d_{i-1}' e^{m_{i-1} - m_i} + e^{x_i - m_i} \end{aligned} \)