Created
July 8, 2020 06:24
-
-
Save Anjum48/8984f657d6a8a884c91135cb4593a56e to your computer and use it in GitHub Desktop.
Hard sample mining in PyTorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class HardMiningBatchSampler(BatchSampler): | |
| """ | |
| Creates batches that only contain a class once and and chosen based on embedding distance. | |
| Used for NPairLoss | |
| """ | |
| def __init__(self, labels, batch_size, drop_last=True, classes=None): | |
| self.labels = labels | |
| self.batch_size = batch_size | |
| # self.C = batch_size * 2 if classes is None else classes | |
| self.C = len(self.labels) // 10 if classes is None else classes | |
| self.labels_index = list(range(len(self.labels))) | |
| self.drop_last = drop_last | |
| self.neighbors = None | |
| self.update_distances(np.random.normal(size=(len(labels), 32))) | |
| def update_distances(self, embeddings): | |
| labels_array = np.array(self.labels) | |
| embeddings_average = [embeddings[np.where(labels_array == x)].mean(axis=0) for x in set(self.labels)] | |
| neigh = NearestNeighbors(n_neighbors=len(embeddings_average), n_jobs=-1, metric="cosine") | |
| neigh.fit(embeddings_average) | |
| distances, neighbors = neigh.kneighbors(embeddings_average) | |
| self.neighbors = neighbors | |
| def __len__(self): | |
| return len(self.labels) // self.batch_size | |
| def __iter__(self): | |
| labels_temp = self.labels.copy() | |
| labels_index_temp = self.labels_index.copy() | |
| while True: | |
| unique_labels = list(set(labels_temp)) | |
| if len(unique_labels) == 1 or (len(unique_labels) < self.batch_size and self.drop_last): | |
| break | |
| # Choose randomly a large number of output classes C; | |
| random_classes = np.random.choice(unique_labels, size=self.C) # size = arbitrary large number | |
| # Select one class randomly from C classes from step 1. Next, greedily add a new class | |
| # that violates triplet constraint the most w.r.t. the selected classes till we reach N classes. | |
| selected_class = np.random.choice(random_classes) | |
| idx = unique_labels.index(selected_class) | |
| nearest_classes = self.neighbors[idx] | |
| batch_labels = [] | |
| batch_indexes = [] | |
| for i in nearest_classes: | |
| try: | |
| label = unique_labels[i] | |
| index = labels_temp.index(label) | |
| batch_labels.append(label) | |
| batch_indexes.append(labels_index_temp[index]) | |
| labels_temp.pop(index) | |
| labels_index_temp.pop(index) | |
| except IndexError: | |
| pass | |
| if len(batch_indexes) == self.batch_size: | |
| break | |
| yield batch_indexes |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment