|
r""" |
|
Optimized Einsum formula using opt_einsum |
|
|
|
Requirements |
|
|
|
* numpy |
|
* https://github.com/fKunstner/quickbench/tree/v0.1.0 |
|
* https://github.com/dgasmith/opt_einsum/tree/v2.3.2 |
|
""" |
|
|
|
import numpy as np |
|
import quickbench |
|
import opt_einsum as oe |
|
|
|
### |
|
print() |
|
print("Analysis of the starting formula") |
|
print("--------------------------------") |
|
print() |
|
### |
|
|
|
einsum_string = 'bkl,bml,bkn,bmn->mk' |
|
|
|
# Build random views to represent this contraction |
|
index_size = [16, 9, 676, 32, 676] |
|
unique_inds = set(einsum_string) - {',', '-', '>'} |
|
sizes_dict = dict(zip(unique_inds, index_size)) |
|
|
|
views = oe.helpers.build_views(einsum_string, sizes_dict) |
|
path, path_info = oe.contract_path(einsum_string, *views, optimize='optimal') |
|
|
|
# Order of the simplification |
|
print(path) |
|
#> [(0, 1), (0, 1), (0, 1)] |
|
|
|
print(path_info) |
|
#> Complete contraction: bkl,bml,bkn,bmn->mk |
|
#> Naive scaling: 5 |
|
#> Optimized scaling: 4 |
|
#> Naive FLOP count: 8.423e+9 |
|
#> Optimized FLOP count: 6.142e+8 |
|
#> Theoretical speedup: 13.714 |
|
#> Largest intermediate: 7.312e+6 elements |
|
#> -------------------------------------------------------------------------------- |
|
#> scaling BLAS current remaining |
|
#> -------------------------------------------------------------------------------- |
|
#> 4 0 bml,bkl->bmk bkn,bmn,bmk->mk |
|
#> 4 0 bmn,bkn->bmk bmk,bmk->mk |
|
#> 3 0 bmk,bmk->mk mk->mk |
|
|
|
### |
|
print() |
|
print("Testing the new formula") |
|
print("-----------------------") |
|
print() |
|
### |
|
|
|
X = np.random.randn(16, 9, 676) |
|
Y = np.random.randn(16, 32, 676) |
|
|
|
|
|
def datafunc(): |
|
return X, Y |
|
|
|
|
|
def naive(X, Y): |
|
return np.einsum(einsum_string, X, Y, X, Y) |
|
|
|
|
|
def optimized(X, Y): |
|
return np.einsum('bmk,bmk->mk', np.einsum('bml,bkl->bmk', Y, X), np.einsum('bmn,bkn->bmk', Y, X)) |
|
|
|
|
|
quickbench.check(datafunc, [naive, optimized], compfunc=lambda x, y: np.allclose(x, y)) |
|
#> [1] optimized matches [0] naive: True |
|
|
|
quickbench.bench(datafunc, [naive, optimized]) |
|
#>+-----------------+-----------------+-----------------+ |
|
#>| Functions | Time (tot) | Time (per iter) | |
|
#>+-----------------+-----------------+-----------------+ |
|
#>| [0] naive | 0.047003s | 0.004700s | |
|
#>| [1] optimized | 0.033002s | 0.003300s | |
|
#>+-----------------+-----------------+-----------------+ |
|
|
|
|
|
### |
|
print() |
|
print("Simplifying the resulting formula by hand") |
|
print("-----------------------------------------") |
|
print() |
|
### |
|
|
|
|
|
def optimized2(X, Y): |
|
T = np.einsum('bml,bkl->bmk', Y, X) |
|
return np.sum(T**2, axis=0) |
|
|
|
|
|
quickbench.check(datafunc, [naive, optimized, optimized2], compfunc=lambda x, y: np.allclose(x, y)) |
|
#> [1] optimized matches [0] naive: True |
|
#> [2] optimized2 matches [0] naive: True |
|
|
|
quickbench.bench(datafunc, [naive, optimized, optimized2]) |
|
#>+-----------------+-----------------+-----------------+ |
|
#>| Functions | Time (tot) | Time (per iter) | |
|
#>+-----------------+-----------------+-----------------+ |
|
#>| [0] naive | 0.047003s | 0.004700s | |
|
#>| [1] optimized | 0.033002s | 0.003300s | |
|
#>| [2] optimized2 | 0.017001s | 0.001700s | |
|
#>+-----------------+-----------------+-----------------+ |