Created
November 12, 2019 17:59
-
-
Save egafni/d89888aa19a9d75d7104e9cdfb556d82 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numpy | |
| def duplicated(arr, keep='first'): | |
| """ | |
| Mimic pandas.Series.duplicated, but works with multi-dimensional numpy arrays | |
| """ | |
| arr = numpy.asarray(arr) | |
| mask = numpy.ones(len(arr), dtype=bool) | |
| unique_ar, indices, inverse, counts = numpy.unique(arr, return_index=True, return_inverse=True, return_counts=True, | |
| axis=0) | |
| if keep == 'first': | |
| mask[indices[counts >= 1]] = False | |
| else: | |
| raise NotImplementedError() | |
| return mask | |
| import numpy | |
| import pandas | |
| import pytest | |
| from fbio.util.numpy_utils import duplicated | |
| @pytest.mark.parametrize('arr', [[1, 2, 3, 4], | |
| [1, 2, 3, 1, 2], | |
| [[1, 2], [3, 4], [1, 2], [3, 5]], | |
| [[[1, 2], | |
| [3, 5]], | |
| [[1, 2], | |
| [4, 5]], | |
| [[1, 2], | |
| [4, 5]]], | |
| ]) | |
| def test_duplicated(arr): | |
| arr = numpy.asarray(arr) | |
| # convert arrays to a string, and call pandas.Series.duplicated() to get the right answer | |
| pandas_res = pandas.Series([x.tostring() for x in arr]).duplicated(keep='first').values | |
| our_res = duplicated(arr, keep='first') | |
| assert numpy.array_equal(pandas_res, our_res) | |
| def test_duplicated_more_dimensions(): | |
| # test more dimensions | |
| numpy.random.seed(1) | |
| arr = numpy.random.rand(10, 3, 4, 6) | |
| assert not duplicated(arr).any() | |
| arr[3, :] = arr[5, :] | |
| assert duplicated(arr).sum() == 1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment