🧮 Brain Teaser
ML/Stats
Stein's Paradox
2026-03-29
✏︎

Stein's Paradox

Problem

You observe a single sample XN(μ,Id)X \sim \mathcal{N}(\mu, I_d) where μRd\mu \in \mathbb{R}^d is unknown. You want to estimate μ\mu under squared error loss: L(μ^,μ)=μ^μ2L(\hat{\mu}, \mu) = \|\hat{\mu} - \mu\|^2

The obvious estimator is the MLE: μ^MLE=X\hat{\mu}_{\mathrm{MLE}} = X, which has risk MSE=d\mathrm{MSE} = d.

For d=1d = 1 and d=2d = 2: the MLE is optimal — no estimator can uniformly beat it.

For d3d \geq 3: show that the MLE is inadmissible, i.e., exhibit an estimator μ~\tilde{\mu} such that Eμ~(X)μ2<dfor all μRd.E\|\tilde{\mu}(X) - \mu\|^2 < d \quad \text{for all } \mu \in \mathbb{R}^d.

The James–Stein estimator is: μ^JS=(1d2X2)X\hat{\mu}_{JS} = \left(1 - \frac{d-2}{\|X\|^2}\right) X

Show that Eμ^JSμ2<dE\|\hat{\mu}_{JS} - \mu\|^2 < d for all μ\mu when d3d \geq 3.


Field

Statistics / Machine Learning

Why It's Beautiful

This is one of the most startling results in all of statistics. It says: even if you are estimating d3d \geq 3 completely unrelated quantities (e.g., the temperature in Toronto, the GDP of Peru, and the mass of Jupiter), you should shrink all your estimates toward zero together — and this provably beats treating each problem independently.

The result shattered the intuition that "optimal estimation of independent quantities should be done independently." It led directly to the development of empirical Bayes methods, regularization (ridge regression shrinks toward zero for exactly this reason), and shrinkage estimators throughout modern statistics and ML.

Efron called it "the most striking result in post-war mathematical statistics."

Key Idea / Trick

Use Stein's identity: for XN(μ,Id)X \sim \mathcal{N}(\mu, I_d) and any weakly differentiable g:RdRdg: \mathbb{R}^d \to \mathbb{R}^d: EXμ, g(X)=E[g(X)]E\langle X - \mu,\ g(X)\rangle = E[\nabla \cdot g(X)]

Write μ^JS=X+g(X)\hat{\mu}_{JS} = X + g(X) with g(X)=d2X2Xg(X) = -\frac{d-2}{\|X\|^2} X, expand the squared loss, and apply the identity to evaluate the cross-term. The risk drops below dd precisely because d2>0d - 2 > 0.

Difficulty

4 / 5

Tags

Statistics, Estimation, Admissibility, James-Stein, Shrinkage, Stein's identity, Empirical Bayes, Regularization, MSE

AdmissibilityJames-SteinShrinkageMSERegularization

Stein's Paradox — Answer

Setup

Let XN(μ,Id)X \sim \mathcal{N}(\mu, I_d). Write the James–Stein estimator as: μ^JS=X+g(X),g(X)=d2X2X\hat{\mu}_{JS} = X + g(X), \qquad g(X) = -\frac{d-2}{\|X\|^2}\, X

The risk of any estimator X+g(X)X + g(X) expands as: EX+g(X)μ2=EXμ2+2EXμ,g(X)+Eg(X)2E\|X + g(X) - \mu\|^2 = E\|X - \mu\|^2 + 2E\langle X - \mu,\, g(X)\rangle + E\|g(X)\|^2 = d + 2E\langle X - \mu,\, g(X)\rangle + E\|g(X)\|^2 \tag{1}


Stein's Identity

Lemma (Stein, 1981). If XN(μ,Id)X \sim \mathcal{N}(\mu, I_d) and g:RdRdg: \mathbb{R}^d \to \mathbb{R}^d is weakly differentiable with Eg<E\|\nabla \cdot g\| < \infty, then: EXμ,g(X)=E[g(X)]E\langle X - \mu,\, g(X)\rangle = E[\nabla \cdot g(X)]

Proof sketch for d=1d=1: Integration by parts on the Gaussian density ϕ(x)=e(xμ)2/2/2π\phi(x) = e^{-(x-\mu)^2/2}/\sqrt{2\pi}: E[(Xμ)g(X)]=(xμ)g(x)ϕ(x)dx=g(x)(ϕ(x))(1)dxE[(X-\mu)g(X)] = \int (x-\mu)g(x)\phi(x)\,dx = \int g(x)(-\phi'(x))\,(-1)\,dx Wait — since ϕ(x)=(xμ)ϕ(x)\phi'(x) = -(x-\mu)\phi(x), we get (xμ)ϕ(x)=ϕ(x)(x-\mu)\phi(x) = -\phi'(x). Integrate by parts: g(x)(xμ)ϕ(x)dx=g(x)ϕ(x)dx=g(x)ϕ(x)dx=E[g(X)]\int g(x)(x-\mu)\phi(x)\,dx = -\int g(x)\phi'(x)\,dx = \int g'(x)\phi(x)\,dx = E[g'(X)]


Computing the Cross-Term

For g(X)=d2X2Xg(X) = -\frac{d-2}{\|X\|^2} X, compute the divergence:

g(X)=(d2X2X)=(d2)XX2\nabla \cdot g(X) = \nabla \cdot \left(-\frac{d-2}{\|X\|^2} X\right) = -(d-2)\, \nabla \cdot \frac{X}{\|X\|^2}

Using xixix2=x22xi2x4\frac{\partial}{\partial x_i}\frac{x_i}{\|x\|^2} = \frac{\|x\|^2 - 2x_i^2}{\|x\|^4} and summing over ii:

XX2=dX22X2X4=d2X2\nabla \cdot \frac{X}{\|X\|^2} = \frac{d\|X\|^2 - 2\|X\|^2}{\|X\|^4} = \frac{d-2}{\|X\|^2}

So by Stein's identity: EXμ,g(X)=E[(d2)d2X2]=(d2)2E1X2E\langle X-\mu,\, g(X)\rangle = E\left[-(d-2)\cdot\frac{d-2}{\|X\|^2}\right] = -(d-2)^2\, E\frac{1}{\|X\|^2}


Computing the Squared Norm Term

g(X)2=(d2)2X4X2=(d2)2X2\|g(X)\|^2 = \frac{(d-2)^2}{\|X\|^4}\|X\|^2 = \frac{(d-2)^2}{\|X\|^2}


Putting It Together

Substituting into (1)(1): Eμ^JSμ2=d2(d2)2E1X2+(d2)2E1X2E\|\hat{\mu}_{JS} - \mu\|^2 = d - 2(d-2)^2 E\frac{1}{\|X\|^2} + (d-2)^2 E\frac{1}{\|X\|^2} =d(d2)2E1X2= d - (d-2)^2\, E\frac{1}{\|X\|^2}

Since X2>0\|X\|^2 > 0 a.s., we have E[1/X2]>0E[1/\|X\|^2] > 0. Therefore, when d3d \geq 3:

Eμ^JSμ2=d(d2)2E1X2<d\boxed{E\|\hat{\mu}_{JS} - \mu\|^2 = d - (d-2)^2\, E\frac{1}{\|X\|^2} < d}

for all μ\mu. The MLE is inadmissible. \blacksquare


Why d=1,2d = 1, 2 Fails

For d=1,2d = 1, 2: (d2)2=0(d-2)^2 = 0 or 11, but d(d2)2E[1/X2]d - (d-2)^2 E[1/\|X\|^2] can become negative (worse than MLE) for some μ\mu when d<3d < 3 — the identity doesn't yield a uniform improvement. In fact for d=1,2d = 1, 2 the MLE is admissible.

The phase transition at d=3d = 3 is sharp and still not fully "intuitively explained" — it's one of those results that is mathematically clear but conceptually mysterious.


Connection to Ridge Regression

Ridge regression estimates β^=(XTX+λI)1XTy\hat{\beta} = (X^TX + \lambda I)^{-1}X^Ty, which shrinks coefficients toward zero. This is exactly Stein shrinkage in disguise — ridge is justified not just as regularization against overfitting, but as a provably better estimator in the MSE sense when the number of parameters 3\geq 3.

Type: ML/StatsEdit on GitHub ↗