Skip to content

Instantly share code, notes, and snippets.

@o8r
Created November 11, 2020 02:55
Show Gist options
  • Select an option

  • Save o8r/b90fb65ff1c359564648964ef7ac885f to your computer and use it in GitHub Desktop.

Select an option

Save o8r/b90fb65ff1c359564648964ef7ac885f to your computer and use it in GitHub Desktop.
Fast shifting copy using NUMBA
import numba
@numba.njit
def shifting_fill(src, dst):
"""Fill a 2D (or more multi-demensional) array-like object by shifting a 1D array, using NUMBA.
This is useful when making an input to a RNN from time series data, that is, where
src = [0, 2, 1, 3, 5]
and the number of time steps fed in the RNN is M=3,
dst = [
[0, 2, 1],
[2, 1, 3],
[1, 3, 5]
]
Note that dst's shape must be (L, M, ...), where L >= len(src) - M + 1
Args:
src (array): 1D ndarray
dst (array): 2D ndarray
"""
assert dst.ndim >= 2
L, M = dst.shape
assert L+M-1 <= len(src)
for i in range(L):
dst[i] = src[i:i+M]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment