NMI skimage: ~30 seconds, 95% accuracy NMI numba *: ~0.6 seconds, 95% accuracy
@njit(cache=True) def compute_joint(img1, img2): stacked = 2 * img1.ravel() + img2.ravel() return np.bincount(stacked, minlength=4).reshape(2, 2) @njit(cache=True) def entropy(counts, total): # Convert counts to probabilities p = counts / total # Only calculate for non-zero entries to avoid log(0) return -np.sum(p[p > 0] * np.log2(p[p > 0])) @njit(cache=True) def normalized_mutual_information(img1, img2): joint_counts = compute_joint(img1, img2) total_pixels = 28*28 # Marginal counts c1 = np.sum(joint_counts, axis=1) c2 = np.sum(joint_counts, axis=0) # Compute entropies directly from counts h1 = entropy(c1, total_pixels) h2 = entropy(c2, total_pixels) joint_entropy = entropy(joint_counts.flatten(), total_pixels) mutual_info = h1 + h2 - joint_entropy return 2 * mutual_info / (h1 + h2)