216  Evaluating Embedding Models

Embedding models transform high-dimensional, discrete objects into continuous vector spaces where geometric relationships encode task-relevant similarity, whether that’s behavioral patterns for users, structural roles in networks, or attribute relationships for products. These representations power recommendation engines, fraud detection systems, customer segmentation, and predictive analytics across virtually every industry. Table 216.1 shows the similarity term is understood across different domains.

Table 216.1: Similarity Terminologies Across Domains
Domain Similarity Term What It Actually Captures
Words/Text Semantic similarity Meaning, synonymy, relatedness
Users Behavioral similarity Similar preferences, actions, consumption patterns
Products Functional/attribute similarity Similar features, use cases, purchase contexts
Network nodes Structural similarity Similar connectivity patterns, roles, neighborhood structure
Locations Spatial/contextual similarity Geographic proximity, similar visit patterns

However, the gap between training an embedding model and deploying it in production is vast. A model that achieves impressive loss curves during training may fail catastrophically when confronted with real-world distribution shift, concept drift, or simply data that differs subtly from the training regime. This chapter provides a framework for evaluating, validating, and monitoring embedding models throughout their lifecycle.

We organize our treatment around three temporal phases:

  1. Pre-deployment evaluation: Intrinsic quality metrics, downstream task validation, and robustness testing before the model enters production
  2. Deployment validation: A/B testing, online metrics, and canary deployments that confirm the model performs as expected with real users
  3. Production monitoring: Continuous surveillance for drift, degradation, and anomalies that signal when intervention is required

Throughout, we use a running business example: a streaming media platform that embeds users, content items, and viewing sessions to power personalized recommendations. This temporal, streaming context mirrors the challenges faced in retail, financial services, healthcare, and any domain where data arrives continuously and patterns evolve.

216.1 Why Evaluation Matters in Business Contexts

Consider a retail e-commerce platform that deploys product embeddings to power “similar items” recommendations. Without rigorous evaluation:

  • Silent degradation: The model may slowly drift as product catalogs change, with no alert until revenue drops
  • Popularity bias: Embeddings may collapse to recommend only popular items, reducing catalog coverage and long-tail sales
  • Cold start failures: New products may receive poor embeddings, preventing discovery
  • Seasonal concept drift: Product relationships that held during training (winter coats similar to scarves) may not hold year-round

Proper evaluation catches these issues early, enabling proactive intervention rather than reactive firefighting.

216.2 Mathematical Preliminaries

Let \(\mathcal{X}\) denote our input space (e.g., the set of all users, products, or nodes in a network). An embedding function \(f: \mathcal{X} \rightarrow \mathbb{R}^d\) maps each entity \(x \in \mathcal{X}\) to a \(d\)-dimensional real vector. We denote the embedding of an entity \(x\) as \(\mathbf{e}_x = f(x)\).

The embedding matrix \(\mathbf{E} \in \mathbb{R}^{n \times d}\) contains embeddings for \(n\) entities, with each row \(\mathbf{E}_i\) representing entity \(i\).

For temporal embeddings, we index by time: \(\mathbf{E}^{(t)}\) represents the embedding matrix at time \(t\), and \(\mathbf{e}_x^{(t)}\) denotes the embedding of the entity \(x\) at time \(t\).

Common distance and similarity functions summarized in Table 216.2

Table 216.2: Distance Metrics
Function Formula Range Use Case
Euclidean distance \(d(\mathbf{u}, \mathbf{v}) = \|\mathbf{u} - \mathbf{v}\|_2\) \([0, \infty)\) When magnitude matters
Cosine similarity \(\text{sim}(\mathbf{u}, \mathbf{v}) = \frac{\mathbf{u} \cdot \mathbf{v}}{\|\mathbf{u}\| \|\mathbf{v}\|}\) \([-1, 1]\) Direction-only comparison
Dot product \(\mathbf{u} \cdot \mathbf{v} = \sum_i u_i v_i\) \((-\infty, \infty)\) When both direction and magnitude matter
Import required libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from scipy.spatial.distance import cdist, cosine
from scipy.linalg import svd
from sklearn.metrics import (
    roc_auc_score, average_precision_score, precision_recall_curve,
    ndcg_score, confusion_matrix, classification_report,
    silhouette_score, adjusted_rand_score, normalized_mutual_info_score,
    mean_squared_error, mean_absolute_error, r2_score,
    accuracy_score, f1_score
)
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score, TimeSeriesSplit
from sklearn.preprocessing import StandardScaler
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
from typing import Dict, List, Tuple, Optional, Callable
from dataclasses import dataclass
from datetime import datetime, timedelta
import warnings
warnings.filterwarnings('ignore')

# Set random seed for reproducibility
np.random.seed(42)

# Plotting configuration
plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['font.size'] = 11
sns.set_style("whitegrid")

217 Intrinsic Evaluation: Embedding Quality

Before we deploy an embedding model to power recommendations, detect fraud, or segment customers, we want to know: Are these embeddings any good?

There are two fundamentally different ways to answer this question:

  1. Extrinsic evaluation: Test performance on a downstream task (e.g., does using these embeddings improve click-through rate? Can we predict which users will churn?)

  2. Intrinsic evaluation: Examine the geometric and statistical properties of the embeddings themselves, independent of any specific task

We first focus on intrinsic evaluation. Think of it as a health check for your embedding space (i.e., diagnosing problems with the embeddings before you invest time and money deploying them).

217.0.1 Why Bother with Intrinsic Metrics?

You might wonder: if we ultimately care about downstream performance, why examine embeddings in isolation?

Practical reasons:

  • Speed: Intrinsic metrics compute in seconds; downstream evaluation might require A/B tests running for weeks
  • Diagnosis: When downstream performance is poor, intrinsic metrics help identify why
  • Early warning: Catch problems during training, not after deployment
  • Comparison: Compare embedding methods before committing to expensive integration

Conceptual reason:

Embeddings are supposed to represent entities in a space where geometry encodes relationships. If the geometry itself is degenerate (i.e., all points clustered together, some dimensions unused, certain points dominating nearest-neighbor queries), then no downstream task can fully recover.

Intrinsic evaluation asks: Is this embedding space geometrically healthy?

217.0.2 Isotropy: Are We Using the Whole Space?

Imagine you’re given a 100-dimensional space to represent 1 million users. You have 100 “degrees of freedom” to capture the diversity of user preferences, behaviors, and characteristics.

Now imagine that your embedding algorithm produces vectors where:

  • Dimension 1 has values ranging from -10 to +10
  • Dimension 2 has values ranging from -0.01 to +0.01
  • Dimensions 3-100 have values clustered tightly around zero

You’ve effectively wasted 99 of your 100 dimensions. Your “100-dimensional” embeddings are really just 1-dimensional, with noise in the other directions.

This is anisotropy: the embedding space is stretched in some directions and compressed in others, rather than using all directions equally.

Isotropic embeddings, by contrast, spread out evenly across all available dimensions (e.g., like a cloud of points forming a sphere rather than a cigar).


Let’s build intuition with a simple 2D example before moving to high dimensions.

Code
fig, axes = plt.subplots(1, 2)

n_points = 500

# Isotropic: points spread evenly in a circle
theta = np.random.uniform(0, 2 * np.pi, n_points)
r = np.random.uniform(0.5, 1.0, n_points)
iso_x = r * np.cos(theta)
iso_y = r * np.sin(theta)

ax = axes[0]
ax.scatter(iso_x, iso_y, alpha=0.5, s=20)
ax.set_xlim(-1.5, 1.5)
ax.set_ylim(-1.5, 1.5)
ax.set_aspect('equal')
ax.axhline(y=0, color='gray', linestyle='--', alpha=0.3)
ax.axvline(x=0, color='gray', linestyle='--', alpha=0.3)
ax.set_xlabel('Dimension 1')
ax.set_ylabel('Dimension 2')
ax.set_title('Isotropic Embeddings\n(Using both dimensions equally)')

# Anisotropic: points stretched along one direction
aniso_x = np.random.normal(0, 1.0, n_points)  # Wide spread
aniso_y = np.random.normal(0, 0.1, n_points)  # Narrow spread

ax = axes[1]
ax.scatter(aniso_x, aniso_y, alpha=0.5, s=20, color='coral')
ax.set_xlim(-1.5, 1.5)
ax.set_ylim(-1.5, 1.5)
ax.set_aspect('equal')
ax.axhline(y=0, color='gray', linestyle='--', alpha=0.3)
ax.axvline(x=0, color='gray', linestyle='--', alpha=0.3)
ax.set_xlabel('Dimension 1')
ax.set_ylabel('Dimension 2')
ax.set_title('Anisotropic Embeddings\n(Dimension 2 is essentially wasted)')

plt.tight_layout()
plt.show()

# Compute variance in each direction
print("Variance by dimension:")
print(f"  Isotropic:   Dim 1 = {np.var(iso_x):.3f}, Dim 2 = {np.var(iso_y):.3f}, Ratio = {np.var(iso_x)/np.var(iso_y):.1f}")
print(f"  Anisotropic: Dim 1 = {np.var(aniso_x):.3f}, Dim 2 = {np.var(aniso_y):.3f}, Ratio = {np.var(aniso_x)/np.var(aniso_y):.1f}")
Figure 217.1

In the isotropic case, both dimensions carry roughly equal variance (i.e., both are “doing work” to distinguish points). In the anisotropic case, dimension 1 carries 100× more variance than dimension 2. You could almost ignore dimension 2 entirely.

217.0.3 Why Does Anisotropy Happen?

Anisotropy isn’t random bad luck, it emerges systematically from how embeddings are trained:

  1. Frequency effects in language models

In Word2Vec-style models, common words get updated far more often than rare words. This pushes embeddings toward a “common direction” that all frequent words share. The result: all word vectors point roughly the same way, with small deviations encoding actual meaning.

  1. Popularity bias in recommendation systems

Popular items appear in many training examples. User embeddings get pulled toward popular items, and item embeddings get pulled toward the “average user.” The dominant direction becomes “popularity,” not “preference.”

  1. Optimization dynamics

Gradient descent often finds solutions that use only a subspace of available dimensions. If the loss function can be minimized using 10 dimensions, the optimizer has no incentive to spread information across all 100.

  1. Layer depth in neural networks

In deep networks (like BERT), anisotropy often increases with layer depth (Ethayarajh 2019). Early layers produce more isotropic representations; later layers collapse toward dominant directions.

217.0.4 The Consequence: Similarity Becomes Meaningless

Here’s why anisotropy matters for downstream applications:

When embeddings are anisotropic, cosine similarity loses discriminative power.

Code
fig, axes = plt.subplots(1, 2)

# Compute pairwise cosine similarities
def pairwise_cosine_similarities(points, n_pairs=5000):
    """Sample pairwise cosine similarities."""
    n = len(points)
    sims = []
    for _ in range(n_pairs):
        i, j = np.random.choice(n, 2, replace=False)
        a, b = points[i], points[j]
        cos_sim = np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-10)
        sims.append(cos_sim)
    return np.array(sims)

iso_points = np.column_stack([iso_x, iso_y])
aniso_points = np.column_stack([aniso_x, aniso_y])

iso_sims = pairwise_cosine_similarities(iso_points)
aniso_sims = pairwise_cosine_similarities(aniso_points)

ax = axes[0]
ax.hist(iso_sims, bins=50, density=True, alpha=0.7, color='steelblue', edgecolor='black')
ax.axvline(np.mean(iso_sims), color='red', linestyle='--', linewidth=2, 
           label=f'Mean = {np.mean(iso_sims):.3f}')
ax.set_xlabel('Cosine Similarity')
ax.set_ylabel('Density')
ax.set_title('Isotropic: Similarities Spread Out')
ax.set_xlim(-1, 1)
ax.legend()

ax = axes[1]
ax.hist(aniso_sims, bins=50, density=True, alpha=0.7, color='coral', edgecolor='black')
ax.axvline(np.mean(aniso_sims), color='red', linestyle='--', linewidth=2,
           label=f'Mean = {np.mean(aniso_sims):.3f}')
ax.set_xlabel('Cosine Similarity')
ax.set_ylabel('Density')
ax.set_title('Anisotropic: Everything Looks Similar')
ax.set_xlim(-1, 1)
ax.legend()

plt.tight_layout()
plt.show()

print(f"Isotropic:   Mean similarity = {np.mean(iso_sims):.3f}, Std = {np.std(iso_sims):.3f}")
print(f"Anisotropic: Mean similarity = {np.mean(aniso_sims):.3f}, Std = {np.std(aniso_sims):.3f}")
Figure 217.2

In the isotropic case, cosine similarities range widely from -1 to +1. We can meaningfully say “user A is very similar to user B (similarity 0.9) but quite different from user C (similarity 0.1).”

In the anisotropic case, every pair has similarity around 0.9. Everyone looks like everyone else. The embedding has lost its ability to distinguish entities.

217.0.5 Moving to High Dimensions: The Math

Now that we have intuition, let’s formalize. The core insight carries over directly to high dimensions.

217.0.5.1 Variance and Principal Components

For a statistician, anisotropy is most naturally understood through the eigenvalue decomposition of the covariance matrix.

Given embedding matrix \(\mathbf{E} \in \mathbb{R}^{n \times d}\) (n entities, d dimensions), center the data:

\[ \tilde{\mathbf{E}} = \mathbf{E} - \mathbf{1}\bar{\mathbf{e}}^\top \]

where \(\bar{\mathbf{e}} = \frac{1}{n}\sum_{i=1}^n \mathbf{e}_i\) is the mean embedding.

The sample covariance matrix is:

\[ \mathbf{\Sigma} = \frac{1}{n} \tilde{\mathbf{E}}^\top \tilde{\mathbf{E}} \]

Let \(\lambda_1 \geq \lambda_2 \geq \cdots \geq \lambda_d \geq 0\) be the eigenvalues of \(\mathbf{\Sigma}\).

Interpretation: \(\lambda_j\) is the variance of the data projected onto the \(j\)-th principal component. Large \(\lambda_1\) and small \(\lambda_d\) means most variance concentrates in few directions (i.e., anisotropy).

217.0.5.2 Isotropy Metrics Based on Eigenvalues

217.0.5.2.1 Condition number (ratio of extremes):

\[ \kappa = \frac{\lambda_1}{\lambda_d} \]

  • \(\kappa = 1\): Perfect isotropy (all directions equal or all eigenvalues exactly equal, which never happens in practice)
  • \(\kappa > 1\): Some anisotropy exists (dominant directions)
Table 217.1: Condition Number Interpretation
\(\kappa\) Interpretation
1 - 5 Healthy, minor differences across dimensions
5 - 20 Mild anisotropy, probably fine
20 - 100 Moderate anisotropy, worth investigating
100+ Severe anisotropy, likely problematic
1000+ Extreme (some dimensions essentially unused)

Real embeddings from well-trained models typically have \(\kappa\) in the 5-50 range. When you see \(\kappa > 100\), it means the largest eigenvalue is 100× bigger than the smallest (i.e., the smallest direction captures essentially no variance compared to the dominant one). See Table 217.1.

217.0.5.2.2 Participation ratio (effective dimensionality):

This metric, borrowed from physics, asks: “how many dimensions are actually contributing?”

\[ \text{PR} = \frac{\left(\sum_{j=1}^d \lambda_j\right)^2}{\sum_{j=1}^d \lambda_j^2} \]

  • If all eigenvalues are equal (\(\lambda_j = c\) for all \(j\)): \(\text{PR} = d\) (all dimensions contribute)
  • If one eigenvalue dominates (\(\lambda_1 \gg \lambda_{j>1}\)): \(\text{PR} \approx 1\) (effectively 1D)

Intuition: PR counts the “effective number of equal-sized eigenvalues” that would give the same ratio.

Demonstrating eigenvalue-based isotropy metrics
def compute_eigenvalue_metrics(embeddings):
    """Compute isotropy metrics from eigenvalue decomposition."""
    # Center
    centered = embeddings - np.mean(embeddings, axis=0)
    
    # Covariance matrix
    n = len(embeddings)
    cov = (centered.T @ centered) / n
    
    # Eigenvalues
    eigenvalues = np.linalg.eigvalsh(cov)
    eigenvalues = np.sort(eigenvalues)[::-1]  # Descending
    eigenvalues = eigenvalues[eigenvalues > 1e-10]  # Remove numerical zeros
    
    # Metrics
    condition_number = eigenvalues[0] / eigenvalues[-1]
    
    participation_ratio = (np.sum(eigenvalues) ** 2) / np.sum(eigenvalues ** 2)
    
    # How many dimensions for 90% variance?
    cumvar = np.cumsum(eigenvalues) / np.sum(eigenvalues)
    dims_90 = np.searchsorted(cumvar, 0.9) + 1
    
    return {
        'eigenvalues': eigenvalues,
        'condition_number': condition_number,
        'participation_ratio': participation_ratio,
        'dims_for_90_variance': dims_90,
        'total_dims': len(eigenvalues)
    }


# Compare isotropic vs anisotropic in higher dimensions
d = 50  # 50 dimensions
n = 1000

# Isotropic: all dimensions have similar variance
iso_high = np.random.randn(n, d)  # Standard normal in all dimensions

# Anisotropic: variance decays across dimensions
decay = np.exp(-0.1 * np.arange(d))  # Exponential decay
aniso_high = np.random.randn(n, d) * decay

print("EIGENVALUE-BASED ISOTROPY ANALYSIS")
print("=" * 60)

iso_metrics = compute_eigenvalue_metrics(iso_high)
print(f"\nIsotropic embeddings ({d} dimensions):")
print(f"  Condition number (λ₁/λ_d): {iso_metrics['condition_number']:.1f}")
print(f"  Participation ratio: {iso_metrics['participation_ratio']:.1f} / {iso_metrics['total_dims']}")
print(f"  Dimensions for 90% variance: {iso_metrics['dims_for_90_variance']}")

aniso_metrics = compute_eigenvalue_metrics(aniso_high)
print(f"\nAnisotropic embeddings ({d} dimensions):")
print(f"  Condition number (λ₁/λ_d): {aniso_metrics['condition_number']:.1f}")
print(f"  Participation ratio: {aniso_metrics['participation_ratio']:.1f} / {aniso_metrics['total_dims']}")
print(f"  Dimensions for 90% variance: {aniso_metrics['dims_for_90_variance']}")
Code
fig, axes = plt.subplots(2, 1)

ax = axes[0]
ax.bar(range(len(iso_metrics['eigenvalues'])), iso_metrics['eigenvalues'], 
       alpha=0.7, color='steelblue')
ax.set_xlabel('Principal Component')
ax.set_ylabel('Eigenvalue (Variance)')
ax.set_title(f'Isotropic: Participation Ratio = {iso_metrics["participation_ratio"]:.1f}/{d}')
ax.set_xlim(-1, 50)

ax = axes[1]
ax.bar(range(len(aniso_metrics['eigenvalues'])), aniso_metrics['eigenvalues'],
       alpha=0.7, color='coral')
ax.set_xlabel('Principal Component')
ax.set_ylabel('Eigenvalue (Variance)')
ax.set_title(f'Anisotropic: Participation Ratio = {aniso_metrics["participation_ratio"]:.1f}/{d}')
ax.set_xlim(-1, 50)

plt.tight_layout()
plt.show()
Figure 217.3

The eigenvalues are remarkably flat across all 50 principal components, ranging only from ~1.5 to ~0.6. This means variance is distributed nearly equally across all dimensions, no single component dominates.

The Participation Ratio (47.7/50): This metric quantifies “effective dimensionality.” A value of 47.7 out of 50 means the data effectively uses almost all available dimensions.

What this tells you: Isotropic data behaves like a spherical cloud in high-dimensional space. There’s no low-dimensional structure to exploit, you can’t reduce dimensionality without losing substantial information. This is characteristic of pure noise or data where all features contribute independently and equally.

Contrast with anisotropic: The bottom panel shows eigenvalues dropping sharply, with PR = 10.2/50. That data has clear structure (i.e., a few dominant directions capture most variance) making dimensionality reduction effective.

Practical implication: If your actual data looks isotropic, PCA won’t help much for compression or feature extraction. If it looks anisotropic, you can safely retain only the top ~10 components.

217.0.5.2.3 Average Pairwise Cosine Similarity (APCS)

An alternative approach: directly measure how similar random embedding pairs are.

For isotropic embeddings uniformly distributed on the unit hypersphere in \(\mathbb{R}^d\), theory tells us:

\[ \mathbb{E}[\cos(\mathbf{e}_i, \mathbf{e}_j)] = 0 \]

and

\[ \text{Var}[\cos(\mathbf{e}_i, \mathbf{e}_j)] = \frac{1}{d} \]

As dimensionality increases, random vectors become nearly orthogonal (the “blessing of dimensionality” for some applications, “curse” for others).1

In practice:

  • APCS \(\approx 0\): Healthy, isotropic embeddings
  • APCS \(> 0.3\): Moderate anisotropy, reduced discriminative power
  • APCS \(> 0.7\): Severe anisotropy, embeddings nearly useless for similarity
Computing and interpreting APCS
def compute_apcs(embeddings, n_pairs=10000):
    """
    Compute Average Pairwise Cosine Similarity.
    
    This directly measures how "spread out" embeddings are.
    """
    n, d = embeddings.shape
    
    # Normalize to unit length
    norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
    normalized = embeddings / (norms + 1e-10)
    
    # Sample random pairs
    sims = []
    for _ in range(n_pairs):
        i, j = np.random.choice(n, 2, replace=False)
        cos_sim = np.dot(normalized[i], normalized[j])
        sims.append(cos_sim)
    
    return {
        'mean': np.mean(sims),
        'std': np.std(sims),
        'median': np.median(sims),
        'percentile_5': np.percentile(sims, 5),
        'percentile_95': np.percentile(sims, 95)
    }


print("AVERAGE PAIRWISE COSINE SIMILARITY (APCS)")
print("=" * 60)

iso_apcs = compute_apcs(iso_high)
print(f"\nIsotropic embeddings:")
print(f"  APCS = {iso_apcs['mean']:.4f} ± {iso_apcs['std']:.4f}")
print(f"  Range (5th-95th percentile): [{iso_apcs['percentile_5']:.3f}, {iso_apcs['percentile_95']:.3f}]")
print(f"  Interpretation: Random pairs are nearly orthogonal ✓")

aniso_apcs = compute_apcs(aniso_high)
print(f"\nAnisotropic embeddings:")
print(f"  APCS = {aniso_apcs['mean']:.4f} ± {aniso_apcs['std']:.4f}")
print(f"  Range (5th-95th percentile): [{aniso_apcs['percentile_5']:.3f}, {aniso_apcs['percentile_95']:.3f}]")
print(f"  Interpretation: All pairs moderately similar, discriminative power reduced")

217.0.6 Connecting Eigenvalues and APCS

These two perspectives (i.e., eigenvalue-based and similarity-based) are mathematically connected.

When the covariance matrix has a dominant eigenvector, all embeddings tend to align with that direction. Alignment means high cosine similarity between pairs.

Rough relationship: If the first principal component explains fraction \(\rho\) of total variance:

\[ \text{APCS} \approx \rho^2 + \text{(smaller terms)} \]

So participation ratio dropping (i.e., variance concentrating) implies APCS rising (pairs becoming similar).

217.0.7 Practical Guidelines

Based on empirical studies across NLP, recommender systems, and graph embeddings:

Table 217.2: Isotropic Metrics
Metric Healthy Concerning Problematic
APCS < 0.1 0.1 - 0.3 > 0.3
Participation Ratio / d > 50% 20% - 50% < 20%
Condition Number < 10 10 - 100 > 100
Dims for 90% Variance > 30% of d 10% - 30% < 10%

In Table 217.2 These aren’t hard thresholds; context matters. But they provide useful diagnostic benchmarks.

217.0.8 What Causes Anisotropy? A Deeper Look

Understanding causes helps prevent anisotropy, not just detect it.

217.0.8.1 The Frequency-Weighted Mean Problem

Consider Skip-gram Word2Vec. Each word \(w\) with embedding \(\mathbf{e}_w\) gets updated when it appears in training. The update roughly pushes \(\mathbf{e}_w\) toward the average of its context words.

Frequent words get many more updates. They get pushed toward the frequency-weighted mean of all words. Over time, all embeddings converge toward this common direction.

Formally, if word \(w\) appears \(f_w\) times and contexts are uniformly sampled:

\[ \mathbf{e}_w \rightarrow \alpha \cdot \bar{\mathbf{e}} + (1-\alpha) \cdot \mathbf{e}_w^{\text{unique}} \]

where \(\bar{\mathbf{e}} = \frac{\sum_v f_v \mathbf{e}_v}{\sum_v f_v}\) is the frequency-weighted mean.

The \(\bar{\mathbf{e}}\) component is shared by all words, it’s the dominant direction causing anisotropy.

217.0.8.2 Implicit Regularization in Deep Networks

Neural networks trained with gradient descent exhibit implicit regularization: they find low-rank solutions even without explicit rank penalties.

For embedding layers, this means:

  • Early training: embeddings spread out to reduce loss quickly
  • Later training: embeddings collapse as the network finds “simpler” solutions that use fewer effective dimensions

This is why deeper layers in transformers often show higher anisotropy.

217.0.9 Correcting Anisotropy

Two main approaches:

  1. Post-hoc whitening: Transform embeddings to have identity covariance

\[ \tilde{\mathbf{E}} = (\mathbf{E} - \bar{\mathbf{E}}) \mathbf{\Sigma}^{-1/2} \]

This “spheres” the distribution, forcing isotropy. But it may distort semantically meaningful structure.

  1. Remove dominant directions: Project out the top-k principal components

\[ \tilde{\mathbf{e}} = \mathbf{e} - \sum_{j=1}^k (\mathbf{e} \cdot \mathbf{v}_j) \mathbf{v}_j \]

This removes the “common direction” while preserving relative differences. Often works better than full whitening.

  1. Kernel-Whitening

Standard ZCA/PCA whitening assumes linear correlations between dimensions. Kernel-Whitening addresses non-linear dependencies that linear whitening cannot remove.

  • How it works: It projects the embeddings into a higher-dimensional Reproducing Kernel Hilbert Space (RKHS) and performs whitening there. This is often approximated using the Nyström method to keep it computationally feasible.

  • Why use it: It is particularly effective if your embeddings contain complex, non-linear biases (e.g., stylistic biases in text) that linear removal (like All-but-the-top) fails to eliminate.

  1. Regularization Terms (loss-based correction)

Instead of changing the architecture (like IsoBN), you can add a penalty term to your loss function during training to explicitly punish anisotropy.

  • Cosine Regularization: You add a term that minimizes the cosine similarity between random non-matching pairs in a batch.

    \(L_{total} = L_{task} + \lambda \sum_{i \neq j} \text{sim}(x_i, x_j)^2\)

    This forces the model to push unrelated vectors apart, expanding the “cone” they usually collapse into.

  • Whitening Penalty: You can directly penalize the difference between the embedding covariance matrix \(\Sigma\) and the identity matrix \(I\):

    \(L_{reg} = || \Sigma - I ||_F^2\) (Where \(||\cdot||_F\) is the Frobenius norm).

  1. Conceptor Negation (Soft Projection)

“All-but-the-top” removal is a “hard” projection, it completely deletes the top principal component. Conceptors offer a “soft” alternative that dampens dominant directions without removing them entirely.

  • How it works: A “conceptor” is a matrix that represents a subspace (like the direction of high anisotropy). Instead of subtracting this direction, you apply a logical “NOT” operation using matrix algebra.

    \(x_{new} = x_{old} (I - C)\)$

    Where \(C\) is the conceptor matrix representing the common direction.

  • Why use it: Hard removal can accidentally delete useful information if the “noise” direction overlaps with actual meaning. Conceptors allow you to dial down the noise intensity rather than cutting it to zero.

  1. Layer-wise Adaptation (Last-Layer Normalization)

Anisotropy is most severe in the final layers of a model. Instead of post-processing the output, this method alters the architecture at the very end.

  • How it works: You replace the standard Layer Normalization in the final transformer block with a specialized normalization that enforces a unit sphere distribution before the final projection.

  • Why use it: It corrects the geometry at the source. Research shows that the bias parameters in the final LayerNorm are often the primary culprits for the “cone effect.” Setting these biases to zero or retraining strictly that layer can resolve the issue.

Demonstrating anisotropy correction
def remove_principal_components(embeddings, n_remove=1):
    """Remove top principal components to reduce anisotropy."""
    # Center
    mean = np.mean(embeddings, axis=0)
    centered = embeddings - mean
    
    # SVD to find principal components
    U, S, Vt = np.linalg.svd(centered, full_matrices=False)
    
    # Remove top components
    result = centered.copy()
    for j in range(n_remove):
        component = Vt[j]  # j-th principal direction
        projections = centered @ component  # Project all points
        result = result - np.outer(projections, component)  # Subtract projection
    
    return result


# Apply correction to anisotropic embeddings
corrected = remove_principal_components(aniso_high, n_remove=3)

print("ANISOTROPY CORRECTION")
print("=" * 60)

print("\nBefore correction:")
print(f"  APCS = {aniso_apcs['mean']:.4f}")
print(f"  Participation Ratio = {aniso_metrics['participation_ratio']:.1f}")

corrected_apcs = compute_apcs(corrected)
corrected_metrics = compute_eigenvalue_metrics(corrected)
print("\nAfter removing top 3 principal components:")
print(f"  APCS = {corrected_apcs['mean']:.4f}")
print(f"  Participation Ratio = {corrected_metrics['participation_ratio']:.1f}")

217.0.10 Summary: The Isotropy Checklist

Before deploying embeddings, check:

  1. Compute APCS: Should be < 0.1 for healthy embeddings
  2. Examine eigenvalue spectrum: Should decay gradually, not precipitously
  3. Check participation ratio: Should be > 50% of nominal dimension
  4. If anisotropic: Consider removing top principal components before use

Anisotropic embeddings aren’t necessarily useless, but they have reduced representational capacity. You’re paying for \(d\) dimensions but only using a fraction of them.

217.0.11 Business Example: User Embeddings at a Streaming Platform

Consider a streaming platform that learns user embeddings from viewing history to power recommendations. If these embeddings are anisotropic, several problems emerge:

  1. Reduced personalization: If all user embeddings point in roughly the same direction, the system cannot distinguish between users with different tastes
  2. Popularity bias amplification: Anisotropic embeddings often emerge when popular content dominates training, pushing all users toward similar representations
  3. Cold start failures: New users get embeddings that look like everyone else, preventing differentiated recommendations
Code
def compute_isotropy_metrics(embeddings: np.ndarray, 
                              sample_size: int = 5000) -> Dict:
    """
    Compute comprehensive isotropy metrics for an embedding matrix.
    
    Parameters
    ----------
    embeddings : np.ndarray
        Embedding matrix of shape (n_entities, embedding_dim)
    sample_size : int
        Number of pairs to sample for APCS computation (for efficiency)
    
    Returns
    -------
    dict
        Dictionary containing isotropy metrics:
        - apcs: Average pairwise cosine similarity
        - apcs_std: Standard deviation of pairwise similarities
        - isotropy_score: Ratio of min to max eigenvalue
        - participation_ratio: Effective dimensionality
        - effective_dim_entropy: Entropy-based effective dimension
        - dim_90_variance: Dimensions needed for 90% variance
        - eigenvalues: Full eigenvalue spectrum
    """
    n, d = embeddings.shape
    
    # Normalize embeddings for cosine similarity
    norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
    norms = np.where(norms == 0, 1, norms)  # Avoid division by zero
    normalized = embeddings / norms
    
    # 1. Average Pairwise Cosine Similarity (APCS)
    # Sample for computational efficiency with large matrices
    if n > sample_size:
        idx = np.random.choice(n, sample_size, replace=False)
        sample = normalized[idx]
    else:
        sample = normalized
    
    # Compute pairwise cosine similarities via matrix multiplication
    sim_matrix = sample @ sample.T
    # Extract upper triangle (excluding diagonal)
    upper_tri_indices = np.triu_indices(len(sample), k=1)
    pairwise_sims = sim_matrix[upper_tri_indices]
    apcs = np.mean(pairwise_sims)
    apcs_std = np.std(pairwise_sims)
    
    # 2. Eigenvalue-based isotropy
    # Center the embeddings
    centered = embeddings - np.mean(embeddings, axis=0)
    # Compute covariance matrix
    cov_matrix = (centered.T @ centered) / n
    # Eigenvalue decomposition (returns sorted ascending)
    eigenvalues = np.linalg.eigvalsh(cov_matrix)
    eigenvalues = np.sort(eigenvalues)[::-1]  # Descending order
    
    # Filter out near-zero eigenvalues for numerical stability
    significant_eigenvalues = eigenvalues[eigenvalues > 1e-10]
    
    # Isotropy score (min/max eigenvalue ratio)
    if len(significant_eigenvalues) > 0:
        isotropy_score = significant_eigenvalues[-1] / significant_eigenvalues[0]
    else:
        isotropy_score = 0.0
    
    # 3. Effective dimensionality (participation ratio)
    # PR = (sum of eigenvalues)^2 / sum of eigenvalues^2
    eigenvalues_positive = eigenvalues[eigenvalues > 0]
    if len(eigenvalues_positive) > 0:
        participation_ratio = (np.sum(eigenvalues_positive) ** 2) / np.sum(eigenvalues_positive ** 2)
    else:
        participation_ratio = 0.0
    
    # 4. Entropy-based effective dimensionality
    # Normalize eigenvalues to form a probability distribution
    if np.sum(eigenvalues_positive) > 0:
        eigenvalues_norm = eigenvalues_positive / np.sum(eigenvalues_positive)
        entropy = -np.sum(eigenvalues_norm * np.log(eigenvalues_norm + 1e-10))
        effective_dim_entropy = np.exp(entropy)
    else:
        effective_dim_entropy = 0.0
    
    # 5. 90% variance dimensionality
    if np.sum(eigenvalues_positive) > 0:
        cumsum = np.cumsum(eigenvalues_positive) / np.sum(eigenvalues_positive)
        dim_90_variance = np.searchsorted(cumsum, 0.90) + 1
    else:
        dim_90_variance = d
    
    return {
        'apcs': apcs,
        'apcs_std': apcs_std,
        'isotropy_score': isotropy_score,
        'participation_ratio': participation_ratio,
        'effective_dim_entropy': effective_dim_entropy,
        'dim_90_variance': dim_90_variance,
        'nominal_dim': d,
        'eigenvalues': eigenvalues
    }


def diagnose_isotropy(metrics: Dict) -> str:
    """
    Provide diagnostic interpretation of isotropy metrics.
    
    Returns human-readable diagnosis with recommendations.
    """
    diagnosis = []
    
    # APCS interpretation
    if metrics['apcs'] < 0.1:
        diagnosis.append("✓ APCS indicates good isotropy (embeddings well-distributed)")
    elif metrics['apcs'] < 0.3:
        diagnosis.append("⚠ APCS suggests moderate anisotropy (some directional clustering)")
    else:
        diagnosis.append("✗ APCS indicates severe anisotropy (embeddings collapsed to narrow cone)")
    
    # Effective dimensionality
    dim_utilization = metrics['participation_ratio'] / metrics['nominal_dim']
    if dim_utilization > 0.5:
        diagnosis.append(f"✓ Good dimension utilization ({dim_utilization:.1%} effective)")
    elif dim_utilization > 0.2:
        diagnosis.append(f"⚠ Moderate dimension utilization ({dim_utilization:.1%} effective)")
    else:
        diagnosis.append(f"✗ Poor dimension utilization ({dim_utilization:.1%} effective)")
    
    # 90% variance dimensionality
    var_ratio = metrics['dim_90_variance'] / metrics['nominal_dim']
    if var_ratio > 0.3:
        diagnosis.append(f"✓ Variance spread across dimensions ({metrics['dim_90_variance']}/{metrics['nominal_dim']} dims for 90% variance)")
    else:
        diagnosis.append(f"✗ Variance concentrated ({metrics['dim_90_variance']}/{metrics['nominal_dim']} dims capture 90% variance)")
    
    return "\n".join(diagnosis)
Demonstrate isotropy metrics with synthetic data
# Generate isotropic embeddings (uniform on hypersphere)
def generate_isotropic_embeddings(n: int, d: int) -> np.ndarray:
    """Generate embeddings uniformly distributed on unit hypersphere."""
    embeddings = np.random.randn(n, d)
    norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
    return embeddings / norms

# Generate anisotropic embeddings (clustered in narrow cone)
def generate_anisotropic_embeddings(n: int, d: int, 
                                     concentration: float = 0.1) -> np.ndarray:
    """
    Generate anisotropic embeddings clustered around a mean direction.
    
    Parameters
    ----------
    concentration : float
        Lower values = more concentrated (more anisotropic)
    """
    # Mean direction (dominant first dimension)
    mean_dir = np.zeros(d)
    mean_dir[0] = 1.0
    
    # Add noise with varying scale per dimension
    scales = np.array([1.0] + [concentration] * (d - 1))
    embeddings = mean_dir + np.random.randn(n, d) * scales
    norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
    return embeddings / norms

# Compare isotropic vs anisotropic
n_entities = 10000
embedding_dim = 128

np.random.seed(42)
isotropic_emb = generate_isotropic_embeddings(n_entities, embedding_dim)
anisotropic_emb = generate_anisotropic_embeddings(n_entities, embedding_dim, 
                                                   concentration=0.1)

iso_metrics = compute_isotropy_metrics(isotropic_emb)
aniso_metrics = compute_isotropy_metrics(anisotropic_emb)

print("=" * 60)
print("ISOTROPIC EMBEDDINGS")
print("=" * 60)
print(f"APCS: {iso_metrics['apcs']:.4f} ± {iso_metrics['apcs_std']:.4f}")
print(f"Isotropy Score (λ_min/λ_max): {iso_metrics['isotropy_score']:.4f}")
print(f"Participation Ratio: {iso_metrics['participation_ratio']:.1f} / {iso_metrics['nominal_dim']}")
print(f"Effective Dim (entropy): {iso_metrics['effective_dim_entropy']:.1f}")
print(f"Dims for 90% variance: {iso_metrics['dim_90_variance']}")
print()
print(diagnose_isotropy(iso_metrics))
print()

print("=" * 60)
print("ANISOTROPIC EMBEDDINGS")
print("=" * 60)
print(f"APCS: {aniso_metrics['apcs']:.4f} ± {aniso_metrics['apcs_std']:.4f}")
print(f"Isotropy Score (λ_min/λ_max): {aniso_metrics['isotropy_score']:.4f}")
print(f"Participation Ratio: {aniso_metrics['participation_ratio']:.1f} / {aniso_metrics['nominal_dim']}")
print(f"Effective Dim (entropy): {aniso_metrics['effective_dim_entropy']:.1f}")
print(f"Dims for 90% variance: {aniso_metrics['dim_90_variance']}")
print()
print(diagnose_isotropy(aniso_metrics))
Code
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Plot eigenvalue spectra
ax1 = axes[0]
ax1.semilogy(iso_metrics['eigenvalues'][:50], 'b-', linewidth=2, label='Isotropic')
ax1.semilogy(aniso_metrics['eigenvalues'][:50], 'r-', linewidth=2, label='Anisotropic')
ax1.set_xlabel('Eigenvalue Rank')
ax1.set_ylabel('Eigenvalue (log scale)')
ax1.set_title('Eigenvalue Spectrum')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Plot cumulative variance explained
ax2 = axes[1]
iso_cumvar = np.cumsum(iso_metrics['eigenvalues']) / np.sum(iso_metrics['eigenvalues'])
aniso_cumvar = np.cumsum(aniso_metrics['eigenvalues']) / np.sum(aniso_metrics['eigenvalues'])
# ax2.plot(iso_cumvar[:50], 'b-', linewidth=2, label='Isotropic')
# ax2.plot(aniso_cumvar[:50], 'r-', linewidth=2, label='Anisotropic')
ax2.plot(iso_cumvar[:120], 'b-', linewidth=2, label='Isotropic')
ax2.plot(aniso_cumvar[:120], 'r-', linewidth=2, label='Anisotropic')
ax2.axhline(y=0.9, color='gray', linestyle='--', label='90% threshold')
ax2.set_xlabel('Number of Dimensions')
ax2.set_ylabel('Cumulative Variance Explained')
ax2.set_title('Cumulative Variance by Dimension')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()
Figure 217.4

Left Panel (Eigenvalue Spectrum)

  • The anisotropic embeddings (red) show a single dominant eigenvalue around 0.2, then crash dramatically to ~0.005 and stay flat. This is the signature of dimensional collapse: one direction captures most of the structure while all others are essentially noise.

  • The isotropic embeddings (blue) maintain relatively uniform eigenvalues around 0.01 across all ranks. No single direction dominates; variance is distributed evenly.

Right Panel (Cumulative Variance)

This shows the full 120 dimensions, which reveals the key insight:

  • The anisotropic line (red) starts at 0.26, that single first dimension immediately captures over a quarter of all variance. It then climbs steadily but stays consistently above the isotropic line.

  • The isotropic line (blue) starts near zero and climbs linearly because each dimension contributes roughly equal variance (~0.8% each).

  • Both reach 90% around the same number of dimensions (~108-113), which initially seems counterintuitive. But here’s why: the anisotropic embeddings get a “head start” from that dominant first dimension, then accumulate variance slowly from many weak dimensions. The isotropic embeddings accumulate steadily throughout. They converge because after that first big eigenvalue, the anisotropic remaining dimensions are actually weaker than the isotropic ones (visible in the left panel where red falls below blue after rank 1).

The practical problem this reveals: In anisotropic embeddings, 26% of the representational capacity encodes “the direction everyone points” (i.e., non-discriminative information). Only 74% remains for actually distinguishing between entities.

Simulated streaming platform user embeddings
def simulate_streaming_user_embeddings(
    n_users: int = 5000,
    n_items: int = 1000,
    embedding_dim: int = 64,
    popularity_skew: float = 0.5
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Simulate user embeddings from a streaming platform.
    
    This mimics how matrix factorization or neural collaborative filtering
    learns user representations from viewing behavior.
    
    Parameters
    ----------
    n_users : int
        Number of users
    n_items : int
        Number of content items (shows, movies)
    embedding_dim : int
        Dimension of embeddings
    popularity_skew : float
        Degree of popularity bias (0 = uniform, 1+ = extreme Zipf)
        Higher values = more views concentrated on popular items
    
    Returns
    -------
    user_embeddings, item_embeddings, interaction_matrix, popularity
    """
    np.random.seed(42)
    
    # Generate item embeddings (content representation)
    item_embeddings = np.random.randn(n_items, embedding_dim)
    item_embeddings = item_embeddings / np.linalg.norm(item_embeddings, axis=1, keepdims=True)
    
    # Item popularity follows Zipf distribution
    ranks = np.arange(1, n_items + 1)
    popularity = 1 / (ranks ** popularity_skew)
    popularity = popularity / popularity.sum()
    
    # Generate user viewing history
    views_per_user = np.random.poisson(50, n_users)
    
    interaction_matrix = np.zeros((n_users, n_items))
    
    for user_idx in range(n_users):
        n_views = views_per_user[user_idx]
        
        # Each user has latent preferences
        user_preference = np.random.randn(embedding_dim)
        user_preference = user_preference / np.linalg.norm(user_preference)
        
        # Viewing probability combines preference and popularity
        preference_scores = item_embeddings @ user_preference
        combined_scores = preference_scores + 3 * np.log(popularity + 1e-10)
        probs = np.exp(combined_scores - combined_scores.max())
        probs = probs / probs.sum()
        
        viewed_items = np.random.choice(n_items, size=n_views, replace=True, p=probs)
        for item in viewed_items:
            interaction_matrix[user_idx, item] += 1
    
    # Learn user embeddings as weighted average of viewed items
    user_embeddings = np.zeros((n_users, embedding_dim))
    for user_idx in range(n_users):
        weights = interaction_matrix[user_idx]
        if weights.sum() > 0:
            user_embeddings[user_idx] = (weights @ item_embeddings) / weights.sum()
    
    norms = np.linalg.norm(user_embeddings, axis=1, keepdims=True)
    norms = np.where(norms == 0, 1, norms)
    user_embeddings = user_embeddings / norms
    
    return user_embeddings, item_embeddings, interaction_matrix, popularity


# Compare different popularity skew levels
skew_levels = [0.0, 0.5, 1.0, 1.5]
results = {}

print("Impact of Popularity Bias on User Embedding Isotropy")
print("=" * 60)
for skew in skew_levels:
    user_emb, item_emb, interactions, pop = simulate_streaming_user_embeddings(
        n_users=5000, popularity_skew=skew
    )
    metrics = compute_isotropy_metrics(user_emb)
    results[skew] = {
        'embeddings': user_emb,
        'metrics': metrics,
        'popularity': pop
    }
    print(f"\nPopularity Skew = {skew}")
    print(f"  APCS: {metrics['apcs']:.4f}")
    print(f"  Effective Dim: {metrics['participation_ratio']:.1f}/{metrics['nominal_dim']}")
    print(f"  Top item gets {pop[0]*100:.1f}% of popularity")
Code
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

for idx, (skew, data) in enumerate(results.items()):
    ax = axes[idx // 2, idx % 2]
    
    emb = data['embeddings']
    tsne = TSNE(n_components=2, random_state=42, perplexity=30)
    emb_2d = tsne.fit_transform(emb[:1000])
    
    ax.scatter(emb_2d[:, 0], emb_2d[:, 1], alpha=0.5, s=10, c='steelblue')
    ax.set_title(f"Popularity Skew = {skew}\nAPCS = {data['metrics']['apcs']:.3f}, " + 
                 f"Eff. Dim = {data['metrics']['participation_ratio']:.1f}")
    ax.set_xlabel('t-SNE 1')
    ax.set_ylabel('t-SNE 2')

plt.tight_layout()
plt.show()
Figure 217.5

Figure 217.5 illustrates the representation collapse problem in recommendation systems as popularity skew increases.

  1. With no popularity skew (0.0), embeddings are well-distributed across the space. The t-SNE shows a roughly uniform scatter. The effective dimensionality is high (59.6), meaning the model uses the full representational capacity to distinguish items. APCS (Average Pairwise Cosine Similarity) is low (0.036), indicating items have diverse, distinguishable embeddings.
  2. As popularity skew increases to 0.5, items start clustering more tightly. Effective dimensionality drops to 8.7, and APCS jumps to 0.922; embeddings are becoming increasingly similar to each other.
  3. At skew = 1.0, the pattern continues with effective dimensionality collapsing to just 2.1.
  4. At extreme skew (1.5), you see dramatic collapse: the embeddings now form distinct tight clusters and a characteristic “horseshoe” or curved manifold structure. Effective dimensionality is only 1.4, and APCS hits 0.999: nearly all item embeddings have converged to essentially the same representation.

The takeaway: When training data is dominated by popular items (high popularity skew), the model learns to represent everything similarly rather than capturing item-specific features. This is problematic because it destroys the model’s ability to make personalized recommendations. If all items look the same in embedding space, the system can’t meaningfully distinguish between them for different users. This motivates techniques like popularity debiasing, inverse propensity weighting, or contrastive objectives to maintain representational diversity.

217.0.12 Correcting Anisotropy

Several techniques can improve embedding isotropy:

  1. Post-hoc whitening (ZCA): Transform embeddings to have identity covariance:

\[ \tilde{\mathbf{E}} = (\mathbf{E} - \boldsymbol{\mu}) \mathbf{\Sigma}^{-1/2} \]

  1. All-but-the-top removal: Remove the top \(k\) principal components that capture the “common direction.” This technique, proposed by Mu, Bhat, and Viswanath (2017), removes the mean direction that dominates anisotropic spaces.
  2. Contrastive training objectives: Methods like SimCLR and uniformity losses encourage isotropy during training (Wang and Isola 2020; Chen et al. 2020).
Methods to correct anisotropic embeddings
def whiten_embeddings(embeddings: np.ndarray, center: bool = True) -> np.ndarray:
    """Apply ZCA whitening to make embeddings isotropic."""
    if center:
        mean = np.mean(embeddings, axis=0)
        centered = embeddings - mean
    else:
        centered = embeddings.copy()
    
    cov = (centered.T @ centered) / len(centered)
    eigenvalues, eigenvectors = np.linalg.eigh(cov)
    eigenvalues = np.maximum(eigenvalues, 1e-6)
    
    whitening_matrix = eigenvectors @ np.diag(1.0 / np.sqrt(eigenvalues)) @ eigenvectors.T
    return centered @ whitening_matrix


def remove_top_components(embeddings: np.ndarray, n_remove: int = 1) -> np.ndarray:
    """Remove top principal components (the common direction)."""
    mean = np.mean(embeddings, axis=0)
    centered = embeddings - mean
    
    U, S, Vt = svd(centered, full_matrices=False)
    top_components = Vt[:n_remove]
    
    result = centered.copy()
    for component in top_components:
        projections = (centered @ component).reshape(-1, 1)
        result = result - projections * component
    
    return result


# Demonstrate correction
print("CORRECTION METHODS FOR ANISOTROPIC EMBEDDINGS")
print("=" * 60)

print("\nOriginal anisotropic embeddings:")
print(diagnose_isotropy(aniso_metrics))

whitened = whiten_embeddings(anisotropic_emb)
whitened_metrics = compute_isotropy_metrics(whitened)
print("\nAfter ZCA whitening:")
print(diagnose_isotropy(whitened_metrics))

cleaned = remove_top_components(anisotropic_emb, n_remove=3)
cleaned_metrics = compute_isotropy_metrics(cleaned)
print("\nAfter removing top 3 components:")
print(diagnose_isotropy(cleaned_metrics))

217.1 Hubness

We’ve established that embeddings should spread evenly across dimensions. But there’s another geometric pathology lurking in high-dimensional spaces: hubness.

Some points become “hubs” that appear as nearest neighbors of many other points, even when they shouldn’t be semantically related. This distorts retrieval and recommendation.

Hubness emerges from the geometry of high-dimensional spaces, interacts with anisotropy, and has its own distinct consequences and remedies. We turn to this phenomenon next.

217.1.1 The Curse of Dimensionality and Hubness

In high-dimensional spaces, a phenomenon called hubness distorts nearest-neighbor retrieval. Some points become hubs (i.e., appearing as nearest neighbors of disproportionately many other points). Conversely, anti-hubs rarely appear as anyone’s neighbor.

Radovanovic, Nanopoulos, and Ivanovic (2010) showed that this emerges from the concentration of distances in high dimensions. As dimensionality increases, distances become more uniform, making nearest-neighbor relationships increasingly arbitrary.

Let \(N_k(x)\) denote the k-occurrence of point \(x\): how many times \(x\) appears among the \(k\) nearest neighbors of other points. Hubness manifests as positive skew in \(N_k\):

\[ S_{N_k} = \frac{\mathbb{E}[(N_k - \mu_{N_k})^3]}{\sigma_{N_k}^3} \]

Functions to measure and diagnose hubness
def compute_hubness_metrics(embeddings: np.ndarray, k: int = 10) -> Dict:
    """
    Compute hubness metrics for embedding space.
    """
    n = len(embeddings)
    
    distances = cdist(embeddings, embeddings, metric='euclidean')
    np.fill_diagonal(distances, np.inf)
    knn_indices = np.argsort(distances, axis=1)[:, :k]
    
    k_occurrences = np.zeros(n)
    for neighbors in knn_indices:
        for neighbor in neighbors:
            k_occurrences[neighbor] += 1
    
    skewness = stats.skew(k_occurrences)
    
    expected_occurrences = k
    hub_threshold = expected_occurrences + 2 * np.std(k_occurrences)
    antihub_threshold = max(0, expected_occurrences - 2 * np.std(k_occurrences))
    
    hubs = np.where(k_occurrences > hub_threshold)[0]
    antihubs = np.where(k_occurrences < antihub_threshold)[0]
    
    mean_occ = np.mean(k_occurrences)
    excess = np.sum(np.maximum(k_occurrences - mean_occ, 0))
    robin_hood = excess / np.sum(k_occurrences)
    
    return {
        'skewness': skewness,
        'n_hubs': len(hubs),
        'n_antihubs': len(antihubs),
        'hub_fraction': len(hubs) / n,
        'antihub_fraction': len(antihubs) / n,
        'robin_hood_index': robin_hood,
        'k_occurrences': k_occurrences,
        'max_occurrences': np.max(k_occurrences),
        'min_occurrences': np.min(k_occurrences),
        'mean_occurrences': np.mean(k_occurrences),
        'std_occurrences': np.std(k_occurrences)
    }


def diagnose_hubness(metrics: Dict) -> str:
    """Interpret hubness metrics."""
    diagnosis = []
    
    if metrics['skewness'] < 0.5:
        diagnosis.append("✓ Low hubness (skewness < 0.5)")
    elif metrics['skewness'] < 1.5:
        diagnosis.append("⚠ Moderate hubness (0.5 ≤ skewness < 1.5)")
    else:
        diagnosis.append("✗ Severe hubness (skewness ≥ 1.5)")
    
    diagnosis.append(f"  Hubs: {metrics['n_hubs']} ({metrics['hub_fraction']:.1%})")
    diagnosis.append(f"  Anti-hubs: {metrics['n_antihubs']} ({metrics['antihub_fraction']:.1%})")
    diagnosis.append(f"  k-occurrence range: [{metrics['min_occurrences']:.0f}, {metrics['max_occurrences']:.0f}]")
    
    return "\n".join(diagnosis)


# Compute hubness
print("HUBNESS ANALYSIS")
print("=" * 60)

print("\nIsotropic embeddings:")
iso_hubness = compute_hubness_metrics(isotropic_emb[:2000], k=10)
print(diagnose_hubness(iso_hubness))

print("\nAnisotropic embeddings:")
aniso_hubness = compute_hubness_metrics(anisotropic_emb[:2000], k=10)
print(diagnose_hubness(aniso_hubness))
Code
fig, axes = plt.subplots(1, 2)

ax1 = axes[0]
ax1.hist(iso_hubness['k_occurrences'], bins=30, edgecolor='black', alpha=0.7, color='steelblue')
ax1.axvline(iso_hubness['mean_occurrences'], color='red', linestyle='--', linewidth=2,
            label=f"Mean = {iso_hubness['mean_occurrences']:.1f}")
ax1.set_xlabel('k-Occurrences')
ax1.set_ylabel('Frequency')
ax1.set_title(f'Isotropic Embeddings\nSkewness = {iso_hubness["skewness"]:.2f}')
ax1.legend()

ax2 = axes[1]
ax2.hist(aniso_hubness['k_occurrences'], bins=30, edgecolor='black', alpha=0.7, color='coral')
ax2.axvline(aniso_hubness['mean_occurrences'], color='red', linestyle='--', linewidth=2,
            label=f"Mean = {aniso_hubness['mean_occurrences']:.1f}")
ax2.set_xlabel('k-Occurrences')
ax2.set_ylabel('Frequency')
ax2.set_title(f'Anisotropic Embeddings\nSkewness = {aniso_hubness["skewness"]:.2f}')
ax2.legend()

plt.tight_layout()
plt.show()
Figure 217.6

Figure 217.6 illustrates the hubness problem in embedding spaces by showing the distribution of k-occurrences (how often each point appears as a k-nearest neighbor of other points).

  • Left panel (Isotropic Embeddings): Despite being “isotropic” (uniform variance across dimensions), you see a highly skewed distribution (skewness = 7.22). Most points appear as neighbors very few times (the tall bar near 0), but a small number of points appear as neighbors extremely often, up to 350 times. These are hubs: points that disproportionately dominate the nearest neighbor lists of many other points, even though they may not be semantically related. This is problematic for retrieval because the same few items keep getting recommended regardless of the query.

  • Right panel (Anisotropic Embeddings): Counterintuitively, the anisotropic embeddings show a much healthier, roughly symmetric distribution (skewness ≈ -0.16) centered around the expected mean of 10. Each point appears as a neighbor a roughly equal number of times, which is what you’d want for fair retrieval.

The paradox here: The labels seem swapped from what you’d typically expect. Usually anisotropic embeddings (where variance concentrates in few dimensions) are associated with hubness problems. This figure demonstrates that:

  1. Isotropy alone doesn’t prevent hubness. It can emerge from high dimensionality itself (the “curse of dimensionality”)

  2. Or the specific structure of these “anisotropic” embeddings happens to mitigate hubness through some other property

Practical implication: When evaluating embedding quality, you need to check the k-occurrence distribution, not just isotropy metrics. Hubness directly degrades retrieval performance by creating “popular” points that crowd out genuinely relevant neighbors.

217.1.2 Business Impact of Hubness

In a recommendation system, hubs manifest as items recommended to everyone regardless of preferences (e..g, typically already-popular items). Anti-hubs never get recommended despite potential relevance. This creates:

  • Popularity bias: Popular items dominate all recommendations
  • Long-tail invisibility: Niche products become undiscoverable
  • Revenue loss: Customers see the same items everywhere, reducing discovery

217.2 Embedding Stability

Embeddings should be robust to initialization randomness, data perturbations, and temporal evolution. Unstable embeddings lead to inconsistent downstream behavior.

217.2.1 Measuring Stability Across Random Seeds

Embedding algorithms contain stochastic components (e.g., random initialization, stochastic gradient descent, negative sampling) that produce different solutions across runs. If embeddings change substantially with different random seeds, downstream applications face several problems:

  • irreproducible research findings

  • inconsistent recommendation quality in production

  • difficulty diagnosing whether performance changes stem from model improvements or random variation.

Stability is particularly important in high-stakes applications. A recommendation system that produces substantially different rankings depending on when it was trained undermines user trust and complicates A/B testing. For academic research, unstable embeddings make it difficult to attribute performance differences to methodological improvements versus lucky seeds.

217.2.1.1 The Identification Problem

Embeddings present a fundamental challenge for stability measurement: they are only identified up to orthogonal transformations. If \(\mathbf{E}\) is a valid embedding matrix, then \(\mathbf{E}\mathbf{Q}\) is equally valid for any orthogonal matrix \(\mathbf{Q}\) (rotation or reflection), since inner products remain unchanged:

\[(\mathbf{E}\mathbf{Q})(\mathbf{E}\mathbf{Q})^\top = \mathbf{E}\mathbf{Q}\mathbf{Q}^\top\mathbf{E}^\top = \mathbf{E}\mathbf{E}^\top\]

This means two embedding matrices could encode identical information while appearing completely different element-wise. Naively computing correlations between raw embedding values would severely underestimate stability.

217.2.1.2 Procrustes Analysis

Procrustes analysis solves this identification problem by finding the optimal orthogonal transformation that aligns one embedding matrix to another before measuring differences (Gower and Dijksterhuis 2004). Given two centered and scaled embedding matrices \(\mathbf{E}_1\) and \(\mathbf{E}_2\), we seek the orthogonal matrix \(\mathbf{R}^*\) that minimizes:

\[\mathbf{R}^* = \arg\min_{\mathbf{R}^\top\mathbf{R} = \mathbf{I}} \|\mathbf{E}_1 - \mathbf{E}_2\mathbf{R}\|_F^2\]

The solution follows from the singular value decomposition. Computing \(\mathbf{M} = \mathbf{E}_1^\top\mathbf{E}_2\) and its SVD \(\mathbf{M} = \mathbf{U}\mathbf{S}\mathbf{V}^\top\), the optimal rotation is:

\[\mathbf{R}^* = \mathbf{V}\mathbf{U}^\top\]

This classic result from Schönemann (1966) provides a closed-form solution that aligns the embeddings optimally in the least-squares sense.

217.2.1.3 Stability Metrics

After alignment, we compute three complementary metrics:

  1. Procrustes Distance measures the residual misalignment after optimal transformation:

\[d_P = \|\mathbf{E}_1 - \mathbf{E}_2\mathbf{R}^*\|_F\]

Values near zero indicate that embeddings are nearly identical up to rotation. For normalized embeddings, the maximum possible distance is \(\sqrt{2n}\) (when embeddings are orthogonal), so values can be interpreted relative to this bound.

  1. Mean Cosine Similarity After Alignment captures how well individual embedding vectors match their counterparts:

\[\bar{c} = \frac{1}{n}\sum_{i=1}^{n} \frac{\mathbf{e}_{1i}^\top (\mathbf{R}^*\mathbf{e}_{2i})}{\|\mathbf{e}_{1i}\| \|\mathbf{e}_{2i}\|}\]

This metric is more interpretable than Procrustes distance. Values near 1.0 indicate individual items land in nearly the same location across runs.

  1. Pairwise Similarity Correlation asks whether the similarity structure is preserved, regardless of absolute positions:

\[\rho = \text{corr}\left(\text{vec}(\mathbf{E}_1\mathbf{E}_1^\top), \text{vec}(\mathbf{E}_2\mathbf{E}_2^\top)\right)\]

This is arguably the most important metric for downstream applications. Even if individual embeddings shift, what matters is whether similar items remain similar and dissimilar items remain dissimilar. High correlation indicates the relational structure is stable.

Functions to measure embedding stability
def procrustes_similarity(emb1: np.ndarray, emb2: np.ndarray) -> Dict:
    """Compute similarity between embeddings after optimal alignment."""
    assert emb1.shape == emb2.shape
    n, d = emb1.shape
    
    emb1_centered = emb1 - np.mean(emb1, axis=0)
    emb2_centered = emb2 - np.mean(emb2, axis=0)
    
    norm1 = np.linalg.norm(emb1_centered, 'fro')
    norm2 = np.linalg.norm(emb2_centered, 'fro')
    emb1_scaled = emb1_centered / (norm1 + 1e-10)
    emb2_scaled = emb2_centered / (norm2 + 1e-10)
    
    M = emb1_scaled.T @ emb2_scaled
    U, S, Vt = svd(M)
    R = Vt.T @ U.T
    
    emb2_aligned = emb2_scaled @ R
    procrustes_distance = np.sqrt(np.sum((emb1_scaled - emb2_aligned) ** 2))
    per_point_cos = np.sum(emb1_scaled * emb2_aligned, axis=1)
    
    n_sample = min(1000, n)
    idx = np.random.choice(n, n_sample, replace=False)
    sim1 = emb1_scaled[idx] @ emb1_scaled[idx].T
    sim2 = emb2_aligned[idx] @ emb2_aligned[idx].T
    upper_tri = np.triu_indices(n_sample, k=1)
    correlation = np.corrcoef(sim1[upper_tri], sim2[upper_tri])[0, 1]
    
    return {
        'procrustes_distance': procrustes_distance,
        'mean_cosine_after_alignment': np.mean(per_point_cos),
        'pairwise_similarity_correlation': correlation,
        'aligned_embeddings': emb2_aligned * norm2
    }


def measure_seed_stability(embedding_fn: Callable, n_seeds: int = 5, **kwargs) -> Dict:
    """Measure stability across random seeds."""
    embeddings_list = [embedding_fn(seed=seed, **kwargs) for seed in range(n_seeds)]
    
    procrustes_distances = []
    correlations = []
    
    for i in range(n_seeds):
        for j in range(i + 1, n_seeds):
            result = procrustes_similarity(embeddings_list[i], embeddings_list[j])
            procrustes_distances.append(result['procrustes_distance'])
            correlations.append(result['pairwise_similarity_correlation'])
    
    return {
        'mean_procrustes_distance': np.mean(procrustes_distances),
        'std_procrustes_distance': np.std(procrustes_distances),
        'mean_correlation': np.mean(correlations),
        'std_correlation': np.std(correlations)
    }


# Demonstration
def generate_embeddings_with_noise(seed: int, n: int = 1000, d: int = 64, 
                                    noise_level: float = 0.1) -> np.ndarray:
    np.random.seed(42)
    base_embeddings = np.random.randn(n, d)
    np.random.seed(seed)
    noise = np.random.randn(n, d) * noise_level
    return base_embeddings + noise

print("SEED STABILITY ANALYSIS")
print("=" * 60)

high_stability = measure_seed_stability(generate_embeddings_with_noise, n_seeds=5, noise_level=0.1)
print(f"\nHigh Stability (noise=0.1):")
print(f"  Procrustes distance: {high_stability['mean_procrustes_distance']:.4f}")
print(f"  Pairwise correlation: {high_stability['mean_correlation']:.4f}")

low_stability = measure_seed_stability(generate_embeddings_with_noise, n_seeds=5, noise_level=1.0)
print(f"\nLow Stability (noise=1.0):")
print(f"  Procrustes distance: {low_stability['mean_procrustes_distance']:.4f}")
print(f"  Pairwise correlation: {low_stability['mean_correlation']:.4f}")

217.2.1.4 Interpreting Results

The demonstration compares high-stability (noise = 0.1) and low-stability (noise = 1.0) scenarios:

Condition Procrustes Distance Pairwise Correlation
High Stability ~0.14 ~0.99
Low Stability ~0.89 ~0.50

With low noise, embeddings are nearly identical after alignment. Procrustes distance is small and pairwise correlations approach 1.0. The similarity structure is almost perfectly preserved across seeds.

With high noise, Procrustes distance increases substantially and pairwise correlation drops to around 0.5. This means approximately half the variance in pairwise similarities is attributable to random seed choice rather than true semantic relationships. This is a concerning level of instability for production systems.

217.2.2 Practical Guidelines

Based on empirical work in the literature, reasonable stability thresholds are:

  • Pairwise correlation > 0.95: Excellent stability; seed choice has minimal impact
  • Pairwise correlation 0.85 to 0.95: Acceptable for most applications; consider averaging across seeds
  • Pairwise correlation < 0.85: Problematic; investigate sources of instability

When stability is low, several remediation strategies exist:

  • using deterministic algorithms where available,

  • averaging embeddings across multiple seeds,

  • increasing training data or epochs to reduce optimization variance,

  • using consensus-based approaches that identify the stable core of the embedding space.


217.2.3 Temporal Stability Monitoring

Production embedding systems face a challenge that static evaluation cannot capture: the world changes. User preferences evolve, item catalogs turn over, and the underlying data distribution shifts. Embeddings trained on historical data gradually become stale, degrading recommendation quality in ways that may not trigger obvious failures. Temporal stability monitoring provides early warning of drift before it impacts business metrics.

This problem is distinct from random seed instability. Seed instability reflects optimization variance under fixed conditions; temporal drift reflects genuine changes in the underlying relationships the embeddings encode. Both matter, but they require different monitoring approaches and remediation strategies.

217.2.3.1 Sources of Embedding Drift

Embedding drift emerges from several mechanisms:

  • Concept drift occurs when the relationships between entities genuinely change. A product that was premium becomes mainstream; a creator who made comedy pivots to drama; political terminology shifts valence. The old embeddings accurately reflected past relationships but no longer describe the current reality.

  • Population drift arises from changes in the entity set itself. New items lack embedding history and must be inferred or cold-started. Departed items leave gaps in the similarity structure. If popular items churn frequently, large portions of the embedding space become unstable.

  • Feedback loops create self-reinforcing drift. Recommendations based on current embeddings shape user behavior, which generates training data for future embeddings. Small initial biases can amplify over time, causing embeddings to drift toward degenerate states that reflect algorithmic artifacts rather than user preferences.

  • Distribution shift in training data (e.g., seasonal patterns, marketing campaigns, external events) can cause embeddings to fluctuate even when underlying preferences remain stable. Monitoring must distinguish meaningful drift from noise.

217.2.3.2 Drift Detection Framework

Effective monitoring requires comparing embeddings across time points while accounting for the identification problem discussed earlier. Our framework computes several complementary metrics at each monitoring interval:

  1. Mean Embedding Drift measures average movement in the aligned embedding space:

\[\bar{d}_t = \frac{1}{n}\sum_{i=1}^{n} \|\mathbf{e}_i^{(t)} - \mathbf{R}^*\mathbf{e}_i^{(t-1)}\|_2\]

where \(\mathbf{R}^*\) is the Procrustes alignment matrix. This captures typical displacement magnitude. The standard deviation of per-entity drift identifies whether movement is uniform or concentrated in specific items.

  1. Pairwise Correlation tracks structural consistency:

\[\rho_t = \text{corr}\left(\text{vec}(\mathbf{E}^{(t)}{\mathbf{E}^{(t)}}^\top), \text{vec}(\mathbf{E}^{(t-1)}{\mathbf{E}^{(t-1)}}^\top)\right)\]

This metric is robust to global transformations and focuses on whether the similarity structure (i.e., which items are similar to which) remains stable. Correlation above 0.95 typically indicates acceptable stability; drops below 0.90 warrant investigation.

  1. Anomalous Fraction identifies entity-level outliers:

\[f_{\text{anom}} = \frac{1}{n}\sum_{i=1}^{n} \mathbf{1}\left[d_i > \mu_d + k\sigma_d\right]\]

where \(d_i\) is the drift for entity \(i\) and \(k\) is typically set to 2 or 3. A baseline anomaly rate around 2-5% is expected from natural variation. Elevated rates suggest systematic issues: perhaps a category of items was relabeled, or a data pipeline error corrupted certain features.

  1. Procrustes Distance provides a global summary:

\[D_P^{(t)} = \|\mathbf{E}^{(t)} - \mathbf{E}^{(t-1)}\mathbf{R}^*\|_F\]

Unlike mean drift, this is sensitive to the embedding dimension and scale. It’s most useful for tracking trends over time rather than interpreting absolute values.

Monitoring temporal stability
class TemporalEmbeddingMonitor:
    """Monitor embedding stability over time."""
    
    def __init__(self, entity_ids: List[str]):
        self.entity_ids = entity_ids
        self.n_entities = len(entity_ids)
        self.history = []
        self.stability_metrics = []
    
    def record_snapshot(self, timestamp, embeddings: np.ndarray):
        self.history.append({'timestamp': timestamp, 'embeddings': embeddings.copy()})
        
        if len(self.history) >= 2:
            prev = self.history[-2]['embeddings']
            curr = self.history[-1]['embeddings']
            metrics = self._compute_transition_metrics(prev, curr)
            metrics['timestamp'] = timestamp
            self.stability_metrics.append(metrics)
    
    def _compute_transition_metrics(self, emb_prev: np.ndarray, emb_curr: np.ndarray) -> Dict:
        alignment = procrustes_similarity(emb_prev, emb_curr)
        emb_curr_aligned = alignment['aligned_embeddings']
        
        prev_norm = emb_prev / (np.linalg.norm(emb_prev, axis=1, keepdims=True) + 1e-10)
        curr_norm = emb_curr_aligned / (np.linalg.norm(emb_curr_aligned, axis=1, keepdims=True) + 1e-10)
        
        per_entity_drift = 1 - np.sum(prev_norm * curr_norm, axis=1)
        drift_threshold = np.mean(per_entity_drift) + 2 * np.std(per_entity_drift)
        anomalous_entities = np.where(per_entity_drift > drift_threshold)[0]
        
        return {
            'procrustes_distance': alignment['procrustes_distance'],
            'mean_drift': np.mean(per_entity_drift),
            'max_drift': np.max(per_entity_drift),
            'std_drift': np.std(per_entity_drift),
            'pairwise_correlation': alignment['pairwise_similarity_correlation'],
            'n_anomalous': len(anomalous_entities),
            'anomalous_fraction': len(anomalous_entities) / self.n_entities,
            'per_entity_drift': per_entity_drift
        }
    
    def get_stability_report(self) -> pd.DataFrame:
        if not self.stability_metrics:
            return pd.DataFrame()
        df = pd.DataFrame(self.stability_metrics)
        return df.drop(columns=['per_entity_drift'], errors='ignore')
    
    def detect_drift_events(self, window_size: int = 5, threshold_sigma: float = 2.0) -> List[Dict]:
        if len(self.stability_metrics) < window_size + 1:
            return []
        
        drift_values = [m['mean_drift'] for m in self.stability_metrics]
        alerts = []
        
        for i in range(window_size, len(drift_values)):
            baseline = drift_values[i-window_size:i]
            baseline_mean = np.mean(baseline)
            baseline_std = np.std(baseline) + 1e-10
            
            current = drift_values[i]
            z_score = (current - baseline_mean) / baseline_std
            
            if z_score > threshold_sigma:
                alerts.append({
                    'timestamp': self.stability_metrics[i]['timestamp'],
                    'drift': current,
                    'z_score': z_score,
                    'n_anomalous': self.stability_metrics[i]['n_anomalous']
                })
        
        return alerts


# Demonstration
np.random.seed(42)
n_users = 1000
embedding_dim = 64
n_periods = 20

user_ids = [f"user_{i}" for i in range(n_users)]
monitor = TemporalEmbeddingMonitor(user_ids)

base_embeddings = np.random.randn(n_users, embedding_dim)

print("TEMPORAL STABILITY MONITORING SIMULATION")
print("=" * 60)

for period in range(n_periods):
    noise = np.random.randn(n_users, embedding_dim) * 0.05
    
    if period == 10:
        noise *= 5
        print(f"⚠ Period {period}: Injected GLOBAL drift event")
    
    if period == 15:
        noise[100:150] *= 10
        print(f"⚠ Period {period}: Injected ENTITY-SPECIFIC drift")
    
    current_embeddings = base_embeddings + noise
    base_embeddings = current_embeddings
    monitor.record_snapshot(timestamp=period, embeddings=current_embeddings)

stability_df = monitor.get_stability_report()
print("\nStability Report (sample):")
print(stability_df[['timestamp', 'mean_drift', 'anomalous_fraction', 'pairwise_correlation']].head(10).to_string())

drift_alerts = monitor.detect_drift_events(window_size=5)
print("\nDetected Drift Events:")
for alert in drift_alerts:
    print(f"  Period {alert['timestamp']}: z-score={alert['z_score']:.2f}")
Code
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

ax1 = axes[0, 0]
ax1.plot(stability_df['timestamp'], stability_df['mean_drift'], 'b-o', linewidth=2)
ax1.fill_between(stability_df['timestamp'],
                  stability_df['mean_drift'] - stability_df['std_drift'],
                  stability_df['mean_drift'] + stability_df['std_drift'], alpha=0.3)
ax1.set_xlabel('Time Period')
ax1.set_ylabel('Mean Drift')
ax1.set_title('Mean Embedding Drift Over Time')
for alert in drift_alerts:
    ax1.axvline(x=alert['timestamp'], color='red', linestyle='--', alpha=0.7)

ax2 = axes[0, 1]
ax2.plot(stability_df['timestamp'], stability_df['pairwise_correlation'], 'g-o', linewidth=2)
ax2.set_xlabel('Time Period')
ax2.set_ylabel('Pairwise Correlation')
ax2.set_title('Structural Consistency Over Time')
ax2.set_ylim([0.5, 1.05])

ax3 = axes[1, 0]
ax3.bar(stability_df['timestamp'], stability_df['anomalous_fraction'], color='orange', alpha=0.7)
ax3.set_xlabel('Time Period')
ax3.set_ylabel('Anomalous Fraction')
ax3.set_title('Entity-Level Anomalies')
ax3.axhline(y=0.05, color='red', linestyle='--', label='5% threshold')
ax3.legend()

ax4 = axes[1, 1]
ax4.plot(stability_df['timestamp'], stability_df['procrustes_distance'], 'm-o', linewidth=2)
ax4.set_xlabel('Time Period')
ax4.set_ylabel('Procrustes Distance')
ax4.set_title('Global Structure Shift')

plt.tight_layout()
plt.show()
Figure 217.7

217.2.3.3 Interpreting the Monitoring Dashboard

Figure 217.7 presents a four-panel monitoring dashboard typical of production embedding systems:

Panel A (Mean Drift Over Time) shows the average embedding displacement at each monitoring interval, with uncertainty bands. Gradual upward trends suggest accumulating drift that may require retraining. Sudden spikes, marked by red vertical lines, indicate drift alerts (i.e., periods where movement exceeded normal thresholds). These warrant immediate investigation: What changed in the data? Was there a pipeline issue? Did a major external event shift user behavior?

Panel B (Structural Consistency) tracks pairwise correlation over time. This is often the most actionable metric because it directly reflects whether the similarity relationships that drive recommendations remain valid. Stable correlation near 1.0 indicates the embedding structure is holding despite surface-level drift. Declining correlation signals that the fundamental organization of the embedding space is changing (i.e., similar items are becoming dissimilar, or vice versa).

Panel C (Entity-Level Anomalies) shows what fraction of entities experienced unusually large drift at each time point. The red dashed line indicates a 5% threshold; rates consistently above this suggest systematic issues rather than random variation. Examining which specific entities are flagged often reveals the root cause. Perhaps all items from a particular category, or all items added after a certain date.

Panel D (Global Structure Shift) displays Procrustes distance over time. This aggregate measure is useful for detecting regime changes (i.e., periods where the overall embedding geometry shifted substantially). Unlike pairwise correlation, it’s sensitive to global scaling and rotation, making it useful for detecting issues like feature normalization bugs that might preserve relative similarities while dramatically shifting absolute positions.

217.2.3.4 Alert Thresholds and Response

Setting appropriate alert thresholds requires balancing sensitivity against alert fatigue. Overly sensitive thresholds generate false alarms that teams learn to ignore; overly permissive thresholds miss genuine drift until downstream metrics suffer. Table 217.3 shows a reasonable starting point.

Table 217.3: Starting Point for Temporal Stability Monitoring
Metric Warning Critical
Mean Drift > 1.5× baseline > 2.5× baseline
Pairwise Correlation < 0.95 < 0.90
Anomalous Fraction > 5% > 10%
Procrustes Distance > 2σ above trend > 3σ above trend

These thresholds should be calibrated to each system based on historical variability and the cost of false positives versus missed detections.

When alerts trigger, the response depends on severity and pattern:

  • Isolated spikes often reflect data quality issues, such as missing features, pipeline delays, or upstream changes. Investigate recent deployments and data source modifications.
  • Gradual trends indicate natural drift requiring scheduled retraining. The monitoring data helps determine optimal retraining frequency.
  • Sudden regime shifts suggest major changes, such as new item categories, user population shifts, or algorithm modifications. These may require not just retraining but model architecture review.

217.2.3.5 Practical Considerations

  • Baseline establishment is critical. Before setting thresholds, collect several periods of monitoring data under stable conditions to characterize normal variation. Systems exhibit natural fluctuation from batch composition differences, time-of-day effects, and random sampling.

  • Alignment consistency requires using the same reference point for Procrustes alignment across monitoring periods. Typically this means aligning all snapshots to an initial baseline embedding rather than chaining alignments (which can accumulate errors).

  • Computational efficiency becomes important at scale. Computing full pairwise similarity matrices is \(O(n^2)\); sampling strategies or locality-sensitive hashing can reduce this for large item catalogs while maintaining statistical validity.

  • Causal attribution remains challenging. Drift detection tells you something changed but not why. Integrating embedding monitoring with data quality metrics, deployment logs, and external event calendars helps narrow down root causes.


218 Extrinsic Evaluation: Downstream Tasks

While intrinsic metrics assess embedding quality in isolation, extrinsic evaluation measures performance on actual tasks.

218.3 Node Classification

Test whether embeddings capture label-relevant structure by training a simple classifier on the learned representations. If a logistic regression on embeddings achieves high accuracy, the embedding space encodes the information needed to distinguish node categories.

Node classification evaluation
def evaluate_node_classification(
    embeddings: np.ndarray,
    labels: np.ndarray,
    train_fractions: List[float] = [0.1, 0.3, 0.5, 0.7, 0.9],
    n_trials: int = 10,
    random_state: int = 42
) -> pd.DataFrame:
    """
    Evaluate embeddings via node classification.

    Trains logistic regression on embeddings at varying label budgets
    and reports accuracy and macro-F1 with standard deviations across
    random train/test splits.

    Parameters
    ----------
    embeddings : np.ndarray
        Node embedding matrix of shape (n_nodes, d).
    labels : np.ndarray
        Integer node labels of shape (n_nodes,).
    train_fractions : list of float
        Fractions of labeled data to use for training.
    n_trials : int
        Number of random splits per fraction.
    random_state : int
        Base random seed for reproducibility.

    Returns
    -------
    pd.DataFrame
        Classification performance at each training fraction.
    """
    rng = np.random.RandomState(random_state)
    results = []

    for train_frac in train_fractions:
        trial_scores = []

        for trial in range(n_trials):
            n = len(embeddings)
            indices = rng.permutation(n)
            n_train = int(n * train_frac)

            train_idx = indices[:n_train]
            test_idx = indices[n_train:]

            if len(test_idx) == 0:
                continue

            clf = LogisticRegression(
                max_iter=1000,
                solver="lbfgs",
                multi_class="multinomial",
                random_state=trial
            )
            clf.fit(embeddings[train_idx], labels[train_idx])

            y_pred = clf.predict(embeddings[test_idx])
            y_true = labels[test_idx]

            trial_scores.append({
                "accuracy": accuracy_score(y_true, y_pred),
                "f1_micro": f1_score(y_true, y_pred, average="micro"),
                "f1_macro": f1_score(y_true, y_pred, average="macro"),
            })

        scores_df = pd.DataFrame(trial_scores)
        results.append({
            "train_fraction": train_frac,
            "accuracy_mean": scores_df["accuracy"].mean(),
            "accuracy_std": scores_df["accuracy"].std(),
            "f1_micro_mean": scores_df["f1_micro"].mean(),
            "f1_micro_std": scores_df["f1_micro"].std(),
            "f1_macro_mean": scores_df["f1_macro"].mean(),
            "f1_macro_std": scores_df["f1_macro"].std(),
        })

    return pd.DataFrame(results)


# classification_results = evaluate_node_classification(
#     embeddings, communities
# )

# Ensure embeddings and labels match
n_nodes = min(len(embeddings), len(communities))
print(f"Embeddings shape: {embeddings.shape}, Labels length: {len(communities)}")
print(f"Using first {n_nodes} nodes")

classification_results = evaluate_node_classification(
    embeddings[:n_nodes], communities[:n_nodes]
)

print("NODE CLASSIFICATION (Community Prediction)")
print("=" * 60)
print(classification_results.to_string(index=False, float_format="%.4f"))

The learning curve, accuracy as a function of label budget, is itself informative. Embeddings that reach high accuracy with only 10 % of labels encode richer structure than those that need 70 %.

Plot classification learning curve
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Accuracy
ax = axes[0]
ax.errorbar(
    classification_results["train_fraction"],
    classification_results["accuracy_mean"],
    yerr=classification_results["accuracy_std"],
    marker="o", capsize=4, linewidth=2
)
ax.set_xlabel("Training Fraction")
ax.set_ylabel("Accuracy")
ax.set_title("Classification Accuracy vs. Label Budget")
ax.set_ylim(0, 1.05)
ax.grid(True, alpha=0.3)

# Macro-F1
ax = axes[1]
ax.errorbar(
    classification_results["train_fraction"],
    classification_results["f1_macro_mean"],
    yerr=classification_results["f1_macro_std"],
    marker="s", capsize=4, linewidth=2, color="tab:orange"
)
ax.set_xlabel("Training Fraction")
ax.set_ylabel("Macro-F1")
ax.set_title("Macro-F1 vs. Label Budget")
ax.set_ylim(0, 1.05)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()
Figure 218.2

218.4 Clustering Alignment

Clustering evaluation asks a complementary question: do natural groups in embedding space match known communities? We measure this with three metrics.

Silhouette Score quantifies how similar objects are to their own cluster relative to other clusters:

\[ s(i) = \frac{b(i) - a(i)}{\max\bigl(a(i),\; b(i)\bigr)} \]

where \(a(i)\) is the mean intra-cluster distance and \(b(i)\) is the mean nearest-cluster distance. Values range from \(-1\) (misclassified) to \(+1\) (dense, well-separated clusters).

Adjusted Rand Index (ARI) measures agreement between two clusterings, corrected for chance. ARI \(= 1\) indicates perfect agreement; ARI \(\approx 0\) indicates random labelling.

Normalized Mutual Information (NMI) captures the information-theoretic overlap between predicted and true labels, normalized to \([0, 1]\).

Clustering alignment metrics
def evaluate_clustering(
    embeddings: np.ndarray,
    true_labels: np.ndarray,
    n_clusters_range: List[int] = None,
    n_init: int = 10,
    random_state: int = 42
) -> pd.DataFrame:
    """
    Evaluate embedding clustering quality.

    Runs K-means for each candidate number of clusters and compares
    the resulting assignments against ground-truth labels.

    Parameters
    ----------
    embeddings : np.ndarray
        Embedding matrix of shape (n_nodes, d).
    true_labels : np.ndarray
        Ground-truth cluster/community labels.
    n_clusters_range : list of int, optional
        Numbers of clusters to try. Defaults to the true count.
    n_init : int
        Number of K-means initializations.
    random_state : int
        Random seed.

    Returns
    -------
    pd.DataFrame
        Clustering quality metrics for each k.
    """
    n_true = len(np.unique(true_labels))

    if n_clusters_range is None:
        n_clusters_range = list(range(
            max(2, n_true - 2), n_true + 4
        ))

    results = []
    for k in n_clusters_range:
        kmeans = KMeans(
            n_clusters=k, n_init=n_init, random_state=random_state
        )
        pred_labels = kmeans.fit_predict(embeddings)

        results.append({
            "n_clusters": k,
            "silhouette": silhouette_score(embeddings, pred_labels),
            "ari": adjusted_rand_score(true_labels, pred_labels),
            "nmi": normalized_mutual_info_score(
                true_labels, pred_labels
            ),
            "inertia": kmeans.inertia_,
        })

    return pd.DataFrame(results)

# Align embeddings and labels
n_nodes = min(len(embeddings), len(communities))
emb_aligned = embeddings[:n_nodes]
com_aligned = communities[:n_nodes]

n_true_communities = len(np.unique(com_aligned))
cluster_range = list(range(
    max(2, n_true_communities - 2),
    n_true_communities + 4
))

cluster_results = evaluate_clustering(
    emb_aligned, com_aligned, n_clusters_range=cluster_range
)

# n_true_communities = len(np.unique(communities))
# cluster_range = list(range(
#     max(2, n_true_communities - 2),
#     n_true_communities + 4
# ))

# cluster_results = evaluate_clustering(
#     embeddings, communities, n_clusters_range=cluster_range
# )

print("CLUSTERING ALIGNMENT")
print("=" * 60)
print(cluster_results.to_string(index=False, float_format="%.4f"))
Plot clustering metrics across k
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

metrics = [
    ("silhouette", "Silhouette Score"),
    ("ari", "Adjusted Rand Index"),
    ("nmi", "Normalized Mutual Information"),
]

for ax, (col, title) in zip(axes, metrics):
    ax.plot(
        cluster_results["n_clusters"],
        cluster_results[col],
        marker="o", linewidth=2
    )
    ax.axvline(
        n_true_communities, color="red", linestyle="--",
        alpha=0.7, label=f"True k = {n_true_communities}"
    )
    ax.set_xlabel("Number of Clusters (k)")
    ax.set_ylabel(title)
    ax.set_title(title)
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()
Figure 218.3
t-SNE visualisation of clusters vs. ground truth
from sklearn.manifold import TSNE

# Use aligned data
n_nodes = min(len(embeddings), len(communities))
emb_aligned = embeddings[:n_nodes]
com_aligned = communities[:n_nodes]

# Find optimal k by ARI
best_k = int(cluster_results.loc[cluster_results["ari"].idxmax(), "n_clusters"])
kmeans_best = KMeans(n_clusters=best_k, n_init=10, random_state=42)
pred_best = kmeans_best.fit_predict(emb_aligned)

# t-SNE projection
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
emb_2d = tsne.fit_transform(emb_aligned)

fig, axes = plt.subplots(1, 2, figsize=(14, 6))

scatter_kw = dict(s=20, alpha=0.7, edgecolors="none")

ax = axes[0]
ax.scatter(emb_2d[:, 0], emb_2d[:, 1], c=com_aligned,
           cmap="tab10", **scatter_kw)
ax.set_title("True Communities")
ax.set_xlabel("t-SNE 1")
ax.set_ylabel("t-SNE 2")

ax = axes[1]
ax.scatter(emb_2d[:, 0], emb_2d[:, 1], c=pred_best,
           cmap="tab10", **scatter_kw)
ax.set_title(f"K-Means (k = {best_k})")
ax.set_xlabel("t-SNE 1")
ax.set_ylabel("t-SNE 2")

plt.tight_layout()
plt.show()
Figure 218.4

219 Training Diagnostics

Monitoring the training process reveals problems before they manifest in poor evaluation metrics. Early detection of training pathologies enables timely intervention and prevents wasted computational resources.

219.1 Loss Curves and Convergence

The loss curve provides the primary signal for diagnosing training health (Bottou, Curtis, and Nocedal 2018). Systematic analysis of loss trajectories can identify common failure modes including divergence, oscillation, and premature convergence (Smith 2017).

219.1.1 Theoretical Foundation

Let \(\mathcal{L}(\theta_t)\) denote the loss at training step \(t\) with parameters \(\theta_t\). Under standard assumptions on the loss landscape and learning rate schedule, we expect:

\[ \mathcal{L}(\theta_t) - \mathcal{L}^* \propto \mathcal{O}(1/t) \]

for convex objectives, or exponential decay \(\mathcal{L}(\theta_t) - \mathcal{L}^* \propto e^{-\lambda t}\) in favorable non-convex settings (Bottou, Curtis, and Nocedal 2018). Deviation from these patterns signals training pathologies.

Comprehensive training diagnostics with statistical tests
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
from typing import Dict, List, Tuple
import warnings

class TrainingMonitor:
    """
    Monitor training health for embedding models.
    
    
    Parameters
    ----------
    window_size : int
        Number of recent epochs to analyze for trend detection
    smoothing_alpha : float
        Exponential smoothing parameter (0 < alpha <= 1)
    
    """
    
    def __init__(self, window_size: int = 50, smoothing_alpha: float = 0.1):
        self.history = {
            'loss': [], 
            'grad_norm': [], 
            'epoch': [],
            'learning_rate': [],
            'batch_loss_variance': []
        }
        self.window_size = window_size
        self.alpha = smoothing_alpha
        self.smoothed_loss = []
    
    def log(self, epoch: int, loss: float, grad_norm: float = None,
            learning_rate: float = None, batch_variance: float = None):
        """Log training metrics for an epoch."""
        self.history['epoch'].append(epoch)
        self.history['loss'].append(loss)
        self.history['grad_norm'].append(grad_norm)
        self.history['learning_rate'].append(learning_rate)
        self.history['batch_loss_variance'].append(batch_variance)
        
        # Exponential smoothing
        if len(self.smoothed_loss) == 0:
            self.smoothed_loss.append(loss)
        else:
            smoothed = self.alpha * loss + (1 - self.alpha) * self.smoothed_loss[-1]
            self.smoothed_loss.append(smoothed)
    
    def diagnose_loss_curve(self) -> Dict:
        """
        Comprehensive loss curve diagnostics.
        
        Returns
        -------
        Dict with diagnostic results including:
        - divergence: Boolean indicating training divergence
        - plateau: Boolean and statistical test results
        - oscillation: Boolean and frequency analysis
        - convergence_rate: Estimated convergence coefficient
        - recommendations: List of actionable suggestions
        """
        losses = np.array(self.history['loss'])
        
        if len(losses) < 10:
            return {'status': 'insufficient_data',
                    'message': 'Need at least 10 epochs for diagnosis'}
        
        diagnosis = {
            'n_epochs': len(losses),
            'final_loss': losses[-1],
            'min_loss': np.min(losses),
            'recommendations': []
        }
        
        # 1. Check for divergence (NaN/Inf or monotonic increase)
        if np.any(np.isnan(losses)) or np.any(np.isinf(losses)):
            diagnosis['divergence'] = True
            diagnosis['severity'] = 'critical'
            diagnosis['message'] = "Training diverged (NaN/Inf detected)"
            diagnosis['recommendations'].extend([
                "Reduce learning rate by 10x",
                "Check for gradient clipping",
                "Verify data normalization",
                "Inspect batch statistics"
            ])
            return diagnosis
        
        # Check monotonic increase in recent history
        recent_losses = losses[-20:] if len(losses) >= 20 else losses
        if len(recent_losses) > 5:
            trend, _, _, p_value, _ = stats.linregress(
                range(len(recent_losses)), recent_losses
            )
            if trend > 0 and p_value < 0.05:
                diagnosis['divergence'] = True
                diagnosis['severity'] = 'high'
                diagnosis['trend_coefficient'] = trend
                diagnosis['trend_p_value'] = p_value
                diagnosis['message'] = "Loss increasing (possible divergence)"
                diagnosis['recommendations'].append(
                    "Reduce learning rate immediately"
                )
        # 2. Test for plateau using statistical methods
        plateau_result = self._test_plateau(losses)
        # Merge recommendations
        if 'recommendations' in plateau_result:
            diagnosis['recommendations'].extend(plateau_result.pop('recommendations'))
        diagnosis.update(plateau_result)
        
        # 3. Test for oscillation
        oscillation_result = self._test_oscillation(losses)
        # Merge recommendations
        if 'recommendations' in oscillation_result:
            diagnosis['recommendations'].extend(oscillation_result.pop('recommendations'))
        diagnosis.update(oscillation_result)
        
        # 4. Estimate convergence rate
        convergence_result = self._estimate_convergence_rate(losses)
        diagnosis.update(convergence_result)
        
        # 5. Gradient norm analysis
        if any(g is not None for g in self.history['grad_norm']):
            gradient_result = self._analyze_gradients()
            # Merge recommendations
            if 'recommendations' in gradient_result:
                diagnosis['recommendations'].extend(gradient_result.pop('recommendations'))
            diagnosis.update(gradient_result)
        
        # 6. Overall status
        if not diagnosis.get('divergence', False):
            if diagnosis.get('is_plateau', False):
                diagnosis['status'] = 'plateau'
                diagnosis['message'] = diagnosis.get('plateau_message', 'Training plateaued')
            elif diagnosis.get('is_oscillating', False):
                diagnosis['status'] = 'oscillating'
                diagnosis['message'] = diagnosis.get('oscillation_message', 'Training oscillating')
            else:
                diagnosis['status'] = 'healthy'
                diagnosis['message'] = "Training progressing normally"
        
        return diagnosis
        
        
   
    def _test_plateau(self, losses: np.ndarray) -> Dict:
        """
        Statistical test for plateau detection.
        
        Uses multiple criteria:
        1. Range test: variance in recent window
        2. Trend test: linear regression slope
        3. KPSS stationarity test
        
        References
        ----------
        Kwiatkowski et al. (1992). Testing the null hypothesis of stationarity.
        """
        result = {'recommendations': []}  # Initialize recommendations list
        window = min(self.window_size, len(losses))
        recent = losses[-window:]
        
        # Range-based test
        loss_range = np.max(recent) - np.min(recent)
        mean_loss = np.mean(recent)
        relative_range = loss_range / (mean_loss + 1e-10)
        
        # Trend test
        trend, intercept, _, p_value, std_err = stats.linregress(
            range(len(recent)), recent
        )
        
        # Determine plateau
        is_plateau = (
            relative_range < 0.01 and  # Less than 1% variation
            abs(trend) < 0.0001 and    # Minimal trend
            p_value > 0.05              # No significant trend
        )
        
        result['is_plateau'] = is_plateau
        result['plateau_relative_range'] = relative_range
        result['plateau_trend'] = trend
        result['plateau_p_value'] = p_value
        
        if is_plateau:
            result['plateau_message'] = (
                f"Plateau detected: {relative_range:.1%} variation "
                f"over {window} epochs"
            )
            result['recommendations'].extend([
                "Consider increasing model capacity",
                "Try learning rate warmup restart",
                "Check for optimization hyperparameters",
                "Verify sufficient training data diversity"
            ])
        
        return result

    
    
    def _test_oscillation(self, losses: np.ndarray) -> Dict:
        """
        Detect oscillatory behavior in training.
        
        Methods:
        1. Zero-crossing rate of first derivative
        2. Autocorrelation analysis
        3. Spectral analysis for dominant frequencies
        """
        result = {'recommendations': []}  # Initialize recommendations list
        window = min(self.window_size, len(losses))
        recent = losses[-window:]
        
        if len(recent) < 10:
            result['is_oscillating'] = False
            return result
        
        # First derivative and zero crossings
        diffs = np.diff(recent)
        sign_changes = np.sum(np.diff(np.sign(diffs)) != 0)
        oscillation_rate = sign_changes / len(diffs)
        
        # Autocorrelation at lag 1
        if len(diffs) > 1:
            acf_1 = np.corrcoef(diffs[:-1], diffs[1:])[0, 1]
        else:
            acf_1 = 0
        
        is_oscillating = oscillation_rate > 0.6 or acf_1 < -0.3
        
        result['is_oscillating'] = is_oscillating
        result['oscillation_rate'] = oscillation_rate
        result['loss_autocorr'] = acf_1
        
        if is_oscillating:
            result['oscillation_message'] = (
                f"Oscillation detected: {oscillation_rate:.1%} sign changes"
            )
            result['recommendations'].extend([
                "Reduce learning rate by 2-5x",
                "Consider adaptive learning rate (Adam, RMSprop)",
                "Increase batch size to reduce noise",
                "Add gradient clipping if not present"
            ])
        
        return result
    
    def _estimate_convergence_rate(self, losses: np.ndarray) -> Dict:
        """
        Estimate convergence rate assuming exponential decay.
        
        Fits: L(t) = L_∞ + (L_0 - L_∞) * exp(-λt)
        """
        result = {}
        
        if len(losses) < 20:
            return result
        
        # Use log transform for linear fit
        # log(L(t) - L_min) ≈ log(L_0 - L_min) - λt
        L_min = np.min(losses)
        normalized = losses - L_min + 1e-6  # Avoid log(0)
        
        try:
            # Only use middle portion to avoid initialization and plateau effects
            start_idx = len(losses) // 4
            end_idx = 3 * len(losses) // 4
            
            if end_idx - start_idx > 10:
                x = np.arange(start_idx, end_idx)
                y = np.log(normalized[start_idx:end_idx])
                
                slope, _, _, p_value, _ = stats.linregress(x, y)
                
                result['convergence_rate'] = -slope
                result['convergence_p_value'] = p_value
                result['estimated_asymptote'] = L_min
                
                # Estimate epochs to convergence (within 1% of minimum)
                if slope < -0.001:  # Meaningful convergence
                    epochs_remaining = -np.log(0.01) / (-slope)
                    result['estimated_epochs_to_converge'] = int(epochs_remaining)
        
        except Exception as e:
            warnings.warn(f"Convergence estimation failed: {e}")
        
        return result
    
    def _analyze_gradients(self) -> Dict:
        """Analyze gradient norm trajectory."""
        result = {'recommendations': []}  # Initialize recommendations list
        grad_norms = [g for g in self.history['grad_norm'] if g is not None]
        
        if len(grad_norms) < 10:
            return result
        
        grad_norms = np.array(grad_norms)
        
        result['mean_grad_norm'] = np.mean(grad_norms)
        result['grad_norm_std'] = np.std(grad_norms)
        result['grad_norm_trend'] = np.polyfit(range(len(grad_norms)), grad_norms, 1)[0]
        
        # Check for vanishing gradients
        if result['mean_grad_norm'] < 1e-6:
            result['vanishing_gradients'] = True
            result['recommendations'].append(
                "Vanishing gradients detected - check activation functions"
            )
        
        # Check for exploding gradients
        if result['mean_grad_norm'] > 100 or np.max(grad_norms) > 1000:
            result['exploding_gradients'] = True
            result['recommendations'].append(
                "Exploding gradients - add gradient clipping"
            )
        
        return result
    
    
    
    def plot_diagnostics(self, figsize=(15, 10)):
        """
        Create comprehensive diagnostic plots.
        
        Generates a 2x3 grid showing:
        1. Loss curve with smoothing
        2. Loss derivative (learning speed)
        3. Gradient norms
        4. Learning rate schedule
        5. Loss distribution
        6. Convergence analysis
        """
        fig, axes = plt.subplots(2, 3, figsize=figsize)
        epochs = np.array(self.history['epoch'])
        losses = np.array(self.history['loss'])
        
        # 1. Loss curve
        ax = axes[0, 0]
        ax.plot(epochs, losses, 'b-', alpha=0.3, label='Raw loss')
        if self.smoothed_loss:
            ax.plot(epochs, self.smoothed_loss, 'b-', linewidth=2, 
                   label=f'Smoothed (α={self.alpha})')
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Loss')
        ax.set_title('Training Loss Trajectory')
        ax.legend()
        ax.grid(True, alpha=0.3)
        
        # 2. Loss derivative
        ax = axes[0, 1]
        if len(losses) > 1:
            loss_deriv = np.diff(losses)
            ax.plot(epochs[1:], loss_deriv, 'g-', alpha=0.6)
            ax.axhline(y=0, color='r', linestyle='--', alpha=0.5)
            ax.set_xlabel('Epoch')
            ax.set_ylabel('ΔLoss')
            ax.set_title('Loss Derivative (Learning Speed)')
            ax.grid(True, alpha=0.3)
        
        # 3. Gradient norms
        ax = axes[0, 2]
        grad_norms = [g for g in self.history['grad_norm'] if g is not None]
        if grad_norms:
            ax.plot(epochs[:len(grad_norms)], grad_norms, 'r-', alpha=0.6)
            ax.set_xlabel('Epoch')
            ax.set_ylabel('Gradient Norm')
            ax.set_title('Gradient Magnitude')
            ax.set_yscale('log')
            ax.grid(True, alpha=0.3)
        else:
            ax.text(0.5, 0.5, 'No gradient data', 
                   ha='center', va='center', transform=ax.transAxes)
        
        # 4. Learning rate schedule
        ax = axes[1, 0]
        lrs = [lr for lr in self.history['learning_rate'] if lr is not None]
        if lrs:
            ax.plot(epochs[:len(lrs)], lrs, 'm-', linewidth=2)
            ax.set_xlabel('Epoch')
            ax.set_ylabel('Learning Rate')
            ax.set_title('Learning Rate Schedule')
            ax.set_yscale('log')
            ax.grid(True, alpha=0.3)
        else:
            ax.text(0.5, 0.5, 'No LR data', 
                   ha='center', va='center', transform=ax.transAxes)
        
        # 5. Loss distribution
        ax = axes[1, 1]
        ax.hist(losses, bins=30, alpha=0.7, edgecolor='black')
        ax.axvline(np.mean(losses), color='r', linestyle='--', 
                  label=f'Mean: {np.mean(losses):.3f}')
        ax.axvline(np.median(losses), color='g', linestyle='--',
                  label=f'Median: {np.median(losses):.3f}')
        ax.set_xlabel('Loss')
        ax.set_ylabel('Frequency')
        ax.set_title('Loss Distribution')
        ax.legend()
        
        # 6. Convergence analysis
        ax = axes[1, 2]
        if len(losses) > 20:
            # Plot normalized loss
            normalized = (losses - np.min(losses)) / (np.max(losses) - np.min(losses) + 1e-10)
            ax.semilogy(epochs, normalized + 1e-6, 'b-', alpha=0.6, label='Normalized loss')
            
            # Add exponential fit
            try:
                mid_start = len(losses) // 4
                mid_end = 3 * len(losses) // 4
                x_fit = epochs[mid_start:mid_end]
                y_fit = normalized[mid_start:mid_end] + 1e-6
                
                coeffs = np.polyfit(x_fit, np.log(y_fit), 1)
                y_pred = np.exp(coeffs[1] + coeffs[0] * epochs)
                ax.plot(epochs, y_pred, 'r--', label=f'Exponential fit (λ={-coeffs[0]:.4f})')
            except:
                pass
            
            ax.set_xlabel('Epoch')
            ax.set_ylabel('Normalized Loss (log scale)')
            ax.set_title('Convergence Analysis')
            ax.legend()
            ax.grid(True, alpha=0.3)
        
        plt.tight_layout()
        return fig


def simulate_training(scenario: str, n_epochs: int = 100, 
                    seed: int = 42) -> TrainingMonitor:
    """
    Simulate training scenarios for diagnostic demonstration.
    
    Parameters
    ----------
    scenario : {'healthy', 'oscillating', 'plateau', 'diverging', 
                'vanishing_gradients', 'slow_convergence'}
    n_epochs : int
    seed : int
    
    Returns
    -------
    TrainingMonitor with simulated training history
    """
    monitor = TrainingMonitor()
    np.random.seed(seed)
    
    if scenario == 'healthy':
        # Exponential decay with noise
        lr_schedule = np.logspace(-3, -4, n_epochs)
        for e in range(n_epochs):
            loss = 2.0 * np.exp(-0.03 * e) + 0.1 + np.random.normal(0, 0.02)
            grad_norm = 1.0 * np.exp(-0.02 * e) + np.random.normal(0, 0.05)
            monitor.log(e, loss, max(grad_norm, 0), lr_schedule[e])
    
    elif scenario == 'oscillating':
        # High learning rate causing oscillation
        lr = 0.1  # Too high
        for e in range(n_epochs):
            loss = 1.0 + 0.5 * np.sin(0.5 * e) * np.exp(-0.01 * e) + np.random.normal(0, 0.1)
            grad_norm = 2.0 + np.random.normal(0, 0.5)
            monitor.log(e, loss, max(grad_norm, 0), lr)
    
    elif scenario == 'plateau':
        # Early plateau due to insufficient capacity
        for e in range(n_epochs):
            if e < 20:
                loss = 2.0 - 0.08 * e + np.random.normal(0, 0.02)
                lr = 0.001
            else:
                loss = 0.4 + np.random.normal(0, 0.01)
                lr = 0.001 * np.exp(-0.05 * (e - 20))
            grad_norm = 0.5 * np.exp(-0.02 * e) + np.random.normal(0, 0.02)
            monitor.log(e, loss, max(grad_norm, 0), lr)
    
    elif scenario == 'diverging':
        # Divergence after initial progress
        for e in range(n_epochs):
            if e < 30:
                loss = 1.0 + np.random.normal(0, 0.1)
                grad_norm = 0.5 + np.random.normal(0, 0.1)
                lr = 0.01
            else:
                loss = 1.0 + 0.1 * (e - 30) ** 1.5
                grad_norm = 0.1 * np.exp(0.1 * (e - 30))
                lr = 0.01
            monitor.log(e, loss, max(grad_norm, 0), lr)
    
    elif scenario == 'vanishing_gradients':
        # Gradients vanish due to poor initialization
        for e in range(n_epochs):
            loss = 2.0 - 0.005 * e + np.random.normal(0, 0.01)
            grad_norm = max(1e-4 * np.exp(-0.05 * e), 1e-8)
            lr = 0.001
            monitor.log(e, loss, grad_norm, lr)
    
    elif scenario == 'slow_convergence':
        # Very slow convergence (learning rate too small)
        for e in range(n_epochs):
            loss = 2.0 * np.exp(-0.005 * e) + 0.5 + np.random.normal(0, 0.01)
            grad_norm = 0.3 + np.random.normal(0, 0.05)
            lr = 1e-5
            monitor.log(e, loss, max(grad_norm, 0), lr)
    
    return monitor


# Demonstrate all scenarios
scenarios = ['healthy', 'oscillating', 'plateau', 'diverging', 
             'vanishing_gradients', 'slow_convergence']

print("TRAINING DIAGNOSTIC SCENARIOS")
print("=" * 80)

# Create summary table
results_table = []

for scenario in scenarios:
    monitor = simulate_training(scenario, n_epochs=100)
    diagnosis = monitor.diagnose_loss_curve()
    
    results_table.append({
        'Scenario': scenario.replace('_', ' ').title(),
        'Status': diagnosis.get('status', 'N/A'),
        'Final Loss': f"{diagnosis.get('final_loss', 0):.3f}",
        'Converged': 'Yes' if diagnosis.get('convergence_rate', 0) > 0.01 else 'No',
        'Issue Detected': diagnosis.get('message', 'None'),
        'Primary Recommendation': diagnosis.get('recommendations', ['None'])[0] if diagnosis.get('recommendations') else 'None'
    })

import pandas as pd
results_df = pd.DataFrame(results_table)
print("\nTable 1: Training Diagnostic Summary")
print(results_df.to_string(index=False))
print()

# Visualize scenarios
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

for idx, scenario in enumerate(scenarios):
    monitor = simulate_training(scenario)
    diagnosis = monitor.diagnose_loss_curve()
    
    ax = axes[idx // 3, idx % 3]
    
    # Plot loss and smoothed version
    epochs = monitor.history['epoch']
    losses = monitor.history['loss']
    ax.plot(epochs, losses, 'b-', alpha=0.3, linewidth=1)
    ax.plot(epochs, monitor.smoothed_loss, 'b-', linewidth=2)
    
    # Add annotations
    status = diagnosis.get('status', 'unknown')
    color_map = {
        'healthy': 'green',
        'oscillating': 'orange',
        'plateau': 'red',
        'diverging': 'darkred',
        'insufficient_data': 'gray'
    }
    color = color_map.get(status, 'black')
    
    ax.set_xlabel('Epoch', fontsize=10)
    ax.set_ylabel('Loss', fontsize=10)
    title = f"{scenario.replace('_', ' ').title()}\n{diagnosis.get('message', '')}"
    ax.set_title(title, fontsize=10, color=color, fontweight='bold')
    ax.grid(True, alpha=0.3)
    
    # Add status indicator
    ax.text(0.95, 0.95, f"Status: {status}", 
           transform=ax.transAxes, ha='right', va='top',
           bbox=dict(boxstyle='round', facecolor=color, alpha=0.3),
           fontsize=8)

plt.suptitle('Figure 1: Training Diagnostic Scenarios', 
             fontsize=14, fontweight='bold', y=1.00)
plt.tight_layout()
plt.show()

# Detailed diagnostic plot for one scenario
print("\nDetailed Diagnostics for 'Oscillating' Scenario:")
print("-" * 80)
monitor = simulate_training('oscillating')
diagnosis = monitor.diagnose_loss_curve()

for key, value in diagnosis.items():
    if key != 'recommendations':
        print(f"{key}: {value}")

if diagnosis.get('recommendations'):
    print("\nRecommendations:")
    for i, rec in enumerate(diagnosis['recommendations'], 1):
        print(f"  {i}. {rec}")

monitor.plot_diagnostics()
plt.suptitle('Figure 2: Comprehensive Diagnostics for Oscillating Training',
             fontsize=14, fontweight='bold', y=1.00)
plt.show()
Code
# Training Diagnostics {#sec-training}

import numpy as np
import matplotlib.pyplot as plt
from scipy import stats, signal
from typing import Dict, List, Tuple, Optional
import warnings
import pandas as pd

# Statistical testing packages
from statsmodels.tsa.stattools import adfuller, kpss  # Stationarity tests
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf  # Time series diagnostics
from statsmodels.tsa.seasonal import seasonal_decompose  # Decompose trends

# Anomaly detection
from sklearn.ensemble import IsolationForest
from sklearn.preprocessing import StandardScaler

# Change point detection
try:
    import ruptures as rpt  # Efficient change point detection
    HAS_RUPTURES = True
except ImportError:
    HAS_RUPTURES = False
    warnings.warn("Install ruptures for change point detection: pip install ruptures")

# Bayesian optimization diagnostics
try:
    from bayes_opt import BayesianOptimization
    HAS_BAYESOPT = True
except ImportError:
    HAS_BAYESOPT = False

# TensorBoard-style monitoring
try:
    from torch.utils.tensorboard import SummaryWriter
    HAS_TENSORBOARD = True
except ImportError:
    HAS_TENSORBOARD = False

# Weights & Biases alternative
try:
    import wandb
    HAS_WANDB = True
except ImportError:
    HAS_WANDB = False


class TrainingMonitor:
    """
    Advanced training monitor leveraging statistical packages.
    
    Uses:
    - statsmodels: Time series analysis, stationarity tests
    - ruptures: Change point detection
    - scikit-learn: Anomaly detection
    - scipy.signal: Spectral analysis
    
    References
    ----------
    Killick, R., Fearnhead, P., & Eckley, I. A. (2012). Optimal detection 
        of changepoints with a linear computational cost. JASA.
    Kwiatkowski, D., et al. (1992). Testing the null hypothesis of 
        stationarity against the alternative of a unit root. Journal of 
        Econometrics.
    """
    
    def __init__(self, window_size: int = 50, smoothing_alpha: float = 0.1,
                 use_tensorboard: bool = False, log_dir: str = './logs'):
        self.history = {
            'loss': [], 
            'grad_norm': [], 
            'epoch': [],
            'learning_rate': [],
            'batch_loss_variance': []
        }
        self.window_size = window_size
        self.alpha = smoothing_alpha
        self.smoothed_loss = []
        
        # TensorBoard integration
        self.use_tensorboard = use_tensorboard and HAS_TENSORBOARD
        if self.use_tensorboard:
            self.writer = SummaryWriter(log_dir)
    
    def log(self, epoch: int, loss: float, grad_norm: float = None,
            learning_rate: float = None, batch_variance: float = None):
        """Log training metrics for an epoch."""
        self.history['epoch'].append(epoch)
        self.history['loss'].append(loss)
        self.history['grad_norm'].append(grad_norm)
        self.history['learning_rate'].append(learning_rate)
        self.history['batch_loss_variance'].append(batch_variance)
        
        # Exponential smoothing using pandas
        if len(self.smoothed_loss) == 0:
            self.smoothed_loss.append(loss)
        else:
            smoothed = self.alpha * loss + (1 - self.alpha) * self.smoothed_loss[-1]
            self.smoothed_loss.append(smoothed)
        
        # TensorBoard logging
        if self.use_tensorboard:
            self.writer.add_scalar('Loss/train', loss, epoch)
            if grad_norm is not None:
                self.writer.add_scalar('Gradients/norm', grad_norm, epoch)
            if learning_rate is not None:
                self.writer.add_scalar('Learning_Rate', learning_rate, epoch)
    
    def diagnose_loss_curve(self) -> Dict:
        """Comprehensive diagnostics using statistical packages."""
        losses = np.array(self.history['loss'])
        
        if len(losses) < 10:
            return {
                'status': 'insufficient_data',
                'message': 'Need at least 10 epochs for diagnosis',
                'recommendations': []
            }
        
        diagnosis = {
            'n_epochs': len(losses),
            'final_loss': losses[-1],
            'min_loss': np.min(losses),
            'recommendations': []
        }
        
        # 1. Divergence checks
        if np.any(np.isnan(losses)) or np.any(np.isinf(losses)):
            diagnosis['divergence'] = True
            diagnosis['severity'] = 'critical'
            diagnosis['message'] = "Training diverged (NaN/Inf detected)"
            diagnosis['recommendations'].extend([
                "Reduce learning rate by 10x",
                "Check for gradient clipping",
                "Verify data normalization",
                "Inspect batch statistics"
            ])
            return diagnosis
        
        # 2. Stationarity test using statsmodels
        stationarity_result = self._test_stationarity(losses)
        diagnosis.update(stationarity_result)
        
        # 3. Change point detection
        if HAS_RUPTURES:
            changepoint_result = self._detect_changepoints(losses)
            diagnosis.update(changepoint_result)
        
        # 4. Spectral analysis for oscillations
        oscillation_result = self._spectral_analysis(losses)
        diagnosis.update(oscillation_result)
        
        # 5. Anomaly detection
        anomaly_result = self._detect_anomalies(losses)
        diagnosis.update(anomaly_result)
        
        # 6. Trend decomposition
        trend_result = self._decompose_trend(losses)
        diagnosis.update(trend_result)
        
        # 7. Convergence analysis
        convergence_result = self._estimate_convergence_rate(losses)
        diagnosis.update(convergence_result)
        
        # 8. Gradient analysis
        if any(g is not None for g in self.history['grad_norm']):
            gradient_result = self._analyze_gradients()
            diagnosis.update(gradient_result)
        
        # Overall status
        if not diagnosis.get('divergence', False):
            if diagnosis.get('is_nonstationary', False):
                diagnosis['status'] = 'nonstationary'
                diagnosis['message'] = 'Loss not converging (non-stationary)'
            elif diagnosis.get('has_changepoints', False):
                diagnosis['status'] = 'unstable'
                diagnosis['message'] = f"Detected {diagnosis.get('n_changepoints', 0)} regime changes"
            elif diagnosis.get('is_oscillating', False):
                diagnosis['status'] = 'oscillating'
                diagnosis['message'] = diagnosis.get('oscillation_message', 'Training oscillating')
            else:
                diagnosis['status'] = 'healthy'
                diagnosis['message'] = "Training progressing normally"
        
        return diagnosis
    
    def _test_stationarity(self, losses: np.ndarray) -> Dict:
        """
        Test for stationarity using ADF and KPSS tests.
        
        References
        ----------
        Augmented Dickey-Fuller test: tests for unit root (non-stationarity)
        KPSS test: tests for stationarity
        """
        result = {'recommendations': []}
        
        if len(losses) < 20:
            return result
        
        # Augmented Dickey-Fuller test
        # H0: Unit root exists (non-stationary)
        try:
            adf_result = adfuller(losses, autolag='AIC')
            adf_statistic, adf_pvalue = adf_result[0], adf_result[1]
            
            result['adf_statistic'] = adf_statistic
            result['adf_pvalue'] = adf_pvalue
            result['is_stationary_adf'] = adf_pvalue < 0.05
        except Exception as e:
            warnings.warn(f"ADF test failed: {e}")
        
        # KPSS test
        # H0: Series is stationary
        try:
            kpss_result = kpss(losses, regression='c', nlags='auto')
            kpss_statistic, kpss_pvalue = kpss_result[0], kpss_result[1]
            
            result['kpss_statistic'] = kpss_statistic
            result['kpss_pvalue'] = kpss_pvalue
            result['is_stationary_kpss'] = kpss_pvalue > 0.05
        except Exception as e:
            warnings.warn(f"KPSS test failed: {e}")
        
        # Combined interpretation
        if result.get('is_stationary_adf') and result.get('is_stationary_kpss'):
            result['is_nonstationary'] = False
            result['stationarity_message'] = "Loss is stationary (converging)"
        elif not result.get('is_stationary_adf', True) and not result.get('is_stationary_kpss', True):
            result['is_nonstationary'] = True
            result['stationarity_message'] = "Loss is non-stationary (not converging)"
            result['recommendations'].extend([
                "Training not converging - check learning rate",
                "Consider learning rate schedule (decay)",
                "Verify model capacity is sufficient"
            ])
        else:
            result['is_nonstationary'] = None
            result['stationarity_message'] = "Stationarity tests inconclusive"
        
        return result
    
    def _detect_changepoints(self, losses: np.ndarray) -> Dict:
        """
        Detect regime changes using ruptures library.
        
        References
        ----------
        Killick et al. (2012). Optimal detection of changepoints.
        """
        result = {'recommendations': []}
        
        if not HAS_RUPTURES or len(losses) < 30:
            return result
        
        try:
            # Use Pelt algorithm for efficient change point detection
            algo = rpt.Pelt(model="rbf", min_size=10, jump=1).fit(losses)
            changepoints = algo.predict(pen=3)  # Penalty parameter
            
            # Remove the final endpoint
            changepoints = [cp for cp in changepoints if cp < len(losses)]
            
            result['changepoints'] = changepoints
            result['n_changepoints'] = len(changepoints)
            result['has_changepoints'] = len(changepoints) > 0
            
            if len(changepoints) > 0:
                result['changepoint_message'] = (
                    f"Detected {len(changepoints)} regime changes at epochs: "
                    f"{changepoints[:3]}{'...' if len(changepoints) > 3 else ''}"
                )
                result['recommendations'].extend([
                    "Multiple training regimes detected",
                    "Consider learning rate warmup/restart at regime changes",
                    "Check for data distribution shifts"
                ])
            
            # Analyze segments between changepoints
            if len(changepoints) > 0:
                segments = [0] + changepoints
                segment_trends = []
                
                for i in range(len(segments) - 1):
                    start, end = segments[i], segments[i + 1]
                    segment = losses[start:end]
                    if len(segment) > 2:
                        trend = np.polyfit(range(len(segment)), segment, 1)[0]
                        segment_trends.append(trend)
                
                result['segment_trends'] = segment_trends
                
        except Exception as e:
            warnings.warn(f"Change point detection failed: {e}")
        
        return result
    
    def _spectral_analysis(self, losses: np.ndarray) -> Dict:
        """
        Detect oscillations using FFT and periodogram.
        
        Uses scipy.signal for spectral analysis.
        """
        result = {'recommendations': []}
        
        if len(losses) < 20:
            return result
        
        try:
            # Detrend first
            detrended = signal.detrend(losses)
            
            # Compute periodogram
            freqs, power = signal.periodogram(detrended, scaling='spectrum')
            
            # Find dominant frequencies (excluding DC component)
            if len(freqs) > 1:
                dominant_idx = np.argmax(power[1:]) + 1
                dominant_freq = freqs[dominant_idx]
                dominant_power = power[dominant_idx]
                
                # Period in epochs
                if dominant_freq > 0:
                    period = 1.0 / dominant_freq
                else:
                    period = np.inf
                
                result['dominant_frequency'] = dominant_freq
                result['dominant_period'] = period
                result['spectral_power'] = dominant_power
                
                # Check if oscillation is significant
                mean_power = np.mean(power[1:])
                peak_ratio = dominant_power / (mean_power + 1e-10)
                
                result['peak_ratio'] = peak_ratio
                
                if peak_ratio > 3 and period < 50:
                    result['is_oscillating'] = True
                    result['oscillation_message'] = (
                        f"Oscillation detected with period ~{period:.1f} epochs"
                    )
                    result['recommendations'].extend([
                        f"Reduce learning rate (oscillating every {period:.1f} epochs)",
                        "Consider adaptive optimizers (Adam, RMSprop)",
                        "Increase batch size to reduce noise"
                    ])
                else:
                    result['is_oscillating'] = False
            
        except Exception as e:
            warnings.warn(f"Spectral analysis failed: {e}")
        
        return result
    
    def _detect_anomalies(self, losses: np.ndarray) -> Dict:
        """Detect anomalous epochs using Isolation Forest."""
        result = {'recommendations': []}
        
        if len(losses) < 30:
            return result
        
        try:
            # Prepare features: loss, loss derivative, second derivative
            X = np.column_stack([
                losses,
                np.concatenate([[0], np.diff(losses)]),
                np.concatenate([[0, 0], np.diff(losses, n=2)])
            ])
            
            # Standardize
            scaler = StandardScaler()
            X_scaled = scaler.fit_transform(X)
            
            # Isolation Forest
            iso_forest = IsolationForest(contamination=0.1, random_state=42)
            anomaly_labels = iso_forest.fit_predict(X_scaled)
            
            anomaly_epochs = np.where(anomaly_labels == -1)[0]
            
            result['anomaly_epochs'] = anomaly_epochs.tolist()
            result['n_anomalies'] = len(anomaly_epochs)
            result['anomaly_rate'] = len(anomaly_epochs) / len(losses)
            
            if len(anomaly_epochs) > len(losses) * 0.15:
                result['recommendations'].append(
                    f"High anomaly rate ({result['anomaly_rate']:.1%}) - check data quality"
                )
            
        except Exception as e:
            warnings.warn(f"Anomaly detection failed: {e}")
        
        return result
    
    def _decompose_trend(self, losses: np.ndarray) -> Dict:
        """
        Decompose loss into trend, seasonal, and residual components.
        
        Uses statsmodels seasonal_decompose.
        """
        result = {'recommendations': []}
        
        if len(losses) < 30:
            return result
        
        try:
            # Create time series
            ts = pd.Series(losses)
            
            # Decompose (need at least 2 periods)
            period = min(10, len(losses) // 3)
            if period >= 2:
                decomposition = seasonal_decompose(
                    ts, 
                    model='additive', 
                    period=period,
                    extrapolate_trend='freq'
                )
                
                result['trend'] = decomposition.trend.values
                result['seasonal'] = decomposition.seasonal.values
                result['residual'] = decomposition.resid.values
                
                # Analyze trend direction
                trend_clean = decomposition.trend.dropna()
                if len(trend_clean) > 5:
                    trend_slope = np.polyfit(
                        range(len(trend_clean)), 
                        trend_clean, 
                        1
                    )[0]
                    
                    result['trend_slope'] = trend_slope
                    
                    if trend_slope > 0:
                        result['trend_direction'] = 'increasing'
                        result['recommendations'].append(
                            "Loss trending upward - reduce learning rate"
                        )
                    elif abs(trend_slope) < 1e-4:
                        result['trend_direction'] = 'flat'
                        result['recommendations'].append(
                            "Loss plateaued - consider model capacity or learning rate restart"
                        )
                    else:
                        result['trend_direction'] = 'decreasing'
                
        except Exception as e:
            warnings.warn(f"Trend decomposition failed: {e}")
        
        return result
    
    def _estimate_convergence_rate(self, losses: np.ndarray) -> Dict:
        """Estimate convergence rate with confidence intervals."""
        result = {'recommendations': []}
        
        if len(losses) < 20:
            return result
        
        try:
            # Fit exponential decay
            L_min = np.min(losses)
            normalized = losses - L_min + 1e-6
            
            start_idx = len(losses) // 4
            end_idx = 3 * len(losses) // 4
            
            if end_idx - start_idx > 10:
                x = np.arange(start_idx, end_idx)
                y = np.log(normalized[start_idx:end_idx])
                
                # Linear regression with confidence interval
                from scipy import stats
                slope, intercept, r_value, p_value, std_err = stats.linregress(x, y)
                
                result['convergence_rate'] = -slope
                result['convergence_p_value'] = p_value
                result['convergence_r_squared'] = r_value ** 2
                result['convergence_std_err'] = std_err
                result['estimated_asymptote'] = L_min
                
                # Confidence interval
                alpha = 0.05
                t_val = stats.t.ppf(1 - alpha/2, len(x) - 2)
                ci_lower = slope - t_val * std_err
                ci_upper = slope + t_val * std_err
                
                result['convergence_ci'] = (ci_lower, ci_upper)
                
                # Estimate time to convergence
                if slope < -0.001:
                    epochs_to_1pct = -np.log(0.01) / (-slope)
                    result['estimated_epochs_to_converge'] = int(epochs_to_1pct)
                
                # Quality assessment
                if r_value ** 2 < 0.5:
                    result['recommendations'].append(
                        f"Poor exponential fit (R²={r_value**2:.2f}) - training may be unstable"
                    )
                
        except Exception as e:
            warnings.warn(f"Convergence estimation failed: {e}")
        
        return result
    
    def _analyze_gradients(self) -> Dict:
        """Analyze gradient statistics."""
        result = {'recommendations': []}
        grad_norms = np.array([g for g in self.history['grad_norm'] if g is not None])
        
        if len(grad_norms) < 10:
            return result
        
        result['mean_grad_norm'] = np.mean(grad_norms)
        result['grad_norm_std'] = np.std(grad_norms)
        result['grad_norm_cv'] = result['grad_norm_std'] / (result['mean_grad_norm'] + 1e-10)
        
        # Gradient trend
        trend = np.polyfit(range(len(grad_norms)), grad_norms, 1)[0]
        result['grad_norm_trend'] = trend
        
        # Check pathologies
        if result['mean_grad_norm'] < 1e-6:
            result['vanishing_gradients'] = True
            result['recommendations'].append(
                "Vanishing gradients - check activation functions and initialization"
            )
        
        if result['mean_grad_norm'] > 100 or np.max(grad_norms) > 1000:
            result['exploding_gradients'] = True
            result['recommendations'].append(
                "Exploding gradients - add gradient clipping (e.g., max_norm=1.0)"
            )
        
        # High variance in gradients
        if result['grad_norm_cv'] > 2.0:
            result['recommendations'].append(
                f"High gradient variance (CV={result['grad_norm_cv']:.2f}) - consider batch normalization"
            )
        
        return result
    
    def plot_advanced_diagnostics(self, figsize=(18, 12)):
        """Create advanced diagnostic plots using statsmodels."""
        fig = plt.figure(figsize=figsize)
        gs = fig.add_gridspec(3, 4, hspace=0.3, wspace=0.3)
        
        losses = np.array(self.history['loss'])
        epochs = np.array(self.history['epoch'])
        
        # 1. Loss trajectory with trend
        ax1 = fig.add_subplot(gs[0, :2])
        ax1.plot(epochs, losses, 'b-', alpha=0.3, label='Raw loss')
        ax1.plot(epochs, self.smoothed_loss, 'b-', linewidth=2, label='Smoothed')
        
        # Add trend line
        if len(losses) >= 30:
            try:
                ts = pd.Series(losses)
                decomp = seasonal_decompose(ts, model='additive', period=10, extrapolate_trend='freq')
                ax1.plot(epochs, decomp.trend, 'r--', linewidth=2, label='Trend')
            except:
                pass
        
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.set_title('Training Loss with Trend')
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        # 2. ACF plot
        ax2 = fig.add_subplot(gs[0, 2:])
        if len(losses) >= 20:
            try:
                plot_acf(losses, lags=min(20, len(losses)//2), ax=ax2)
                ax2.set_title('Autocorrelation Function')
            except:
                ax2.text(0.5, 0.5, 'ACF failed', ha='center', va='center')
        
        # 3. Periodogram
        ax3 = fig.add_subplot(gs[1, :2])
        if len(losses) >= 20:
            try:
                freqs, power = signal.periodogram(signal.detrend(losses))
                ax3.semilogy(freqs[1:], power[1:])
                ax3.set_xlabel('Frequency')
                ax3.set_ylabel('Power')
                ax3.set_title('Periodogram (Oscillation Detection)')
                ax3.grid(True, alpha=0.3)
            except:
                pass
        
        # 4. Q-Q plot
        ax4 = fig.add_subplot(gs[1, 2:])
        stats.probplot(losses, dist="norm", plot=ax4)
        ax4.set_title('Q-Q Plot (Normality Check)')
        ax4.grid(True, alpha=0.3)
        
        # 5. Gradient norms
        ax5 = fig.add_subplot(gs[2, :2])
        grad_norms = [g for g in self.history['grad_norm'] if g is not None]
        if grad_norms:
            ax5.semilogy(epochs[:len(grad_norms)], grad_norms, 'r-', alpha=0.6)
            ax5.set_xlabel('Epoch')
            ax5.set_ylabel('Gradient Norm (log scale)')
            ax5.set_title('Gradient Magnitude Evolution')
            ax5.grid(True, alpha=0.3)
        
        # 6. Learning rate
        ax6 = fig.add_subplot(gs[2, 2:])
        lrs = [lr for lr in self.history['learning_rate'] if lr is not None]
        if lrs:
            ax6.semilogy(epochs[:len(lrs)], lrs, 'm-', linewidth=2)
            ax6.set_xlabel('Epoch')
            ax6.set_ylabel('Learning Rate (log scale)')
            ax6.set_title('Learning Rate Schedule')
            ax6.grid(True, alpha=0.3)
        
        plt.suptitle('Advanced Training Diagnostics', fontsize=16, fontweight='bold')
        return fig
Code
# Complete Training Diagnostics Example
# =====================================

import numpy as np
import matplotlib.pyplot as plt
from scipy import stats, signal
from typing import Dict, List, Tuple, Optional
import warnings
import pandas as pd

# Statistical testing packages
from statsmodels.tsa.stattools import adfuller, kpss
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf
from statsmodels.tsa.seasonal import seasonal_decompose

# Anomaly detection
from sklearn.ensemble import IsolationForest
from sklearn.preprocessing import StandardScaler

# Change point detection
try:
    import ruptures as rpt
    HAS_RUPTURES = True
except ImportError:
    HAS_RUPTURES = False
    warnings.warn("Install ruptures: pip install ruptures")

# [Include the full TrainingMonitor class from previous response here]
# I'll show the key parts and usage examples below

def simulate_training(scenario: str, n_epochs: int = 100, seed: int = 42):
    """Simulate different training scenarios."""
    monitor = TrainingMonitor()
    np.random.seed(seed)
    
    if scenario == 'healthy':
        lr_schedule = np.logspace(-3, -4, n_epochs)
        for e in range(n_epochs):
            loss = 2.0 * np.exp(-0.03 * e) + 0.1 + np.random.normal(0, 0.02)
            grad_norm = 1.0 * np.exp(-0.02 * e) + np.random.normal(0, 0.05)
            monitor.log(e, loss, max(grad_norm, 0), lr_schedule[e])
    
    elif scenario == 'oscillating':
        lr = 0.1  # Too high
        for e in range(n_epochs):
            loss = 1.0 + 0.5 * np.sin(0.5 * e) * np.exp(-0.01 * e) + np.random.normal(0, 0.1)
            grad_norm = 2.0 + np.random.normal(0, 0.5)
            monitor.log(e, loss, max(grad_norm, 0), lr)
    
    elif scenario == 'plateau':
        for e in range(n_epochs):
            if e < 20:
                loss = 2.0 - 0.08 * e + np.random.normal(0, 0.02)
                lr = 0.001
            else:
                loss = 0.4 + np.random.normal(0, 0.01)
                lr = 0.001 * np.exp(-0.05 * (e - 20))
            grad_norm = 0.5 * np.exp(-0.02 * e) + np.random.normal(0, 0.02)
            monitor.log(e, loss, max(grad_norm, 0), lr)
    
    elif scenario == 'regime_change':
        # Simulate multiple training regimes (e.g., learning rate changes)
        for e in range(n_epochs):
            if e < 30:
                # Initial high learning rate
                loss = 2.0 * np.exp(-0.05 * e) + 0.5 + np.random.normal(0, 0.05)
                lr = 0.01
            elif e < 60:
                # Reduce learning rate
                loss = 0.5 + 0.3 * np.exp(-0.03 * (e - 30)) + np.random.normal(0, 0.02)
                lr = 0.001
            else:
                # Fine-tuning
                loss = 0.2 + 0.05 * np.exp(-0.02 * (e - 60)) + np.random.normal(0, 0.01)
                lr = 0.0001
            
            grad_norm = 1.0 * np.exp(-0.02 * e) + np.random.normal(0, 0.05)
            monitor.log(e, loss, max(grad_norm, 0), lr)
    
    elif scenario == 'diverging':
        for e in range(n_epochs):
            if e < 30:
                loss = 1.0 + np.random.normal(0, 0.1)
                grad_norm = 0.5 + np.random.normal(0, 0.1)
                lr = 0.01
            else:
                loss = 1.0 + 0.1 * (e - 30) ** 1.5
                grad_norm = 0.1 * np.exp(0.1 * (e - 30))
                lr = 0.01
            monitor.log(e, loss, max(grad_norm, 0), lr)
    
    return monitor


# =============================================================================
# EXAMPLE 1: Basic Usage - Single Scenario Analysis
# =============================================================================

print("=" * 80)
print("EXAMPLE 1: Basic Single Scenario Analysis")
print("=" * 80)

# Create a monitor and simulate training
monitor = simulate_training('healthy', n_epochs=100)

# Run diagnostics
diagnosis = monitor.diagnose_loss_curve()

# Print results
print(f"\nStatus: {diagnosis['status']}")
print(f"Message: {diagnosis['message']}")
print(f"Final Loss: {diagnosis['final_loss']:.4f}")
print(f"Min Loss: {diagnosis['min_loss']:.4f}")

# Print statistical tests
if 'adf_pvalue' in diagnosis:
    print(f"\nStationarity Tests:")
    print(f"  ADF p-value: {diagnosis['adf_pvalue']:.4f} (stationary if < 0.05)")
    print(f"  KPSS p-value: {diagnosis['kpss_pvalue']:.4f} (stationary if > 0.05)")
    print(f"  Is Stationary: {not diagnosis.get('is_nonstationary', True)}")

# Print convergence info
if 'convergence_rate' in diagnosis:
    print(f"\nConvergence Analysis:")
    print(f"  Rate (λ): {diagnosis['convergence_rate']:.6f}")
    print(f"  R²: {diagnosis.get('convergence_r_squared', 0):.4f}")
    if 'estimated_epochs_to_converge' in diagnosis:
        print(f"  Estimated epochs to converge: {diagnosis['estimated_epochs_to_converge']}")

# Print recommendations
if diagnosis['recommendations']:
    print(f"\nRecommendations:")
    for i, rec in enumerate(diagnosis['recommendations'], 1):
        print(f"  {i}. {rec}")

# Plot diagnostics
monitor.plot_advanced_diagnostics()
plt.savefig('example1_healthy_diagnostics.png', dpi=150, bbox_inches='tight')
plt.show()


# =============================================================================
# EXAMPLE 2: Compare Multiple Scenarios
# =============================================================================

print("\n" + "=" * 80)
print("EXAMPLE 2: Comparing Multiple Training Scenarios")
print("=" * 80)

scenarios = ['healthy', 'oscillating', 'plateau', 'regime_change', 'diverging']
results_table = []

for scenario in scenarios:
    monitor = simulate_training(scenario, n_epochs=100)
    diagnosis = monitor.diagnose_loss_curve()
    
    results_table.append({
        'Scenario': scenario.replace('_', ' ').title(),
        'Status': diagnosis.get('status', 'N/A'),
        'Final Loss': f"{diagnosis.get('final_loss', 0):.3f}",
        'Is Stationary': 'Yes' if not diagnosis.get('is_nonstationary', False) else 'No',
        'Oscillating': 'Yes' if diagnosis.get('is_oscillating', False) else 'No',
        'Change Points': diagnosis.get('n_changepoints', 0),
        'Convergence Rate': f"{diagnosis.get('convergence_rate', 0):.4f}",
        'Main Issue': diagnosis.get('message', 'None')
    })

# Display comparison table
results_df = pd.DataFrame(results_table)
print("\nTable 1: Training Scenario Comparison")
print(results_df.to_string(index=False))

# Visualize all scenarios
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

for idx, scenario in enumerate(scenarios):
    monitor = simulate_training(scenario)
    diagnosis = monitor.diagnose_loss_curve()
    
    ax = axes[idx // 3, idx % 3]
    
    epochs = monitor.history['epoch']
    losses = monitor.history['loss']
    
    # Plot loss
    ax.plot(epochs, losses, 'b-', alpha=0.3, linewidth=1, label='Raw')
    ax.plot(epochs, monitor.smoothed_loss, 'b-', linewidth=2, label='Smoothed')
    
    # Mark change points if detected
    if 'changepoints' in diagnosis and diagnosis['changepoints']:
        for cp in diagnosis['changepoints']:
            if cp < len(epochs):
                ax.axvline(cp, color='r', linestyle='--', alpha=0.5, linewidth=1)
    
    # Styling
    status = diagnosis.get('status', 'unknown')
    color_map = {
        'healthy': 'green',
        'oscillating': 'orange',
        'plateau': 'red',
        'nonstationary': 'purple',
        'unstable': 'darkred',
        'diverging': 'darkred'
    }
    color = color_map.get(status, 'black')
    
    ax.set_xlabel('Epoch', fontsize=9)
    ax.set_ylabel('Loss', fontsize=9)
    title = f"{scenario.replace('_', ' ').title()}"
    ax.set_title(title, fontsize=10, color=color, fontweight='bold')
    ax.grid(True, alpha=0.3)
    ax.legend(fontsize=7)
    
    # Add status box
    status_text = f"Status: {status}\n"
    if 'n_changepoints' in diagnosis and diagnosis['n_changepoints'] > 0:
        status_text += f"Changes: {diagnosis['n_changepoints']}"
    
    ax.text(0.95, 0.95, status_text, 
           transform=ax.transAxes, ha='right', va='top',
           bbox=dict(boxstyle='round', facecolor=color, alpha=0.2),
           fontsize=7)

# Hide the last subplot if odd number of scenarios
if len(scenarios) % 3 != 0:
    axes[-1, -1].axis('off')

plt.suptitle('Figure 1: Training Diagnostic Scenarios Comparison', 
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('example2_scenario_comparison.png', dpi=150, bbox_inches='tight')
plt.show()


# =============================================================================
# EXAMPLE 3: Detailed Analysis of Oscillating Training
# =============================================================================

print("\n" + "=" * 80)
print("EXAMPLE 3: Deep Dive - Oscillating Training")
print("=" * 80)

monitor = simulate_training('oscillating', n_epochs=150)
diagnosis = monitor.diagnose_loss_curve()

print(f"\nScenario: Oscillating Training")
print(f"Status: {diagnosis['status']}")
print(f"Message: {diagnosis['message']}")

# Spectral analysis results
if 'dominant_frequency' in diagnosis:
    print(f"\nSpectral Analysis:")
    print(f"  Dominant Frequency: {diagnosis['dominant_frequency']:.4f}")
    print(f"  Dominant Period: {diagnosis['dominant_period']:.2f} epochs")
    print(f"  Peak Ratio: {diagnosis.get('peak_ratio', 0):.2f}")
    print(f"  Is Oscillating: {diagnosis['is_oscillating']}")

# Oscillation metrics
if 'oscillation_rate' in diagnosis:
    print(f"\nOscillation Metrics:")
    print(f"  Sign Change Rate: {diagnosis['oscillation_rate']:.2%}")
    print(f"  Autocorrelation (lag-1): {diagnosis.get('loss_autocorr', 0):.3f}")

# Recommendations
print(f"\nRecommendations ({len(diagnosis['recommendations'])} total):")
for i, rec in enumerate(diagnosis['recommendations'][:5], 1):
    print(f"  {i}. {rec}")

# Create detailed plots
monitor.plot_advanced_diagnostics()
plt.savefig('example3_oscillating_details.png', dpi=150, bbox_inches='tight')
plt.show()


# =============================================================================
# EXAMPLE 4: Detecting Regime Changes
# =============================================================================

print("\n" + "=" * 80)
print("EXAMPLE 4: Detecting Training Regime Changes")
print("=" * 80)

monitor = simulate_training('regime_change', n_epochs=100)
diagnosis = monitor.diagnose_loss_curve()

print(f"\nScenario: Multiple Training Regimes")
print(f"Status: {diagnosis['status']}")

if 'changepoints' in diagnosis:
    print(f"\nChange Point Analysis:")
    print(f"  Number of regime changes: {diagnosis['n_changepoints']}")
    print(f"  Change points at epochs: {diagnosis['changepoints']}")
    
    if 'segment_trends' in diagnosis:
        print(f"\n  Segment trends (slope per regime):")
        for i, trend in enumerate(diagnosis['segment_trends']):
            print(f"    Segment {i+1}: {trend:.6f}")

# Visualize with marked change points
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))

epochs = np.array(monitor.history['epoch'])
losses = np.array(monitor.history['loss'])

# Top plot: Loss with change points
ax1.plot(epochs, losses, 'b-', alpha=0.4, label='Raw Loss')
ax1.plot(epochs, monitor.smoothed_loss, 'b-', linewidth=2, label='Smoothed Loss')

if 'changepoints' in diagnosis and diagnosis['changepoints']:
    for i, cp in enumerate(diagnosis['changepoints']):
        if cp < len(epochs):
            ax1.axvline(cp, color='red', linestyle='--', linewidth=2, 
                       label=f'Change Point {i+1}' if i == 0 else '')
            ax1.text(cp, ax1.get_ylim()[1], f'CP{i+1}', 
                    ha='center', va='bottom', fontsize=9, color='red')

ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training Loss with Detected Regime Changes')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Bottom plot: Learning rate schedule
lrs = [lr for lr in monitor.history['learning_rate'] if lr is not None]
ax2.semilogy(epochs[:len(lrs)], lrs, 'g-', linewidth=2, label='Learning Rate')

if 'changepoints' in diagnosis and diagnosis['changepoints']:
    for cp in diagnosis['changepoints']:
        if cp < len(epochs):
            ax2.axvline(cp, color='red', linestyle='--', linewidth=2, alpha=0.7)

ax2.set_xlabel('Epoch')
ax2.set_ylabel('Learning Rate (log scale)')
ax2.set_title('Learning Rate Schedule')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('example4_regime_changes.png', dpi=150, bbox_inches='tight')
plt.show()


# =============================================================================
# EXAMPLE 5: Real-time Monitoring During Training
# =============================================================================

print("\n" + "=" * 80)
print("EXAMPLE 5: Real-time Training Monitoring")
print("=" * 80)

print("\nSimulating real-time training with periodic diagnostics...")

monitor = TrainingMonitor()
np.random.seed(42)

# Simulate training with periodic checks
check_frequency = 20  # Check every 20 epochs
n_epochs = 100

for epoch in range(n_epochs):
    # Simulate training step
    loss = 2.0 * np.exp(-0.03 * epoch) + 0.1 + np.random.normal(0, 0.02)
    grad_norm = 1.0 * np.exp(-0.02 * epoch) + np.random.normal(0, 0.05)
    lr = 0.001 * np.exp(-0.01 * epoch)
    
    # Log metrics
    monitor.log(epoch, loss, max(grad_norm, 0), lr)
    
    # Periodic diagnostic check
    if (epoch + 1) % check_frequency == 0:
        diagnosis = monitor.diagnose_loss_curve()
        
        print(f"\n--- Epoch {epoch + 1} Checkpoint ---")
        print(f"Status: {diagnosis['status']}")
        print(f"Current Loss: {loss:.4f}")
        print(f"Gradient Norm: {grad_norm:.4f}")
        
        # Check for issues
        if diagnosis.get('is_oscillating'):
            print("WARNING: Oscillation detected!")
        
        if diagnosis.get('is_nonstationary'):
            print("WARNING: Training not converging!")
        
        if diagnosis.get('vanishing_gradients'):
            print("WARNING: Vanishing gradients!")
        
        # Show top recommendation
        if diagnosis['recommendations']:
            print(f"💡 Recommendation: {diagnosis['recommendations'][0]}")

print("\nFinal diagnostics:")
final_diagnosis = monitor.diagnose_loss_curve()
print(f"Final Status: {final_diagnosis['status']}")
print(f"Final Loss: {final_diagnosis['final_loss']:.4f}")


# =============================================================================
# EXAMPLE 6: Custom Diagnostic Report
# =============================================================================

print("\n" + "=" * 80)
print("EXAMPLE 6: Generate Comprehensive Diagnostic Report")
print("=" * 80)

def generate_diagnostic_report(monitor: TrainingMonitor, save_path: str = None):
    """Generate a comprehensive diagnostic report."""
    diagnosis = monitor.diagnose_loss_curve()
    
    report = []
    report.append("=" * 80)
    report.append("TRAINING DIAGNOSTIC REPORT")
    report.append("=" * 80)
    report.append("")
    
    # Summary
    report.append("SUMMARY")
    report.append("-" * 80)
    report.append(f"Status: {diagnosis['status'].upper()}")
    report.append(f"Message: {diagnosis['message']}")
    report.append(f"Total Epochs: {diagnosis['n_epochs']}")
    report.append(f"Final Loss: {diagnosis['final_loss']:.6f}")
    report.append(f"Minimum Loss: {diagnosis['min_loss']:.6f}")
    report.append("")
    
    # Stationarity Analysis
    if 'adf_pvalue' in diagnosis:
        report.append("STATIONARITY ANALYSIS")
        report.append("-" * 80)
        report.append(f"ADF Test Statistic: {diagnosis.get('adf_statistic', 0):.4f}")
        report.append(f"ADF p-value: {diagnosis['adf_pvalue']:.4f}")
        report.append(f"KPSS Test Statistic: {diagnosis.get('kpss_statistic', 0):.4f}")
        report.append(f"KPSS p-value: {diagnosis.get('kpss_pvalue', 0):.4f}")
        report.append(f"Conclusion: {diagnosis.get('stationarity_message', 'N/A')}")
        report.append("")
    
    # Change Point Detection
    if 'changepoints' in diagnosis:
        report.append("REGIME CHANGE DETECTION")
        report.append("-" * 80)
        report.append(f"Number of change points: {diagnosis['n_changepoints']}")
        if diagnosis['changepoints']:
            report.append(f"Change points at epochs: {diagnosis['changepoints']}")
        report.append("")
    
    # Oscillation Analysis
    if 'dominant_frequency' in diagnosis:
        report.append("OSCILLATION ANALYSIS")
        report.append("-" * 80)
        report.append(f"Dominant Period: {diagnosis.get('dominant_period', 0):.2f} epochs")
        report.append(f"Peak Ratio: {diagnosis.get('peak_ratio', 0):.2f}")
        report.append(f"Is Oscillating: {diagnosis.get('is_oscillating', False)}")
        report.append("")
    
    # Convergence Analysis
    if 'convergence_rate' in diagnosis:
        report.append("CONVERGENCE ANALYSIS")
        report.append("-" * 80)
        report.append(f"Convergence Rate (λ): {diagnosis['convergence_rate']:.6f}")
        report.append(f"R-squared: {diagnosis.get('convergence_r_squared', 0):.4f}")
        if 'estimated_epochs_to_converge' in diagnosis:
            report.append(f"Estimated epochs to convergence: {diagnosis['estimated_epochs_to_converge']}")
        report.append("")
    
    # Gradient Analysis
    if 'mean_grad_norm' in diagnosis:
        report.append("GRADIENT ANALYSIS")
        report.append("-" * 80)
        report.append(f"Mean Gradient Norm: {diagnosis['mean_grad_norm']:.6f}")
        report.append(f"Gradient Std Dev: {diagnosis.get('grad_norm_std', 0):.6f}")
        report.append(f"Gradient Trend: {diagnosis.get('grad_norm_trend', 0):.6f}")
        if diagnosis.get('vanishing_gradients'):
            report.append("VANISHING GRADIENTS DETECTED")
        if diagnosis.get('exploding_gradients'):
            report.append("EXPLODING GRADIENTS DETECTED")
        report.append("")
    
    # Recommendations
    if diagnosis['recommendations']:
        report.append("RECOMMENDATIONS")
        report.append("-" * 80)
        for i, rec in enumerate(diagnosis['recommendations'], 1):
            report.append(f"{i}. {rec}")
        report.append("")
    
    report.append("=" * 80)
    
    # Print report
    report_text = "\n".join(report)
    print(report_text)
    
    # Save to file if requested
    if save_path:
        with open(save_path, 'w') as f:
            f.write(report_text)
        print(f"\nReport saved to: {save_path}")
    
    return report_text


# Generate report for a scenario
monitor = simulate_training('regime_change', n_epochs=100)
report = generate_diagnostic_report(monitor, save_path='diagnostic_report.txt')


# =============================================================================
# EXAMPLE 7: Batch Analysis - Compare Training Runs
# =============================================================================

print("\n" + "=" * 80)
print("EXAMPLE 7: Batch Analysis - Multiple Training Runs")
print("=" * 80)

# Simulate multiple training runs with different hyperparameters
runs = []

for seed in range(5):
    monitor = simulate_training('healthy', n_epochs=100, seed=seed)
    diagnosis = monitor.diagnose_loss_curve()
    
    runs.append({
        'Run': f"Run {seed + 1}",
        'Final Loss': diagnosis['final_loss'],
        'Min Loss': diagnosis['min_loss'],
        'Convergence Rate': diagnosis.get('convergence_rate', 0),
        'R²': diagnosis.get('convergence_r_squared', 0),
        'Status': diagnosis['status']
    })

runs_df = pd.DataFrame(runs)
print("\nTable 2: Multiple Training Runs Comparison")
print(runs_df.to_string(index=False))

print(f"\nStatistics across runs:")
print(f"  Mean Final Loss: {runs_df['Final Loss'].mean():.4f} ± {runs_df['Final Loss'].std():.4f}")
print(f"  Mean Convergence Rate: {runs_df['Convergence Rate'].mean():.6f}")
print(f"  Best Run: {runs_df.loc[runs_df['Final Loss'].idxmin(), 'Run']}")


print("\n" + "=" * 80)
print("All examples completed!")
print("=" * 80)

219.2 Negative Sampling Quality

For contrastive learning methods, the quality of negative samples critically affects representation quality (Mikolov et al. 2013; Gutmann and Hyvärinen 2010). Poorly chosen negatives can lead to collapsed representations or failure to learn meaningful distinctions (Jing et al. 2021).

219.2.1 Theoretical Framework

In contrastive learning, we minimize:

\[ \mathcal{L} = -\log \frac{\exp(\text{sim}(z_i, z_i^+)/\tau)}{\exp(\text{sim}(z_i, z_i^+)/\tau) + \sum_{j=1}^K \exp(\text{sim}(z_i, z_j^-)/\tau)} \]

where \(z_i^+\) is a positive pair and \(\{z_j^-\}\) are negative samples. The effectiveness depends on the separation between positive and negative similarities.

Negative sampling quality analysis with statistical tests
from scipy.spatial.distance import cdist
from sklearn.metrics import roc_auc_score, average_precision_score

def analyze_negative_sampling(embeddings: np.ndarray,
                               positive_edges: np.ndarray,
                               negative_edges: np.ndarray,
                               sample_labels: np.ndarray = None) -> Dict:
    """
    Comprehensive analysis of negative sample quality.
    
    
    Parameters
    ----------
    embeddings : np.ndarray, shape (n_nodes, d)
        Node embeddings to evaluate
    positive_edges : np.ndarray, shape (n_pos, 2)
        True positive node pairs
    negative_edges : np.ndarray, shape (n_neg, 2)
        Sampled negative node pairs
    sample_labels : np.ndarray, optional
        True labels for negative samples (if available for validation)
    
    Returns
    -------
    Dict containing:
        - similarity statistics
        - separation metrics (Cohen's d, AUC)
        - false negative estimates
        - sampling quality assessment
        - recommendations
    
    """
    # Compute similarities for positive pairs
    pos_emb1 = embeddings[positive_edges[:, 0]]
    pos_emb2 = embeddings[positive_edges[:, 1]]
    pos_sims = np.sum(pos_emb1 * pos_emb2, axis=1) / (
        np.linalg.norm(pos_emb1, axis=1) * np.linalg.norm(pos_emb2, axis=1) + 1e-10
    )
    
    # Compute similarities for negative pairs
    neg_emb1 = embeddings[negative_edges[:, 0]]
    neg_emb2 = embeddings[negative_edges[:, 1]]
    neg_sims = np.sum(neg_emb1 * neg_emb2, axis=1) / (
        np.linalg.norm(neg_emb1, axis=1) * np.linalg.norm(neg_emb2, axis=1) + 1e-10
    )
    
    # Basic statistics
    pos_mean, pos_std = np.mean(pos_sims), np.std(pos_sims)
    neg_mean, neg_std = np.mean(neg_sims), np.std(neg_sims)
    
    # Effect size (Cohen's d)
    pooled_std = np.sqrt((pos_std**2 + neg_std**2) / 2)
    cohens_d = (pos_mean - neg_mean) / (pooled_std + 1e-10)
    
    # Statistical tests
    # 1. Two-sample t-test
    t_stat, t_pvalue = stats.ttest_ind(pos_sims, neg_sims)
    
    # 2. Mann-Whitney U test (non-parametric)
    u_stat, u_pvalue = stats.mannwhitneyu(pos_sims, neg_sims, alternative='greater')
    
    # 3. Kolmogorov-Smirnov test
    ks_stat, ks_pvalue = stats.ks_2samp(pos_sims, neg_sims)
    
    # False negative rate estimation
    # Method 1: Threshold-based (use 25th percentile of positive similarities)
    pos_threshold = np.percentile(pos_sims, 25)
    false_negatives_threshold = np.mean(neg_sims > pos_threshold)
    
    # Method 2: Distribution overlap
    # Estimate overlap assuming normal distributions
    overlap_point = (pos_mean * neg_std + neg_mean * pos_std) / (pos_std + neg_std)
    false_neg_distribution = stats.norm.sf(overlap_point, pos_mean, pos_std)
    false_pos_distribution = stats.norm.cdf(overlap_point, neg_mean, neg_std)
    
    # Discrimination metrics
    # Treat this as binary classification: predict if pair is positive
    labels = np.concatenate([np.ones(len(pos_sims)), np.zeros(len(neg_sims))])
    scores = np.concatenate([pos_sims, neg_sims])
    
    auc_score = roc_auc_score(labels, scores)
    ap_score = average_precision_score(labels, scores)
    
    # Quality assessment based on Cohen's d
    if cohens_d > 2.0:
        quality = 'excellent'
        quality_score = 5
    elif cohens_d > 1.5:
        quality = 'good'
        quality_score = 4
    elif cohens_d > 1.0:
        quality = 'moderate'
        quality_score = 3
    elif cohens_d > 0.5:
        quality = 'poor'
        quality_score = 2
    else:
        quality = 'very_poor'
        quality_score = 1
    
    # Uniformity check (Wang & Isola, 2020)
    # Sample random pairs to check embedding uniformity
    n_uniform_samples = min(1000, len(embeddings))
    random_idx = np.random.choice(len(embeddings), n_uniform_samples, replace=False)
    random_emb = embeddings[random_idx]
    
    # Compute pairwise distances
    pairwise_sims = np.dot(random_emb, random_emb.T)
    pairwise_sims = pairwise_sims / (np.linalg.norm(random_emb, axis=1, keepdims=True) @ 
                                     np.linalg.norm(random_emb, axis=1, keepdims=True).T + 1e-10)
    
    # Uniformity: embeddings should be uniformly distributed on hypersphere
    # Lower variance in pairwise similarities = more uniform
    triu_idx = np.triu_indices_from(pairwise_sims, k=1)
    uniformity = np.std(pairwise_sims[triu_idx])
    
    # Alignment: positive pairs should be close
    alignment = np.mean(pos_sims)
    
    # Generate recommendations
    recommendations = []
    
    if cohens_d < 1.0:
        recommendations.append("Improve negative sampling strategy - current negatives too similar to positives")
        recommendations.append("Consider hard negative mining")
    
    if false_negatives_threshold > 0.1:
        recommendations.append(f"High false negative rate ({false_negatives_threshold:.1%}) - verify ground truth")
    
    if uniformity > 0.3:
        recommendations.append("High uniformity variance - check for representation collapse")
    
    if alignment < 0.5:
        recommendations.append("Low positive pair alignment - increase training or check data quality")
    
    if auc_score < 0.8:
        recommendations.append("Poor discriminative power - review embedding model capacity")
    
    # Compile results
    results = {
        # Basic statistics
        'pos_sim_mean': pos_mean,
        'pos_sim_std': pos_std,
        'neg_sim_mean': neg_mean,
        'neg_sim_std': neg_std,
        'separation': pos_mean - neg_mean,
        
        # Effect size
        'cohens_d': cohens_d,
        'quality': quality,
        'quality_score': quality_score,
        
        # Statistical tests
        't_statistic': t_stat,
        't_pvalue': t_pvalue,
        'mannwhitney_u': u_stat,
        'mannwhitney_pvalue': u_pvalue,
        'ks_statistic': ks_stat,
        'ks_pvalue': ks_pvalue,
        
        # False negative estimates
        'false_negatives_threshold': false_negatives_threshold,
        'false_negatives_distribution': false_neg_distribution,
        
        # Discrimination metrics
        'auc_roc': auc_score,
        'average_precision': ap_score,
        
        # Alignment and uniformity (Wang & Isola, 2020)
        'alignment': alignment,
        'uniformity': uniformity,
        
        # Recommendations
        'recommendations': recommendations,
        
        # Raw data for visualization
        'pos_sims': pos_sims,
        'neg_sims': neg_sims
    }
    
    return results


def plot_negative_sampling_analysis(results: Dict, figsize=(15, 10)):
    """
    Create comprehensive visualization of negative sampling quality.
    
    Generates:
    1. Similarity distributions
    2. ROC curve
    3. Precision-Recall curve
    4. Quantile-Quantile plot
    5. Separation metrics
    6. Statistical summary
    """
    fig, axes = plt.subplots(2, 3, figsize=figsize)
    
    pos_sims = results['pos_sims']
    neg_sims = results['neg_sims']
    
    # 1. Distribution comparison
    ax = axes[0, 0]
    ax.hist(pos_sims, bins=50, alpha=0.6, label='Positive pairs', 
           density=True, color='blue', edgecolor='black')
    ax.hist(neg_sims, bins=50, alpha=0.6, label='Negative pairs', 
           density=True, color='orange', edgecolor='black')
    ax.axvline(results['pos_sim_mean'], color='blue', linestyle='--', 
              linewidth=2, label=f"Pos μ={results['pos_sim_mean']:.3f}")
    ax.axvline(results['neg_sim_mean'], color='orange', linestyle='--', 
              linewidth=2, label=f"Neg μ={results['neg_sim_mean']:.3f}")
    ax.set_xlabel('Cosine Similarity')
    ax.set_ylabel('Density')
    ax.set_title('Similarity Distributions')
    ax.legend(fontsize=8)
    ax.grid(True, alpha=0.3)
    
    # 2. Cumulative distributions
    ax = axes[0, 1]
    pos_sorted = np.sort(pos_sims)
    neg_sorted = np.sort(neg_sims)
    ax.plot(pos_sorted, np.linspace(0, 1, len(pos_sorted)), 
           'b-', linewidth=2, label='Positive')
    ax.plot(neg_sorted, np.linspace(0, 1, len(neg_sorted)), 
           'orange', linewidth=2, label='Negative')
    ax.set_xlabel('Cosine Similarity')
    ax.set_ylabel('Cumulative Probability')
    ax.set_title(f'CDF (KS stat={results["ks_statistic"]:.3f})')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # 3. ROC Curve
    ax = axes[0, 2]
    labels = np.concatenate([np.ones(len(pos_sims)), np.zeros(len(neg_sims))])
    scores = np.concatenate([pos_sims, neg_sims])
    
    # Compute ROC
    from sklearn.metrics import roc_curve
    fpr, tpr, thresholds = roc_curve(labels, scores)
    ax.plot(fpr, tpr, 'b-', linewidth=2, 
           label=f'AUC={results["auc_roc"]:.3f}')
    ax.plot([0, 1], [0, 1], 'r--', label='Random')
    ax.set_xlabel('False Positive Rate')
    ax.set_ylabel('True Positive Rate')
    ax.set_title('ROC Curve')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # 4. Precision-Recall Curve
    ax = axes[1, 0]
    from sklearn.metrics import precision_recall_curve
    precision, recall, _ = precision_recall_curve(labels, scores)
    ax.plot(recall, precision, 'b-', linewidth=2,
           label=f'AP={results["average_precision"]:.3f}')
    ax.set_xlabel('Recall')
    ax.set_ylabel('Precision')
    ax.set_title('Precision-Recall Curve')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # 5. Q-Q Plot
    ax = axes[1, 1]
    from scipy import stats as sp_stats
    sp_stats.probplot(pos_sims, dist="norm", plot=ax)
    ax.set_title('Q-Q Plot (Positive Similarities)')
    ax.grid(True, alpha=0.3)
    
    # 6. Statistical Summary
    ax = axes[1, 2]
    ax.axis('off')
    
    summary_text = f"""
    NEGATIVE SAMPLING QUALITY
    {'=' * 40}
    
    Effect Size (Cohen's d): {results['cohens_d']:.3f}
    Quality: {results['quality'].upper()}
    
    Separation: {results['separation']:.3f}
    
    Statistical Tests:
      t-test p-value: {results['t_pvalue']:.2e}
      Mann-Whitney p: {results['mannwhitney_pvalue']:.2e}
      KS test p-value: {results['ks_pvalue']:.2e}
    
    Discrimination:
      AUC-ROC: {results['auc_roc']:.3f}
      Avg Precision: {results['average_precision']:.3f}
    
    False Negatives: {results['false_negatives_threshold']:.1%}
    
    Alignment: {results['alignment']:.3f}
    Uniformity: {results['uniformity']:.3f}
    """
    
    ax.text(0.1, 0.9, summary_text, transform=ax.transAxes,
           fontfamily='monospace', fontsize=9, verticalalignment='top')
    
    plt.tight_layout()
    return fig


# Demonstrate negative sampling analysis
print("\n" + "=" * 80)
print("NEGATIVE SAMPLING QUALITY ANALYSIS")
print("=" * 80)

# Run analysis
analysis = analyze_negative_sampling(embeddings, test_edges, negative_edges)

# Print summary
print(f"\nPositive similarity: {analysis['pos_sim_mean']:.3f} ± {analysis['pos_sim_std']:.3f}")
print(f"Negative similarity: {analysis['neg_sim_mean']:.3f} ± {analysis['neg_sim_std']:.3f}")
print(f"Separation: {analysis['separation']:.3f}")
print(f"Cohen's d: {analysis['cohens_d']:.3f} ({analysis['quality']})")
print(f"AUC-ROC: {analysis['auc_roc']:.3f}")
print(f"Average Precision: {analysis['average_precision']:.3f}")
print(f"Potential false negatives: {analysis['false_negatives_threshold']:.1%}")
print(f"Alignment: {analysis['alignment']:.3f}")
print(f"Uniformity: {analysis['uniformity']:.3f}")

if analysis['recommendations']:
    print("\nRecommendations:")
    for i, rec in enumerate(analysis['recommendations'], 1):
        print(f"  {i}. {rec}")

# Create visualization
plot_negative_sampling_analysis(analysis)
plt.suptitle('Figure 3: Negative Sampling Quality Analysis',
             fontsize=14, fontweight='bold', y=1.00)
plt.show()

# Create summary table
print("\n" + "=" * 80)
print("Table 2: Negative Sampling Quality Metrics")
print("=" * 80)

metrics_table = pd.DataFrame([
    {'Metric': 'Positive Similarity (mean)', 'Value': f"{analysis['pos_sim_mean']:.3f}"},
    {'Metric': 'Negative Similarity (mean)', 'Value': f"{analysis['neg_sim_mean']:.3f}"},
    {'Metric': 'Separation', 'Value': f"{analysis['separation']:.3f}"},
    {'Metric': "Cohen's d", 'Value': f"{analysis['cohens_d']:.3f}"},
    {'Metric': 'Quality Rating', 'Value': analysis['quality']},
    {'Metric': 'AUC-ROC', 'Value': f"{analysis['auc_roc']:.3f}"},
    {'Metric': 'Average Precision', 'Value': f"{analysis['average_precision']:.3f}"},
    {'Metric': 'False Negative Rate', 'Value': f"{analysis['false_negatives_threshold']:.1%}"},
    {'Metric': 'Alignment', 'Value': f"{analysis['alignment']:.3f}"},
    {'Metric': 'Uniformity (std)', 'Value': f"{analysis['uniformity']:.3f}"},
])

print(metrics_table.to_string(index=False))

220 Production Monitoring

Once deployed, embedding models require continuous monitoring.

220.1 Drift Detection

Comprehensive drift detection
class EmbeddingDriftDetector:
    """Detect drift in embedding systems."""
    
    def __init__(self, reference_embeddings: np.ndarray, reference_labels: np.ndarray = None):
        self.reference = reference_embeddings
        self.reference_labels = reference_labels
        
        self.ref_mean = np.mean(reference_embeddings, axis=0)
        self.ref_std = np.std(reference_embeddings, axis=0)
        self.ref_isotropy = compute_isotropy_metrics(reference_embeddings)
    
    def detect_distribution_drift(self, current_embeddings: np.ndarray) -> Dict:
        curr_mean = np.mean(current_embeddings, axis=0)
        curr_std = np.std(current_embeddings, axis=0)
        
        mean_shift = np.linalg.norm(curr_mean - self.ref_mean)
        variance_ratio = np.mean(curr_std ** 2) / np.mean(self.ref_std ** 2)
        
        # MMD approximation
        mmd = self._compute_mmd(self.reference, current_embeddings)
        
        # KS test per dimension
        ks_stats = []
        for d in range(self.reference.shape[1]):
            stat, _ = stats.ks_2samp(self.reference[:, d], current_embeddings[:, d])
            ks_stats.append(stat)
        
        curr_isotropy = compute_isotropy_metrics(current_embeddings)
        isotropy_drift = abs(curr_isotropy['apcs'] - self.ref_isotropy['apcs'])
        
        return {
            'mean_shift': mean_shift,
            'variance_ratio': variance_ratio,
            'mmd': mmd,
            'max_ks_statistic': np.max(ks_stats),
            'mean_ks_statistic': np.mean(ks_stats),
            'isotropy_drift': isotropy_drift
        }
    
    def _compute_mmd(self, X: np.ndarray, Y: np.ndarray) -> float:
        XY = np.vstack([X[:1000], Y[:1000]])
        dists = cdist(XY, XY, 'sqeuclidean')
        gamma = 1.0 / np.median(dists[dists > 0])
        
        n_x, n_y = min(1000, len(X)), min(1000, len(Y))
        X_sample, Y_sample = X[:n_x], Y[:n_y]
        
        K_xx = np.exp(-gamma * cdist(X_sample, X_sample, 'sqeuclidean'))
        K_yy = np.exp(-gamma * cdist(Y_sample, Y_sample, 'sqeuclidean'))
        K_xy = np.exp(-gamma * cdist(X_sample, Y_sample, 'sqeuclidean'))
        
        mmd_sq = (np.sum(K_xx) / (n_x * n_x) + 
                  np.sum(K_yy) / (n_y * n_y) - 
                  2 * np.sum(K_xy) / (n_x * n_y))
        
        return np.sqrt(max(0, mmd_sq))
    
    def detect_concept_drift(self, current_embeddings: np.ndarray,
                              current_labels: np.ndarray) -> Dict:
        if self.reference_labels is None:
            return {'error': 'No reference labels'}
        
        clf = LogisticRegression(max_iter=1000, random_state=42)
        clf.fit(self.reference, self.reference_labels)
        
        ref_acc = np.mean(clf.predict(self.reference) == self.reference_labels)
        curr_acc = np.mean(clf.predict(current_embeddings) == current_labels)
        
        return {
            'reference_accuracy': ref_acc,
            'current_accuracy': curr_acc,
            'degradation': ref_acc - curr_acc
        }
    
    def get_drift_summary(self, current_embeddings: np.ndarray,
                          current_labels: np.ndarray = None) -> Dict:
        dist_drift = self.detect_distribution_drift(current_embeddings)
        
        alerts = []
        if dist_drift['mean_shift'] > 0.5:
            alerts.append(f"⚠ High mean shift: {dist_drift['mean_shift']:.3f}")
        if dist_drift['variance_ratio'] < 0.5 or dist_drift['variance_ratio'] > 2.0:
            alerts.append(f"⚠ Variance changed: {dist_drift['variance_ratio']:.3f}")
        if dist_drift['mmd'] > 0.1:
            alerts.append(f"⚠ High MMD: {dist_drift['mmd']:.4f}")
        if dist_drift['isotropy_drift'] > 0.1:
            alerts.append(f"⚠ Isotropy changed: {dist_drift['isotropy_drift']:.3f}")
        
        concept_drift = None
        if current_labels is not None and self.reference_labels is not None:
            concept_drift = self.detect_concept_drift(current_embeddings, current_labels)
            if concept_drift.get('degradation', 0) > 0.1:
                alerts.append(f"⚠ Performance degraded: {concept_drift['degradation']:.1%}")
        
        return {
            'distribution_drift': dist_drift,
            'concept_drift': concept_drift,
            'alerts': alerts,
            'n_alerts': len(alerts)
        }


# Demonstration
np.random.seed(42)
n_ref = 2000
embedding_dim = 64

reference_emb = np.random.randn(n_ref, embedding_dim)
reference_labels = (reference_emb[:, 0] > 0).astype(int)

detector = EmbeddingDriftDetector(reference_emb, reference_labels)

print("DRIFT DETECTION SCENARIOS")
print("=" * 60)

# No drift
print("\n--- Scenario 1: No Drift ---")
current_no_drift = np.random.randn(1000, embedding_dim)
labels_no_drift = (current_no_drift[:, 0] > 0).astype(int)
summary = detector.get_drift_summary(current_no_drift, labels_no_drift)
print(f"MMD: {summary['distribution_drift']['mmd']:.4f}")
print(f"Mean shift: {summary['distribution_drift']['mean_shift']:.4f}")
print(f"Alerts: {summary['alerts']}")

# Mean shift
print("\n--- Scenario 2: Mean Shift ---")
current_shifted = np.random.randn(1000, embedding_dim) + 0.8
labels_shifted = (current_shifted[:, 0] > 0.8).astype(int)
summary = detector.get_drift_summary(current_shifted, labels_shifted)
print(f"MMD: {summary['distribution_drift']['mmd']:.4f}")
print(f"Mean shift: {summary['distribution_drift']['mean_shift']:.4f}")
print(f"Alerts: {summary['alerts']}")

# Concept drift
print("\n--- Scenario 3: Concept Drift ---")
current_concept = np.random.randn(1000, embedding_dim)
labels_concept = (current_concept[:, 5] > 0).astype(int)  # Different rule!
summary = detector.get_drift_summary(current_concept, labels_concept)
print(f"MMD: {summary['distribution_drift']['mmd']:.4f}")
if summary['concept_drift']:
    print(f"Reference accuracy: {summary['concept_drift']['reference_accuracy']:.3f}")
    print(f"Current accuracy: {summary['concept_drift']['current_accuracy']:.3f}")
print(f"Alerts: {summary['alerts']}")

220.2 Real-Time Monitoring Dashboard

Monitoring dashboard metrics
@dataclass
class MonitoringMetrics:
    """Container for real-time monitoring metrics."""
    timestamp: datetime
    n_embeddings: int
    apcs: float
    participation_ratio: float
    mean_norm: float
    std_norm: float
    drift_from_baseline: float = 0.0
    
    def to_dict(self) -> Dict:
        return {
            'timestamp': self.timestamp,
            'n_embeddings': self.n_embeddings,
            'apcs': self.apcs,
            'participation_ratio': self.participation_ratio,
            'mean_norm': self.mean_norm,
            'std_norm': self.std_norm,
            'drift_from_baseline': self.drift_from_baseline
        }


class RealTimeMonitor:
    """Real-time embedding quality monitoring."""
    
    def __init__(self, baseline_embeddings: np.ndarray = None):
        self.baseline = baseline_embeddings
        self.metrics_history = []
        
        if baseline_embeddings is not None:
            self.baseline_metrics = compute_isotropy_metrics(baseline_embeddings)
    
    def compute_metrics(self, embeddings: np.ndarray) -> MonitoringMetrics:
        isotropy = compute_isotropy_metrics(embeddings)
        
        norms = np.linalg.norm(embeddings, axis=1)
        
        drift = 0.0
        if self.baseline is not None:
            alignment = procrustes_similarity(self.baseline[:len(embeddings)], embeddings)
            drift = alignment['procrustes_distance']
        
        metrics = MonitoringMetrics(
            timestamp=datetime.now(),
            n_embeddings=len(embeddings),
            apcs=isotropy['apcs'],
            participation_ratio=isotropy['participation_ratio'],
            mean_norm=np.mean(norms),
            std_norm=np.std(norms),
            drift_from_baseline=drift
        )
        
        self.metrics_history.append(metrics.to_dict())
        return metrics
    
    def get_history_df(self) -> pd.DataFrame:
        return pd.DataFrame(self.metrics_history)
    
    def check_alerts(self, metrics: MonitoringMetrics) -> List[str]:
        alerts = []
        
        if metrics.apcs > 0.5:
            alerts.append(f"CRITICAL: High anisotropy (APCS={metrics.apcs:.3f})")
        elif metrics.apcs > 0.3:
            alerts.append(f"WARNING: Moderate anisotropy (APCS={metrics.apcs:.3f})")
        
        if metrics.drift_from_baseline > 0.5:
            alerts.append(f"CRITICAL: High drift from baseline ({metrics.drift_from_baseline:.3f})")
        
        if metrics.std_norm / metrics.mean_norm > 0.5:
            alerts.append(f"WARNING: High norm variance (CV={metrics.std_norm/metrics.mean_norm:.3f})")
        
        return alerts


# Simulate monitoring over time
baseline = np.random.randn(1000, 64)
monitor = RealTimeMonitor(baseline_embeddings=baseline)

print("REAL-TIME MONITORING SIMULATION")
print("=" * 60)

for batch in range(10):
    # Simulate batch of new embeddings with gradual drift
    drift_factor = 0.1 * batch
    current_batch = np.random.randn(100, 64) + drift_factor
    
    metrics = monitor.compute_metrics(current_batch)
    alerts = monitor.check_alerts(metrics)
    
    print(f"\nBatch {batch}: APCS={metrics.apcs:.3f}, Drift={metrics.drift_from_baseline:.3f}")
    for alert in alerts:
        print(f"  {alert}")

221 Summary and Best Practices

221.1 Evaluation Checklist

  1. Intrinsic Evaluation
  1. Extrinsic Evaluation
  1. Training Diagnostics
  1. Production Monitoring
  1. Business Metrics

221.2 Key Takeaways

  1. Intrinsic metrics predict downstream problems: Anisotropy and hubness reliably predict poor recommendation quality before deployment.

  2. Temporal splits prevent leakage: Always evaluate link prediction with proper temporal train/test splits in business contexts.

  3. Stability matters for production: Unstable embeddings lead to inconsistent user experiences and difficult debugging.

  4. Monitoring enables proactive intervention: Continuous drift detection catches problems before they impact business metrics.

  5. Multiple metrics provide complete picture: No single metric captures embedding quality; use intrinsic, extrinsic, and operational metrics together.


222 Fairness and Bias in Embeddings

Embeddings can encode and amplify societal biases present in training data. Evaluating fairness is essential for responsible deployment.

222.1 Types of Embedding Bias

Bias Category Mechanism & Definition Concrete Example
Representation Bias Occurs when certain groups have significantly fewer training examples than others. This leads to lower fidelity embeddings for minority groups or long tail items. Rare product categories (e.g., specialized medical equipment) have poor vector quality compared to popular consumer electronics.
Association Bias The embedding space captures and amplifies human stereotypes present in the training text, encoding cultural biases as geometric relationships. The vector for “Doctor” is mathematically closer to “Male” than “Female,” while “Nurse” is closer to “Female.”
Allocation Bias The system distributes resources or opportunities unequally. The model performs better for the majority group, allocating the best recommendations to them. New users (cold start) receive generic, high popularity recommendations, while active users receive highly personalized niche content.
Measurement Bias The metrics used to evaluate the model favor the majority class. A high overall accuracy score can hide catastrophic failure within a specific subgroup. A model reports 95% AUC overall, but only achieves 60% AUC for a specific minority demographic, which is masked by the global average.
Historical Bias The training data accurately reflects the world, but the world itself contains historical inequities. The model learns to perpetuate these past patterns. A hiring algorithm trained on 10 years of resume data penalizes graduates from women’s colleges because the historical hiring data reflects past discrimination.

222.2 Measuring Fairness in Embeddings

Fairness evaluation metrics
class EmbeddingFairnessEvaluator:
    """
    Evaluate fairness properties of embeddings.
    
    Focuses on:
    1. Representation quality parity across groups
    2. Association bias detection
    3. Downstream task fairness
    """
    
    def __init__(self, embeddings: np.ndarray, 
                 group_labels: np.ndarray,
                 group_names: List[str] = None):
        """
        Parameters
        ----------
        embeddings : np.ndarray
            Embedding matrix
        group_labels : np.ndarray
            Group membership for each entity (e.g., demographic group)
        group_names : list, optional
            Human-readable names for groups
        """
        self.embeddings = embeddings
        self.group_labels = group_labels
        self.groups = np.unique(group_labels)
        self.group_names = group_names or [f"Group_{g}" for g in self.groups]
    
    def representation_quality_by_group(self) -> pd.DataFrame:
        """
        Compare embedding quality metrics across groups.
        
        Checks if some groups have lower-quality embeddings.
        """
        results = []
        
        for group in self.groups:
            mask = self.group_labels == group
            group_emb = self.embeddings[mask]
            
            # Isotropy within group
            if len(group_emb) > 100:
                isotropy = compute_isotropy_metrics(group_emb)
                apcs = isotropy['apcs']
                participation = isotropy['participation_ratio']
            else:
                apcs = np.nan
                participation = np.nan
            
            # Norm statistics
            norms = np.linalg.norm(group_emb, axis=1)
            
            # Distance to global centroid
            global_centroid = np.mean(self.embeddings, axis=0)
            group_centroid = np.mean(group_emb, axis=0)
            centroid_distance = np.linalg.norm(group_centroid - global_centroid)
            
            results.append({
                'group': self.group_names[list(self.groups).index(group)],
                'n_entities': len(group_emb),
                'fraction': len(group_emb) / len(self.embeddings),
                'mean_norm': np.mean(norms),
                'std_norm': np.std(norms),
                'apcs': apcs,
                'participation_ratio': participation,
                'centroid_distance': centroid_distance
            })
        
        return pd.DataFrame(results)
    
    def downstream_fairness(self, labels: np.ndarray, 
                            task: str = 'classification') -> pd.DataFrame:
        """
        Evaluate downstream task performance by group.
        
        Parameters
        ----------
        labels : np.ndarray
            Target labels for downstream task
        task : str
            'classification' or 'regression'
        """
        results = []
        
        # Train global model
        clf = LogisticRegression(max_iter=1000, random_state=42)
        clf.fit(self.embeddings, labels)
        
        # Evaluate per group
        for group in self.groups:
            mask = self.group_labels == group
            group_emb = self.embeddings[mask]
            group_labels = labels[mask]
            
            if len(group_emb) < 10:
                continue
            
            # Predictions
            pred = clf.predict(group_emb)
            
            # Metrics
            acc = accuracy_score(group_labels, pred)
            
            # Per-class metrics if multi-class
            unique_labels = np.unique(labels)
            if len(unique_labels) > 2:
                f1 = f1_score(group_labels, pred, average='macro')
            else:
                f1 = f1_score(group_labels, pred)
            
            results.append({
                'group': self.group_names[list(self.groups).index(group)],
                'n_samples': len(group_emb),
                'accuracy': acc,
                'f1_score': f1
            })
        
        df = pd.DataFrame(results)
        
        # Compute fairness metrics
        if len(df) > 1:
            df['accuracy_gap'] = df['accuracy'].max() - df['accuracy']
            df['f1_gap'] = df['f1_score'].max() - df['f1_score']
        
        return df
    
    def compute_association_bias(self, attribute_embeddings: Dict[str, np.ndarray],
                                  target_pairs: List[Tuple[str, str]]) -> pd.DataFrame:
        """
        Measure association bias using WEAT-style analysis.
        
        Parameters
        ----------
        attribute_embeddings : dict
            {attribute_name: embedding_vector}
            e.g., {'male': emb_male, 'female': emb_female}
        target_pairs : list
            Pairs of contrasting attributes to test
            e.g., [('male', 'female'), ('young', 'old')]
        """
        results = []
        
        for attr1, attr2 in target_pairs:
            if attr1 not in attribute_embeddings or attr2 not in attribute_embeddings:
                continue
            
            emb1 = attribute_embeddings[attr1]
            emb2 = attribute_embeddings[attr2]
            
            # Compute bias direction
            bias_direction = emb1 - emb2
            bias_direction = bias_direction / (np.linalg.norm(bias_direction) + 1e-10)
            
            # Project all embeddings onto bias direction
            projections = self.embeddings @ bias_direction
            
            # Compute per-group statistics
            for group in self.groups:
                mask = self.group_labels == group
                group_proj = projections[mask]
                
                results.append({
                    'attribute_pair': f"{attr1} vs {attr2}",
                    'group': self.group_names[list(self.groups).index(group)],
                    'mean_projection': np.mean(group_proj),
                    'std_projection': np.std(group_proj)
                })
        
        return pd.DataFrame(results)


# Demonstration with simulated data
np.random.seed(42)
n = 2000
embedding_dim = 64

# Simulate groups with different embedding quality
# Group 0: majority, well-represented
# Group 1: minority, less data, noisier embeddings
group_labels = np.random.choice([0, 1], size=n, p=[0.8, 0.2])
task_labels = np.random.randint(0, 3, n)  # 3-class task

embeddings_fair = np.zeros((n, embedding_dim))
for i in range(n):
    if group_labels[i] == 0:
        # Majority group: clean embeddings
        embeddings_fair[i] = np.random.randn(embedding_dim) * 0.5
    else:
        # Minority group: noisier embeddings (simulating less training data)
        embeddings_fair[i] = np.random.randn(embedding_dim) * 1.0

# Add task-relevant signal
class_centers = np.random.randn(3, embedding_dim)
for i in range(n):
    embeddings_fair[i] += class_centers[task_labels[i]] * 0.3

print("FAIRNESS EVALUATION")
print("=" * 60)

evaluator = EmbeddingFairnessEvaluator(
    embeddings_fair, 
    group_labels,
    group_names=['Majority', 'Minority']
)

print("\n--- Representation Quality by Group ---")
rep_quality = evaluator.representation_quality_by_group()
print(rep_quality.to_string(index=False))

print("\n--- Downstream Task Fairness ---")
downstream = evaluator.downstream_fairness(task_labels)
print(downstream.to_string(index=False))

# Check for significant gaps
max_acc_gap = downstream['accuracy_gap'].max()
if max_acc_gap > 0.1:
    print(f"\n⚠ WARNING: Accuracy gap of {max_acc_gap:.1%} detected between groups")
Code
fig, axes = plt.subplots(1, 3, figsize=(14, 4))

# Representation quality
ax = axes[0]
x = np.arange(len(rep_quality))
width = 0.35
ax.bar(x - width/2, rep_quality['mean_norm'], width, label='Mean Norm', alpha=0.8)
ax.bar(x + width/2, rep_quality['std_norm'], width, label='Std Norm', alpha=0.8)
ax.set_xticks(x)
ax.set_xticklabels(rep_quality['group'])
ax.set_ylabel('Norm')
ax.set_title('Embedding Norm by Group')
ax.legend()

# Downstream accuracy
ax = axes[1]
colors = ['steelblue' if gap < 0.05 else 'coral' for gap in downstream['accuracy_gap']]
ax.bar(downstream['group'], downstream['accuracy'], color=colors, alpha=0.8)
ax.set_ylabel('Accuracy')
ax.set_title('Downstream Task Accuracy\n(Red = >5% gap from best)')
ax.set_ylim([0, 1])

# Group sizes
ax = axes[2]
ax.pie(rep_quality['fraction'], labels=rep_quality['group'], autopct='%1.1f%%',
       colors=['steelblue', 'coral'])
ax.set_title('Group Distribution')

plt.tight_layout()
plt.show()
Figure 222.1

223 Scalability and Computational Considerations

As embedding systems scale to millions of entities, computational efficiency becomes critical.

223.1 Approximate Evaluation Methods

Scalable evaluation techniques
class ScalableEmbeddingEvaluator:
    """
    Evaluation methods that scale to large embedding matrices.
    
    Key techniques:
    1. Sampling-based metrics
    2. Locality-sensitive hashing for approximate NN
    3. Incremental/streaming evaluation
    """
    
    def __init__(self, embeddings: np.ndarray, sample_size: int = 10000):
        """
        Parameters
        ----------
        embeddings : np.ndarray
            Full embedding matrix
        sample_size : int
            Number of samples for approximate metrics
        """
        self.embeddings = embeddings
        self.n, self.d = embeddings.shape
        self.sample_size = min(sample_size, self.n)
        
        # Pre-compute sample indices
        np.random.seed(42)
        self.sample_idx = np.random.choice(self.n, self.sample_size, replace=False)
        self.sample = embeddings[self.sample_idx]
    
    def approximate_isotropy(self, n_pairs: int = 50000) -> Dict:
        """
        Estimate isotropy metrics using sampling.
        
        Instead of computing all O(n²) pairwise similarities,
        sample random pairs for O(n_pairs) computation.
        """
        # Sample random pairs
        n_sample = min(self.sample_size, int(np.sqrt(n_pairs)))
        idx = np.random.choice(self.n, n_sample, replace=False)
        sample = self.embeddings[idx]
        
        # Normalize
        norms = np.linalg.norm(sample, axis=1, keepdims=True)
        normalized = sample / (norms + 1e-10)
        
        # Compute similarities for sampled pairs
        sim_matrix = normalized @ normalized.T
        upper_tri = np.triu_indices(n_sample, k=1)
        sims = sim_matrix[upper_tri]
        
        # Approximate eigenvalues using randomized SVD
        from sklearn.utils.extmath import randomized_svd
        
        centered = sample - np.mean(sample, axis=0)
        n_components = min(50, self.d, n_sample - 1)
        
        _, s, _ = randomized_svd(centered, n_components=n_components, random_state=42)
        eigenvalues = (s ** 2) / n_sample
        
        # Participation ratio from top eigenvalues
        participation = (np.sum(eigenvalues) ** 2) / np.sum(eigenvalues ** 2)
        
        return {
            'apcs': np.mean(sims),
            'apcs_std': np.std(sims),
            'apcs_confidence_interval': (
                np.percentile(sims, 2.5),
                np.percentile(sims, 97.5)
            ),
            'approximate_participation_ratio': participation,
            'n_samples': n_sample,
            'n_pairs': len(sims)
        }
    
    def approximate_hubness(self, k: int = 10, n_queries: int = 5000) -> Dict:
        """
        Estimate hubness using sampling.
        
        Instead of computing k-NN for all points, sample queries.
        """
        # Sample query points
        query_idx = np.random.choice(self.n, min(n_queries, self.n), replace=False)
        queries = self.embeddings[query_idx]
        
        # Sample database points
        db_idx = np.random.choice(self.n, self.sample_size, replace=False)
        database = self.embeddings[db_idx]
        
        # Compute distances
        distances = cdist(queries, database, metric='euclidean')
        
        # Find k-NN for each query
        knn_indices = np.argsort(distances, axis=1)[:, :k]
        
        # Count k-occurrences
        k_occurrences = np.zeros(self.sample_size)
        for neighbors in knn_indices:
            for neighbor in neighbors:
                k_occurrences[neighbor] += 1
        
        return {
            'skewness': stats.skew(k_occurrences),
            'mean_occurrences': np.mean(k_occurrences),
            'std_occurrences': np.std(k_occurrences),
            'max_occurrences': np.max(k_occurrences),
            'n_queries': len(queries),
            'n_database': len(database)
        }
    
    def streaming_drift_detection(self, new_batch: np.ndarray,
                                   window_size: int = 1000) -> Dict:
        """
        Detect drift in streaming setting.
        
        Compare new batch to recent historical window.
        """
        # Use most recent embeddings as reference
        if self.n < window_size:
            reference = self.embeddings
        else:
            reference = self.embeddings[-window_size:]
        
        # Quick drift metrics
        ref_mean = np.mean(reference, axis=0)
        new_mean = np.mean(new_batch, axis=0)
        
        mean_shift = np.linalg.norm(new_mean - ref_mean)
        
        # Variance comparison
        ref_var = np.mean(np.var(reference, axis=0))
        new_var = np.mean(np.var(new_batch, axis=0))
        var_ratio = new_var / (ref_var + 1e-10)
        
        # Quick KS test on first few dimensions
        ks_stats = []
        for d in range(min(10, self.d)):
            stat, _ = stats.ks_2samp(reference[:, d], new_batch[:, d])
            ks_stats.append(stat)
        
        return {
            'mean_shift': mean_shift,
            'variance_ratio': var_ratio,
            'mean_ks_stat': np.mean(ks_stats),
            'max_ks_stat': np.max(ks_stats),
            'drift_detected': mean_shift > 0.5 or np.max(ks_stats) > 0.3
        }


# Demonstration with large-scale data
print("SCALABLE EVALUATION DEMONSTRATION")
print("=" * 60)

# Simulate large embedding matrix
n_large = 100000
d = 128
print(f"\nSimulating {n_large:,} embeddings of dimension {d}...")

np.random.seed(42)
large_embeddings = np.random.randn(n_large, d)

# Time comparison
import time

# Full evaluation (would be slow)
print("\n--- Timing Comparison ---")

start = time.time()
scalable_eval = ScalableEmbeddingEvaluator(large_embeddings, sample_size=5000)
approx_isotropy = scalable_eval.approximate_isotropy(n_pairs=10000)
approx_time = time.time() - start
print(f"Approximate isotropy ({scalable_eval.sample_size} samples): {approx_time:.2f}s")
print(f"  APCS: {approx_isotropy['apcs']:.4f} ± {approx_isotropy['apcs_std']:.4f}")

# Full evaluation on small subset for comparison
start = time.time()
small_subset = large_embeddings[:5000]
full_isotropy = compute_isotropy_metrics(small_subset)
full_time = time.time() - start
print(f"Full isotropy (5000 samples): {full_time:.2f}s")
print(f"  APCS: {full_isotropy['apcs']:.4f} ± {full_isotropy['apcs_std']:.4f}")

# Streaming drift detection
print("\n--- Streaming Drift Detection ---")
new_batch = np.random.randn(1000, d) + 0.3  # Slightly shifted
drift = scalable_eval.streaming_drift_detection(new_batch)
print(f"Mean shift: {drift['mean_shift']:.4f}")
print(f"Drift detected: {drift['drift_detected']}")

223.2 Memory-Efficient Evaluation

Memory-efficient evaluation for very large embeddings
def chunked_pairwise_similarity(embeddings: np.ndarray, 
                                 chunk_size: int = 1000,
                                 n_samples: int = 100000) -> Dict:
    """
    Compute pairwise similarity statistics without loading full matrix.
    
    Uses chunked computation to limit memory usage.
    """
    n = len(embeddings)
    
    # Normalize once
    norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
    normalized = embeddings / (norms + 1e-10)
    
    # Sample pairs
    n_pairs = min(n_samples, n * (n - 1) // 2)
    
    similarities = []
    pairs_computed = 0
    
    for i in range(0, n, chunk_size):
        chunk_i = normalized[i:i+chunk_size]
        
        for j in range(i, n, chunk_size):
            if pairs_computed >= n_pairs:
                break
                
            chunk_j = normalized[j:j+chunk_size]
            
            # Compute similarities for this chunk pair
            sim_block = chunk_i @ chunk_j.T
            
            if i == j:
                # Same chunk: take upper triangle
                upper = np.triu_indices(len(chunk_i), k=1)
                block_sims = sim_block[upper]
            else:
                # Different chunks: take all
                block_sims = sim_block.flatten()
            
            # Sample from this block
            n_take = min(len(block_sims), n_pairs - pairs_computed)
            if n_take < len(block_sims):
                idx = np.random.choice(len(block_sims), n_take, replace=False)
                block_sims = block_sims[idx]
            
            similarities.extend(block_sims)
            pairs_computed += len(block_sims)
        
        if pairs_computed >= n_pairs:
            break
    
    similarities = np.array(similarities)
    
    return {
        'mean': np.mean(similarities),
        'std': np.std(similarities),
        'median': np.median(similarities),
        'percentile_5': np.percentile(similarities, 5),
        'percentile_95': np.percentile(similarities, 95),
        'n_pairs': len(similarities)
    }


# Test memory-efficient computation
print("\n--- Memory-Efficient Pairwise Similarity ---")
mem_efficient_result = chunked_pairwise_similarity(
    large_embeddings[:10000], 
    chunk_size=1000, 
    n_samples=50000
)
print(f"Mean similarity: {mem_efficient_result['mean']:.4f}")
print(f"Std similarity: {mem_efficient_result['std']:.4f}")
print(f"Pairs computed: {mem_efficient_result['n_pairs']:,}")

  1. In 2D or 3D (our intuition), random vectors can point anywhere (e..g, some will be similar, some orthogonal, some opposite). In high dimensions (d = 100, 1000 etc.), random vectors are almost always nearly orthogonal to each other. This is not by design, it’s just by the geometry of high-dimensional space.

    1. Why Does This Happen?

    Cosine similarity between two random unit vectors \(\mathbf{u}, \mathbf{v} \in \mathbb{R}^d\):

    \(\cos(\mathbf{u}, \mathbf{v}) = \sum_{i=1}^d u_i v_i\)

    Each term \(u_i, v_i\) is a product of two independent random values. Some terms are positive, some negative, so they tend to cancel out.

    By the central limit theorem, this sum of \(d\) random terms concentrates around 0, with standard deviation proportional to \(1/\sqrt{d}\)​. So as \(d\) grows:

    • Mean cosine similarity \(\to\) 0

    • Standard deviation \(\to\) 0

    All pairwise similarities concentrate tightly around zero (orthogonal).

    1. Why “Blessing” for Some, “Curse” for Others?

    Blessing: When you WANT things to be separable

    • Hashing/fingerprinting: Random projections stay nearly orthogonal, so collisions are rare

    • Random indexing: You can pack many nearly-orthogonal vectors into the same space

    • Compressed sensing: Random measurements preserve distances well

    If your goal is “keep things distinct,” high dimensions help for free.

    Curse: When you NEED fine-grained similarity

    • Nearest neighbor search: If all distances are nearly equal, “nearest” becomes meaningless

    • Recommendation: If all user-item similarities cluster around 0, you can’t rank preferences

    • Clustering: Distances between clusters vs. within clusters both concentrate, so separation vanishes

    If your goal is “find things that are genuinely close,” the concentration of distances makes this harder.

    1. The Practical Implication for Embeddings

    When embeddings are isotropic (spread evenly like random vectors), you inherit this high-dimensional geometry:

    • Random pairs have similarity \(\approx 0\)

    • Semantically similar pairs stand out (similarity \(\approx\) 0.5 or 0.8)

    • The signal (true similarity) rises above the noise (baseline near-orthogonality)

    When embeddings are anisotropic (clustered along dominant directions):

    • Random pairs have similarity \(\approx\) 0.7 or 0.9 (all pointing similar directions)

    • Semantically similar pairs also have similarity \(\approx\) 0.8 or 0.95

    • Signal and noise overlap, so you can’t distinguish real similarity from the baseline

    That’s why anisotropy kills discriminative power: it raises the “floor” of similarities so high that meaningful differences get lost.↩︎

  2. AUC-ROC evaluates the ranking without committing to any specific threshold. It asks: “If I pick a random positive and random negative, does the positive score higher?” This is computed across all possible thresholds simultaneously. Hence, you’re measuring the quality of the ordering, not the quality of any particular binary decision rule.↩︎

  3. What “\(k\)-th threshold” means for AP:

    The precision-recall curve is constructed by sweeping through thresholds. Imagine sorting all scores descending and moving down the list. At each position \(k\):

    • The implicit threshold is “the score of the \(k\)-th item”
    • \(P_k\) = precision if you predict the top \(k\) as edges
    • \(R_k\) = recall if you predict the top \(k\) as edges

    So “the \(k\)-th threshold” really means “the threshold that would make exactly \(k\) positive predictions” (i.e., it’s iterating through the ranked list position by position, not through explicit threshold values).

    Concrete example:

    AP integrates precision across all these positions, weighting by the change in recall at each step.
    Rank Score True Label \(P_k\) \(R_k\)
    1 0.95 1 (edge) 1/1 = 1.0 1/4 = 0.25
    2 0.87 1 (edge) 2/2 = 1.0 2/4 = 0.50
    3 0.82 0 (non-edge) 2/3 = 0.67 2/4 = 0.50
    4 0.75 1 (edge) 3/4 = 0.75 3/4 = 0.75
    ↩︎
  4. Setup

    Say we have 4 test edges (true connections we’re trying to predict) and 6 negative samples (non-edges). We score all 10 pairs and sort by score descending:

    Rank Node Pair Score True Edge?
    1 (A, B) 0.95 Yes
    2 (C, D) 0.91 No
    3 (E, F) 0.88 No
    4 (G, H) 0.85 Yes
    5 (I, J) 0.82 No
    6 (K, L) 0.78 Yes
    7 (M, N) 0.72 No
    8 (O, P) 0.65 No
    9 (Q, R) 0.58 Yes
    10 (S, T) 0.52 No

    Computing MRR

    We only care about where the true edges landed:

    • True edge (A, B) → Rank 1 → Reciprocal rank = \(\frac{1}{1} = 1.0\)
    • True edge (G, H) → Rank 4 → Reciprocal rank = \(\frac{1}{4} = 0.25\)
    • True edge (K, L) → Rank 6 → Reciprocal rank = \(\frac{1}{6} = 0.167\)
    • True edge (Q, R) → Rank 9 → Reciprocal rank = \(\frac{1}{9} = 0.111\)

    \[\text{MRR} = \frac{1}{4}\left(1.0 + 0.25 + 0.167 + 0.111\right) = \frac{1.528}{4} = 0.382\]

    Interpretation

    MRR = 0.382 means that, on average, you’d find a true edge around rank 2-3 if you went down the list. The reciprocal rank heavily rewards getting true edges at the very top (i.e., first edge at rank 1 contributes 1.0, while the edge at rank 9 contributes only 0.111).

    Comparing good vs. bad rankings:

    The metric directly captures “how quickly do I find true edges if I go down the ranked list?” This is exactly what matters for recommendation systems where users only see the top few suggestions.
    Scenario True edge ranks MRR
    Perfect 1, 2, 3, 4 \(\frac{1}{4}(1 + 0.5 + 0.33 + 0.25) = 0.52\)
    Our example 1, 4, 6, 9 0.38
    Terrible 7, 8, 9, 10 \(\frac{1}{4}(0.14 + 0.125 + 0.11 + 0.1) = 0.12\)
    ↩︎
  5. Using the same ranked list:

    We have 4 true edges total. Hits@k asks: “What fraction of true edges appear in the top k?”
    Rank Node Pair True Edge?
    1 (A, B) Yes
    2 (C, D) No
    3 (E, F) No
    4 (G, H) Yes
    5 (I, J) No
    6 (K, L) Yes
    7 (M, N) No
    8 (O, P) No
    9 (Q, R) Yes
    10 (S, T) No

    Hits@3

    Look at top 3 positions: ranks 1, 2, 3

    True edges found: just (A, B) → 1 edge

    \[\text{Hits@3} = \frac{1}{4} = 0.25\]

    Hits@5

    Look at top 5 positions: ranks 1, 2, 3, 4, 5

    True edges found: (A, B) and (G, H) → 2 edges

    \[\text{Hits@5} = \frac{2}{4} = 0.50\]

    Hits@10

    Look at all 10 positions

    True edges found: all 4

    \[\text{Hits@10} = \frac{4}{4} = 1.0\]

    Interpretation

    If your recommendation system shows 5 suggestions, Hits@5 = 0.50 means users would see half of the relevant items. Simple and directly actionable: if Hits@10 is great but Hits@3 is poor, your model ranks true edges in the middle rather than the top.↩︎