Created
November 11, 2020 02:55
-
-
Save o8r/b90fb65ff1c359564648964ef7ac885f to your computer and use it in GitHub Desktop.
Fast shifting copy using NUMBA
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numba | |
| @numba.njit | |
| def shifting_fill(src, dst): | |
| """Fill a 2D (or more multi-demensional) array-like object by shifting a 1D array, using NUMBA. | |
| This is useful when making an input to a RNN from time series data, that is, where | |
| src = [0, 2, 1, 3, 5] | |
| and the number of time steps fed in the RNN is M=3, | |
| dst = [ | |
| [0, 2, 1], | |
| [2, 1, 3], | |
| [1, 3, 5] | |
| ] | |
| Note that dst's shape must be (L, M, ...), where L >= len(src) - M + 1 | |
| Args: | |
| src (array): 1D ndarray | |
| dst (array): 2D ndarray | |
| """ | |
| assert dst.ndim >= 2 | |
| L, M = dst.shape | |
| assert L+M-1 <= len(src) | |
| for i in range(L): | |
| dst[i] = src[i:i+M] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment