May 8-11, 2017 | Silicon Valley
Cris Cecka, Senior Research Scientist. May 11, 2017
LOW-COMMUNICATION FFT WITH FAST MULTIPOLE METHOD
2
THE FAST FOURIER TRANSFORM
Operation Count: 4N log2 N � 6N + 8
3
SPLIT-RADIX FFTAlgorithm
4
SPLIT-RADIX FFTProfile
5
FMM-FFTEdelman et al. 1999
6
STRUCTURED DENSE MATRICES AND FMM
•SVD:
•Low-Rank:
•Hierarchically LR:
•H-Semi-Separable:
•H2-Matrix/FMM
A = U DV ⇤
K = U Kr⇥r V⇤
KIJ = UI KIJ V ⇤J
KIJ = UI UI KIJ V ⇤J V ⇤
J
7
FMM-FFTAlgorithm
MM,P = diag(IM ,C1, . . . ,CP�1)
[Cp]mn = ⇢phcot
⇣ ⇡
M
⇣n�m+
p
P
⌘⌘+ ı
i
} 2D M ⇥ P FFT
8
COT FMM
• One dimensional • Uniform — integers are source/target • Periodic • Distributed • Size M-by-M • P of them!
• Interleaved
[Cp]mn = ⇢phcot
⇣ ⇡
M
⇣n�m+
p
P
⌘⌘+ ı
i
9
FMM OPERATORS
Each operator is an (implicit) matrix.
M/2L
Q
Q
Q
S2M
M2M
M2M
M2L
M2LL2L
L2L L2L
L2T L2T
S2T
• S: “Source”• T: “Target”• M: “Multipole”• L: “Local”
S2T
M2L
B=2
3
L=4
10
PARAMETERS OF THE FMM-FFT
• FFT
• FMM • Rank • Base level • Leaf box size • Leaf level
N = M P
QBML
L = log2(M/ML)
(N,P,ML, Q,B)
11
DISTRIBUTED FMM
All2All Gather
All2All Gather
Halo 2b
Halo 2b
Halo 1b
Halo 2b
Halo 2b
Halo 1b
12
INTERPOLATIVE FMM
• Same operators across all boxes • Same operators across all levels • Almost same operators across all FMMs
zj = cos
✓(2j + 1)⇡
2Q
◆`i(z) =
Y
0k<Qk 6=i
z � zkzi � zk
S2M
M2M
M2L
L2L
L2T
Cij = `m(tIi ) `q(zIm)C(zIq , z
Jr ) `r(z
Jn ) `n(s
Ji )
13
TENSOR REPRESENTATIONS
• Input:
• Output:
Aijk` := A[i+ j ⇤ ldA<1>+ k ⇤ ldA<2>+ ` ⇤ ldA<3>],
Sn ⌘ Spm ⌘ Spmb
Tn ⌘ Tpm ⌘ Tpmb
14
S2M/L2T
S2Mqm = `q(sm) sm = �1 +2m+ 1
ML
Computed with single BatchedGEMM
ML(p�1)qb = S2Mqm Spmb
15
BATCHED MATRIX-MATRIX MULTIPLY
cublas<T>gemmStridedBatched in cuBLAS 8.0
16
S2M/L2T
Tpmb = L2Tmq Lpqb =) Tpm[b] = Lpq[b] S2Mqm
Mpqb = S2Mqm Spmb =) Mpq[b] = Spm[b] S2MTqm
17
M2M/L2L
M2M±qk = `q
✓zk ± 1
2
◆
M`pqb = M2Mqk M`+1
pk(2b)
Computed with single BatchedGEMM
L`+1pq(2b) = L2Lqk L`
pkb + L`+1pq(2b)
18
S2T/M2L
• Also Level-3 Linear Algebra computations, but no BLAS primitives. • CUSTOM KERNELS
Tpib = S2Tp(j�i) Spjb S2Tpk =
(cot
�⇡N (p+ Pk)
�p > 0
�k0 p = 0
L`pib = M2L`
pijs M`pj(b+s) M2L`
pijs = cot
⇣ ⇡
2
`(
zj2
� zi2
+ s) +⇡
N(p+ 1)
⌘
19
INTERPOLATIVE FMM
P(4ML-1)
QML
QML
2Q2
2Q2
4(L-B)PQ2
StorageOperator Compute
2PMQ
2PMQ
3P2LML2
4(2L-2B)PQ2
4(2L-2B)PQ2
3(2L-2B)PQ2
20
ALGORITHM
21
PROFILE
22
FMM-FFT PROFILE
S2M M2M
Halo
S2T M2L
}L2L L2T
2D FFT
23
2xK40c FMM-FFT
24
2xP100 FMM-FFT
25
8xP100 FMM-FFT
26
FMM BREAKDOWN
• T=ComplexDouble, A=2xP100
• B-GEMM and S2T dominate
• Small N • Latency — Use 1 Level
• Large N • Compute
Components
27
EFFICIENCY
• >95% BatchedGEMM • 60% S2T/M2L • >90% FMM-FFT
28
PARAMETER DEPENDENCE — ML
• Trade #levels for S2T comp
• Flop count not enough • Increase the intensity
• Tune performance for ML=64
• T=Z, A=2xP100, N=227, P=256, B=3, Q=16
Points per box per FMM
29
PARAMETER DEPENDENCE — P
• Flops/Intensity approx constant • Trade #levels for #FMMs
• Large P good • Fill up B-GEMM • More square 2D FFT
• T=Z, A=2xP100, N=227, ML=64, B=3, Q=16
Number of FMMs
30
PARAMETER DEPENDENCE — B
• Not very significant
• Scale to 128 GPUs w/o complications
• T=Z, A=2xP100, N=227, P=256, ML=64, Q=16
Base Level
31
PARAMETER DEPENDENCE — Q
• Weak performance dependence
• Accuracy tuning
• T=Z, A=2xP100, N=227, P=256, ML=64, B=3
Quadrature Order
32
FUTURE
• Integration into CUFFT
• Application to 2D/3D FFTs? • Convolutions
• NUFFT, Sparse FFT
• Volta predictions and measurements • Mixed precision (e.g. FP16 far-field) to use Tensor Core?
• Persistent Matrix Batched GEMM (cuBLAS optimization) • Staged Persistent Matrix Batched GEMM (cooperative groups, RNNs)
33
CONCLUSION
• FMM-FFT trades 2/3 communication in 1D FFT for P FMMs • Viable on highest comp:comm architecture available
• Detailed implementation that relies heavily on existing primitives • Primitives >95% efficient • Two custom dense kernels >60% efficient • Entire FMM-FFT >90% efficient
• Tunable accuracy-performance tradeoff
• Compute model accurately predicts performance
May 8-11, 2017 | Silicon Valley
THANK YOU