Skip to content

Instantly share code, notes, and snippets.

@jgdshkovi
Created June 16, 2020 20:00
Show Gist options
  • Select an option

  • Save jgdshkovi/a1feb35c1ae069d9fc0e26ba75628f5d to your computer and use it in GitHub Desktop.

Select an option

Save jgdshkovi/a1feb35c1ae069d9fc0e26ba75628f5d to your computer and use it in GitHub Desktop.
from collections import defaultdict
def list_duplicates(seq):
tally = defaultdict(list)
for i,item in enumerate(seq):
tally[item].append(i)
return ((key,locs) for key,locs in tally.items() if len(locs)>1)
def return_mask(wts, labels,thr):
modsource = labels.copy()
for dup in sorted(list_duplicates(labels)):
lis = dup[1][1:]
a = torch.from_numpy(np.reshape(wts[dup[1][0]],(1,wts[dup[1][0]].size)))
for i in lis:
b = torch.from_numpy(np.reshape(wts[i],(1,wts[i].size)))
simi = (cos(a,b))
dis = distance.euclidean(a,b)
if simi>thr:
modsource[i] = -1
mask = []
for el in modsource:
if el!=-1:
mask.append(1)
else:
mask.append(0)
return mask
def calc_distance(x1, y1, a, b, c):
d = abs((a * x1 + b * y1 + c)) / (math.sqrt(a * a + b * b))
return d
ress = []
def optk(X, shp):
global ress
ress = []
maxdis_k = 0
iterations = 20*shp//100
if iterations>50:
iterations=50
count = 1
dist_points_from_cluster_center = [0]
distance_of_points_from_line = [0]
spt = skm.spherical_k_means(X,n_clusters=1,random_state=10)
ept = skm.spherical_k_means(X,n_clusters=shp,random_state=10)
a = spt[2] - ept[2]
b = shp - 1
c1 = 1 * ept[2]
c2 = shp * spt[2]
c = c1 - c2
ress.append(spt[1])
dist_points_from_cluster_center.append(spt[2])
distance_of_points_from_line.append(
calc_distance(1, dist_points_from_cluster_center[1], a, b, c))
for k in range(2,shp):
if count<iterations:
res = skm.spherical_k_means(X,n_clusters=k,random_state=10)
ress.append(res[1])
dist_points_from_cluster_center.append(res[2])
dis = calc_distance(k, dist_points_from_cluster_center[k], a, b, c)
distance_of_points_from_line.append(dis)
if dis > distance_of_points_from_line[maxdis_k]:
maxdis_k = k
count = 0
else:
count += 1
else:
break
ress.append(ept[1])
dist_points_from_cluster_center.append(ept[2])
distance_of_points_from_line.append(
calc_distance(shp, dist_points_from_cluster_center[-1], a, b, c))
return maxdis_k
def return_cluster_labels(feat_wts, shp):
k = optk(feat_wts, shp)
print(k)
return ress[k]
cos_cfg = []
cfg_mask = []
layer_id = 0
for m in model.modules():
if isinstance(m , nn.Conv2d):
shape = m.weight.data.shape
print(shape)
reshaped_tensor = m.weight.data.clone().numpy().reshape(shape[0] , shape[1]*shape[2]*shape[3])
labels = return_cluster_labels(reshaped_tensor,shape[0])
mask = return_mask(reshaped_tensor,labels, thr= 0.20)
print(sum(mask))
cos_cfg.append(sum(mask))
cfg_mask.append(torch.tensor(mask))
layer_id += 1
elif isinstance(m, nn.MaxPool2d):
layer_id += 1
cos_cfg.append('M')
print(cos_cfg)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment