Skip to content

Instantly share code, notes, and snippets.

@Hironsan
Last active March 6, 2025 00:33
Show Gist options
  • Select an option

  • Save Hironsan/ea1a1df443b30d5428469f118facc3a5 to your computer and use it in GitHub Desktop.

Select an option

Save Hironsan/ea1a1df443b30d5428469f118facc3a5 to your computer and use it in GitHub Desktop.
Python implementation of infinite relational model
import random
from collections import defaultdict
from collections.abc import Iterator
from typing import Literal, Optional
import networkx as nx
import numpy as np
import numpy.typing as npt
import tqdm
from scipy.special import betaln, logsumexp
from scipy.stats import mode
from sklearn.metrics import adjusted_rand_score
def chinese_restaurant_process(n: int, alpha: float, seed: Optional[int] = None) -> npt.NDArray:
"""Generate cluster assignments for n elements using the Chinese Restaurant Process.
Args:
n (int): Number of elements
alpha (float): Parameter of the CRP
seed (int | None): Random seed
Returns:
z (npt.NDArray): Cluster assignments for each element
"""
assignments = [] # Cluster assignments
table_counts = [] # Number of customers at each table
rng = np.random.default_rng(seed)
for i in range(n):
if i == 0:
# The first customer sits at a new table
assignments.append(0)
table_counts.append(1)
else:
probs = np.array([*table_counts, alpha]) # Existing tables + new table
probs /= probs.sum() # Convert to probabilities
choice = rng.choice(len(probs), p=probs) # Choose a table
if choice == len(table_counts):
# Create a new table
table_counts.append(1)
assignments.append(len(table_counts) - 1)
else:
# Join an existing table
table_counts[choice] += 1
assignments.append(choice)
return np.array(assignments)
class Object:
def __init__(self, domain: Literal[0, 1], index: int) -> None:
self.domain = domain
self.index = index
def is_first_domain(self) -> bool:
"""Return True if the object's domain is 0, otherwise False."""
return self.domain == 0
class Objects:
def __init__(self, num_rows: int, num_cols: int) -> None:
self.objects = [
*[Object(0, i) for i in range(num_rows)],
*[Object(1, i) for i in range(num_cols)],
]
def __iter__(self) -> Iterator[Object]:
"""Shuffles the objects and returns an iterator over them."""
random.shuffle(self.objects)
return iter(self.objects)
class SufficientStatistics:
def __init__(self, z1: npt.NDArray, z2: npt.NDArray, X: npt.NDArray) -> None:
K = len(np.unique(z1))
L = len(np.unique(z2))
z1_onehot = np.eye(K)[z1]
z2_onehot = np.eye(L)[z2]
n_pos = z1_onehot.T @ X @ z2_onehot
n_neg = z1_onehot.T @ (1 - X) @ z2_onehot
self.n_pos = defaultdict(int)
self.n_neg = defaultdict(int)
for k in range(K):
for l in range(L):
self.n_pos[(k, l)] = n_pos[k, l]
self.n_neg[(k, l)] = n_neg[k, l]
def update_row(
self,
X: npt.NDArray,
z1: npt.NDArray,
z2: npt.NDArray,
i: int,
increment: bool = True,
) -> None:
delta = 1 if increment else -1
k = z1[i]
for j in range(X.shape[1]):
l = z2[j]
self.n_pos[(k, l)] += delta * X[i, j]
self.n_neg[(k, l)] += delta * (1 - X[i, j])
def update_col(
self,
X: npt.NDArray,
z1: npt.NDArray,
z2: npt.NDArray,
j: int,
increment: bool = True,
) -> None:
delta = 1 if increment else -1
l = z2[j]
for i in range(X.shape[0]):
k = z1[i]
self.n_pos[(k, l)] += delta * X[i, j]
self.n_neg[(k, l)] += delta * (1 - X[i, j])
def remove_row_cluster(self, k: int, col_clusters: list[int]) -> None:
for l in col_clusters:
del self.n_pos[(k, l)]
del self.n_neg[(k, l)]
def remove_col_cluster(self, l: int, row_clusters: list[int]) -> None:
for k in row_clusters:
del self.n_pos[(k, l)]
del self.n_neg[(k, l)]
class ClusterManager:
def __init__(self, z: npt.NDArray) -> None:
self.clusters = defaultdict(list)
for i, c in enumerate(z):
self.clusters[c].append(i)
def m(self, c: int) -> int:
return len(self.clusters[c])
@property
def new_cluster_id(self) -> int:
return max(self.clusters.keys()) + 1
@property
def cluster_ids(self) -> list[int]:
return list(self.clusters.keys())
@property
def cluster_ids_with_new(self) -> list[int]:
return [*list(self.clusters.keys()), self.new_cluster_id]
def add_index(self, c: int, index: int) -> None:
self.clusters[c].append(index)
def remove_cluster(self, c: int) -> None:
self.clusters.pop(c)
def remove_index(self, c: int, index: int) -> None:
self.clusters[c].remove(index)
def is_empty(self, c: int) -> bool:
return len(self.clusters[c]) == 0
class InfiniteRelationalModel:
def __init__(
self,
alpha1: float,
alpha2: float,
a0: float,
b0: float,
num_iter: int,
burn_in: int,
interval: int,
seed: Optional[int] = None,
) -> None:
self.alpha1 = alpha1
self.alpha2 = alpha2
self.a0 = a0
self.b0 = b0
self.num_iter = num_iter
self.burn_in = burn_in
self.interval = interval
self.seed = seed
self.zs1 = []
self.zs2 = []
def fit(self, X: npt.NDArray) -> None:
N1, N2 = X.shape
z1 = chinese_restaurant_process(N1, self.alpha1, self.seed)
z2 = chinese_restaurant_process(N2, self.alpha2, self.seed)
rng = np.random.default_rng(self.seed)
objects = Objects(N1, N2)
stats = SufficientStatistics(z1, z2, X)
row_cluster = ClusterManager(z1)
col_cluster = ClusterManager(z2)
for _ in tqdm.tqdm(range(self.num_iter)):
for o in objects:
if o.is_first_domain():
stats.update_row(X, z1, z2, o.index, increment=False)
row_cluster.remove_index(z1[o.index], o.index)
if row_cluster.is_empty(z1[o.index]):
row_cluster.remove_cluster(z1[o.index])
stats.remove_row_cluster(z1[o.index], col_cluster.cluster_ids)
probs = self.calculate_first_domain_posterior_prob(
X,
stats,
row_cluster,
col_cluster,
o.index,
)
k = rng.choice(row_cluster.cluster_ids_with_new, p=probs)
z1[o.index] = k
row_cluster.add_index(k, o.index)
stats.update_row(X, z1, z2, o.index, increment=True)
else:
stats.update_col(X, z1, z2, o.index, increment=False)
col_cluster.remove_index(z2[o.index], o.index)
if col_cluster.is_empty(z2[o.index]):
col_cluster.remove_cluster(z2[o.index])
stats.remove_col_cluster(z2[o.index], row_cluster.cluster_ids)
probs = self.calculate_second_domain_posterior_prob(
X,
stats,
row_cluster,
col_cluster,
o.index,
)
l = rng.choice(col_cluster.cluster_ids_with_new, p=probs)
z2[o.index] = l
col_cluster.add_index(l, o.index)
stats.update_col(X, z1, z2, o.index, increment=True)
self.zs1.append(z1.copy())
self.zs2.append(z2.copy())
def fit_predict(self, X: npt.NDArray) -> tuple[npt.NDArray, npt.NDArray]:
self.fit(X)
z1 = mode(self.zs1[self.burn_in :: self.interval], keepdims=False).mode
z2 = mode(self.zs2[self.burn_in :: self.interval], keepdims=False).mode
_, z1 = np.unique(z1, return_inverse=True)
_, z2 = np.unique(z2, return_inverse=True)
return z1, z2
def calculate_first_domain_posterior_prob(
self,
X: npt.NDArray,
stats: SufficientStatistics,
row_cluster: ClusterManager,
col_cluster: ClusterManager,
i: int,
) -> npt.NDArray:
log_probs = np.zeros(len(row_cluster.cluster_ids_with_new))
for idx, k in enumerate(row_cluster.cluster_ids_with_new):
log_probs[idx] = np.log(self.alpha1) if k == row_cluster.new_cluster_id else np.log(row_cluster.m(k))
for l, indices in col_cluster.clusters.items():
a_hat = self.a0 + stats.n_pos[(k, l)]
b_hat = self.b0 + stats.n_neg[(k, l)]
pos = X[i][indices].sum()
neg = len(indices) - pos
log_probs[idx] += betaln(
a_hat + pos,
b_hat + neg,
) - betaln(a_hat, b_hat)
log_probs -= logsumexp(log_probs)
return np.exp(log_probs)
def calculate_second_domain_posterior_prob(
self,
X: npt.NDArray,
stats: SufficientStatistics,
row_cluster: ClusterManager,
col_cluster: ClusterManager,
j: int,
) -> npt.NDArray:
log_probs = np.zeros(len(col_cluster.cluster_ids_with_new))
for idx, l in enumerate(col_cluster.cluster_ids_with_new):
log_probs[idx] = np.log(self.alpha2) if l == col_cluster.new_cluster_id else np.log(col_cluster.m(l))
for k, indices in row_cluster.clusters.items():
a_hat = self.a0 + stats.n_pos[(k, l)]
b_hat = self.b0 + stats.n_neg[(k, l)]
pos = X[indices, j].sum()
neg = len(indices) - pos
log_probs[idx] += betaln(
a_hat + pos,
b_hat + neg,
) - betaln(a_hat, b_hat)
log_probs -= logsumexp(log_probs)
return np.exp(log_probs)
def load_dataset() -> tuple[npt.NDArray, npt.NDArray]:
graph = nx.karate_club_graph()
X = (nx.to_numpy_array(graph) > 0).astype(np.int32)
np.fill_diagonal(X, 1)
mapping = {"Mr. Hi": 0, "Officer": 1}
Z = np.array([mapping[node["club"]] for node in graph.nodes.values()])
return X, Z
if __name__ == "__main__":
X, Z = load_dataset()
model = InfiniteRelationalModel(
alpha1=1.0,
alpha2=1.0,
a0=0.5,
b0=0.5,
num_iter=2000,
burn_in=1500,
interval=5,
seed=41,
)
z1, z2 = model.fit_predict(X)
ari_score = adjusted_rand_score(Z, z1)
ari_score = adjusted_rand_score(Z, z2)
print(f"ARI Score (Z vs z1): {ari_score:.4f}")
print(f"ARI Score (Z vs z2): {ari_score:.4f}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment