Skip to content

Instantly share code, notes, and snippets.

@Anjum48
Created July 8, 2020 06:24
Show Gist options
  • Select an option

  • Save Anjum48/8984f657d6a8a884c91135cb4593a56e to your computer and use it in GitHub Desktop.

Select an option

Save Anjum48/8984f657d6a8a884c91135cb4593a56e to your computer and use it in GitHub Desktop.
Hard sample mining in PyTorch
class HardMiningBatchSampler(BatchSampler):
"""
Creates batches that only contain a class once and and chosen based on embedding distance.
Used for NPairLoss
"""
def __init__(self, labels, batch_size, drop_last=True, classes=None):
self.labels = labels
self.batch_size = batch_size
# self.C = batch_size * 2 if classes is None else classes
self.C = len(self.labels) // 10 if classes is None else classes
self.labels_index = list(range(len(self.labels)))
self.drop_last = drop_last
self.neighbors = None
self.update_distances(np.random.normal(size=(len(labels), 32)))
def update_distances(self, embeddings):
labels_array = np.array(self.labels)
embeddings_average = [embeddings[np.where(labels_array == x)].mean(axis=0) for x in set(self.labels)]
neigh = NearestNeighbors(n_neighbors=len(embeddings_average), n_jobs=-1, metric="cosine")
neigh.fit(embeddings_average)
distances, neighbors = neigh.kneighbors(embeddings_average)
self.neighbors = neighbors
def __len__(self):
return len(self.labels) // self.batch_size
def __iter__(self):
labels_temp = self.labels.copy()
labels_index_temp = self.labels_index.copy()
while True:
unique_labels = list(set(labels_temp))
if len(unique_labels) == 1 or (len(unique_labels) < self.batch_size and self.drop_last):
break
# Choose randomly a large number of output classes C;
random_classes = np.random.choice(unique_labels, size=self.C) # size = arbitrary large number
# Select one class randomly from C classes from step 1. Next, greedily add a new class
# that violates triplet constraint the most w.r.t. the selected classes till we reach N classes.
selected_class = np.random.choice(random_classes)
idx = unique_labels.index(selected_class)
nearest_classes = self.neighbors[idx]
batch_labels = []
batch_indexes = []
for i in nearest_classes:
try:
label = unique_labels[i]
index = labels_temp.index(label)
batch_labels.append(label)
batch_indexes.append(labels_index_temp[index])
labels_temp.pop(index)
labels_index_temp.pop(index)
except IndexError:
pass
if len(batch_indexes) == self.batch_size:
break
yield batch_indexes
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment