Forked from mblondel/projection_simplex_vectorized.py
Created
October 15, 2020 17:48
-
-
Save pandorica-opens/ff84b88699b370f505a68bd5282eff39 to your computer and use it in GitHub Desktop.
Vectorized projection onto the simplex
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Author: Mathieu Blondel | |
| # License: BSD 3 clause | |
| import numpy as np | |
| def projection_simplex(V, z=1, axis=None): | |
| """ | |
| Projection of x onto the simplex, scaled by z: | |
| P(x; z) = argmin_{y >= 0, sum(y) = z} ||y - x||^2 | |
| z: float or array | |
| If array, len(z) must be compatible with V | |
| axis: None or int | |
| axis=None: project V by P(V.ravel(); z) | |
| axis=1: project each V[i] by P(V[i]; z[i]) | |
| axis=0: project each V[:, j] by P(V[:, j]; z[j]) | |
| """ | |
| if axis == 1: | |
| n_features = V.shape[1] | |
| U = np.sort(V, axis=1)[:, ::-1] | |
| z = np.ones(len(V)) * z | |
| cssv = np.cumsum(U, axis=1) - z[:, np.newaxis] | |
| ind = np.arange(n_features) + 1 | |
| cond = U - cssv / ind > 0 | |
| rho = np.count_nonzero(cond, axis=1) | |
| theta = cssv[np.arange(len(V)), rho - 1] / rho | |
| return np.maximum(V - theta[:, np.newaxis], 0) | |
| elif axis == 0: | |
| return projection_simplex(V.T, z, axis=1).T | |
| else: | |
| V = V.ravel().reshape(1, -1) | |
| return projection_simplex(V, z, axis=1).ravel() | |
| def _projection_simplex(v, z=1): | |
| """ | |
| Old implementation for test and benchmark purposes. | |
| The arguments v and z should be a vector and a scalar, respectively. | |
| """ | |
| n_features = v.shape[0] | |
| u = np.sort(v)[::-1] | |
| cssv = np.cumsum(u) - z | |
| ind = np.arange(n_features) + 1 | |
| cond = u - cssv / ind > 0 | |
| rho = ind[cond][-1] | |
| theta = cssv[cond][-1] / float(rho) | |
| w = np.maximum(v - theta, 0) | |
| return w | |
| def test(): | |
| from sklearn.utils.testing import assert_array_almost_equal | |
| rng = np.random.RandomState(0) | |
| V = rng.rand(100, 10) | |
| # Axis = None case. | |
| w = projection_simplex(V[0], z=1, axis=None) | |
| w2 = _projection_simplex(V[0], z=1) | |
| assert_array_almost_equal(w, w2) | |
| w = projection_simplex(V, z=1, axis=None) | |
| w2 = _projection_simplex(V.ravel(), z=1) | |
| assert_array_almost_equal(w, w2) | |
| # Axis = 1 case. | |
| W = projection_simplex(V, axis=1) | |
| # Check same as with for loop. | |
| W2 = np.array([_projection_simplex(V[i]) for i in range(V.shape[0])]) | |
| assert_array_almost_equal(W, W2) | |
| # Check works with vector z. | |
| W3 = projection_simplex(V, np.ones(V.shape[0]), axis=1) | |
| assert_array_almost_equal(W, W3) | |
| # Axis = 0 case. | |
| W = projection_simplex(V, axis=0) | |
| # Check same as with for loop. | |
| W2 = np.array([_projection_simplex(V[:, i]) for i in range(V.shape[1])]).T | |
| assert_array_almost_equal(W, W2) | |
| # Check works with vector z. | |
| W3 = projection_simplex(V, np.ones(V.shape[1]), axis=0) | |
| assert_array_almost_equal(W, W3) | |
| def benchmark(): | |
| import time | |
| n_features = 100 | |
| n_repeats = 5 | |
| sizes = (10, 100, 1000, 10000) | |
| rng = np.random.RandomState(0) | |
| vectorized = np.zeros(len(sizes)) | |
| loop = np.zeros(len(sizes)) | |
| for i, n_samples in enumerate(sizes): | |
| for _ in range(n_repeats): | |
| V = rng.rand(n_samples, 10) | |
| start = time.clock() | |
| projection_simplex(V, axis=0) | |
| vectorized[i] += time.clock() - start | |
| start = time.clock() | |
| [_projection_simplex(V[i]) for i in range(V.shape[0])] | |
| loop[i] += time.clock() - start | |
| vectorized[i] /= n_repeats | |
| loop[i] /= n_repeats | |
| import matplotlib.pylab as plt | |
| plt.figure() | |
| plt.plot(sizes, loop / vectorized, linewidth=3) | |
| plt.title("Vectorized projection onto the simplex") | |
| plt.xscale("log") | |
| plt.xlabel("Number of vectors to project") | |
| plt.ylabel("Speedup compared to using a for loop") | |
| plt.show() | |
| if __name__ == '__main__': | |
| test() | |
| benchmark() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment