Efficient leave one out cross validation - part 1

The derivation and implementation of a method for leave one out cross validation with neglible extra runtime compared to fitting alone.
Author

Tom Shlomo

Published

February 27, 2024

Cross-validation is a crucial technique in assessing the performance of machine learning models. K-fold cross-validation, a widely-used method, involves dividing the dataset into K subsets, training the model K times, each time using a different subset as the testing set. This helps us gauge how well our model generalizes to unseen data. However, as K increases so does the computational time. This becomes painfully evident, particularly during hyperparameter tuning, where sluggish fits can be a major bottleneck.

Leave-one-out cross-validation (LOOCV), a special case of K-fold cross-validation where K equals the number of training samples, can offer accurate evaluation but comes at a hefty computational cost, making it less practical for larger datasets and hyperparameter tuning.

For linear models like ordinary least squares and ridge regression, a little-known trick exists to efficiently calculate LOOCV scores. scikit-learn even implements this in it’s RidgeCV estimator. Notably, this same trick extends beyond these linear models to any quadratically regularized least squares regression — a fact not widely recognized.

Taking it a step further, even for non-least-squares models like logistic and Poisson regression, a similar trick can be employed to approximate LOOCV scores efficiently. Intriguingly, the accuracy of this approximation improves with larger datasets, addressing the need for speedup in precisely those scenarios.

In this initial segment, we derive efficient LOOCV for the quadratic scenario and demonstrate its implementation in Python.

In part 2, we will build upon this derivation to cover non-quadratic scenarios and showcase these findings with a practical example dataset.

Notation

We denote the number of samples in the training dataset as \(n\).

The \(m\)-dimensional feature vectors are represented as \(x_1\) to \(x_n\), forming the rows of matrix \(X\).

Targets are denoted as \(y_1\) to \(y_n\), forming the vector \(y\). The model’s prediction for the \(i\)-th training sample is \(\hat{y}_i = x_i^T \theta\), where \(\theta\) is the coefficients vector. \(\hat{y} = X \theta\) represents the vector containing all predictions.

We fit \(\theta\) to the training data by minimizing the combined loss and regularization terms: \[ \theta := \arg\min_{\theta'} f(\theta'). \tag{1}\] where \[ f(\theta') := \sum_{i=1}^{n} l(x_i^T \theta'; y_i) + r(\theta'). \] Here, \(l(\hat{y}_i; y_i)\) represents the loss function, quantifying the difference between the prediction \(\hat{y}\) and the true target \(y_i\), while \(r\) is the regularization function. We assume \(l\) (as a function of \(\hat{y}_i\)) and \(r\) are convex and twice differentiable. Special cases of this model include ordinary least squares (\(l(\hat{y}_i; y_i) = (\hat{y}_i - y_i)^2\), \(r(\theta') = 0\)), ridge regression (\(l(\hat{y}_i; y_i) = (\hat{y}_i - y_i)^2\), \(r(\theta') = \alpha \| \theta' \|^2\)), logistic regression (\(l(\hat{y}_i;y_i) = \log \left( 1 + e^{-y_i \hat{y}_i}\right)\) with \(y_i \in \{ -1, 1\}\)), and Poisson regression (\(l(\hat{y}_i;y_i) = y_i \hat{y}_i - e^{\hat{y}_i}\)).

To denote the coefficients obtained by excluding the \(j\)-th example, we use \(\theta^{(j)}\): \[ \theta^{(j)} = \arg\min_{\theta'} f^{(j)} (\theta') \] where \[ f^{(j)}(\theta') := \sum_{i \neq j} l(x_i^T \theta'; y_i) + r(\theta') \] Similarly, \(X^{(j)}\) and \(y^{(j)}\), represent \(X\) and \(y\) with the \(j\)-th row removed, respectively. We denote by \(\tilde{y}_j\) the predicted label for sample \(j\) when it is left out: \[ \tilde{y}_j := x_j ^T \theta^{(j)} \tag{2}\] Our goal is calculating \(\tilde{y}_j\), for all \(j\), efficiently.

Deriving efficient LOOCV for the quadratic case

In scenarios where the loss function is the sum of squares loss, \[ l(\hat{y}_i; y_i) = (\hat{y}_i - y_i)^2, \] and the regularizer is quadratic \[ r(\theta') = \theta'^T R \theta' \] where \(R\) is an \(m \times m\) semi-positive definite matrix, the solution to the optimization problem Equation 1 is obtained by solving the linear equation 1: \[ A \theta = b. \tag{3}\] where \[\begin{align*} A &:= X^T X + R \\ b &:= X^T y. \end{align*}\]

Similarly, obtaining \(\theta^{(j)}\) requires solving \[ A^{(j)} \theta^{(j)} = b^{(j)}. \tag{4}\] where \[\begin{align*} A^{(j)} &:= X^{(j)T} X^{(j)} + R \\ b^{(j)} &:= X^{(j)T} y^{(j)}. \end{align*}\] Forming and solving Equation 4 for each \(j\) has a time complexity of \(O(m^3 + n m^2)\). Thus, in a naive implementation, the overall complexity of LOOCV becomes \(O(n m^3 + n^2 m^2)\), posing a significant computational challenge, particularly when \(n\) is large.

Efficient LOOCV leverages the solution for Equation 3 to calculate the solution for Equation 4. We exploit the idea from computational linear algebra that solving multiple \(m\) by \(m\) equations with the same matrix has a time complexity similar to solving a single such equation. Thus, we solve, in addition to Equation 3, the following \(n\) equations: \[ A t_j = x_j. \]

We start by noting that \[\begin{align*} X^TX &= X^{(j)^T} X^{(j)} + x_j x_j^T \\ X^Ty &= X^{(j)^T} y^{(j)} + x_j y_j, \end{align*}\] so we can write Equation 4 like so: \[ (A - x_j x_j^T) \theta^{(j)} = b - x_j y_j. \] The usual way forward involves employing Sherman-Morrison formula, solving for \(\theta^{(j)}\) and substituting it in Equation 2 to obtain an expression for \(\tilde{y}\). However, there’s a better approach 2: We rewrite Equation 4 as \[\begin{align*} A \theta^{(j)} - x_j \tilde{y}_j &= b - x_j y_j \\ \tilde{y}_j &= x_j ^T \theta^{(j)} \end{align*}\] so instead of a single equation with one unknown (\(\theta^{(j)}\)), we now have two equations with two unknowns (\(\theta^{(j)}\) and \(\tilde{y}_j\)). At first this seems more complicated, but notice that since the coefficient of \(\theta^{(j)}\) in the first equation is \(A\), we can eliminate it: \[\begin{align*} \theta^{(j)} &= A^{-1} ( b - x_j y_j + x_j \tilde{y}_j ) \\ &= \theta - t_j ( y_j - \tilde{y}_j ) \end{align*}\] substituting in the bottom equation, we can solve for \(\tilde{y}_j\): \[\begin{align*} \tilde{y}_j &= x_j ^T \left( \theta - t_j ( y_j - \tilde{y}_j ) \right) \\ \tilde{y}_j &= \hat{y}_j - h_j (y_j - \tilde{y}_j) \\ \tilde{y}_j &= \frac{\hat{y}_j - h_j y_j}{1-h_j} % \\ % \tilde{y}_j &= \frac{\hat{y}_j -h_j \hat{y}_j + h_j \hat{y}_j - h_j y_j}{1-h_j} \\ \tilde{y}_j &= \hat{y}_j + \frac{h_j }{1-h_j} \left( \hat{y}_j - y_j \right) % \\ % \tilde{y}_j &= \frac{\hat{y}_j - y_j}{1-h_j} + y_j \end{align*}\] where \[ h_j := x_j ^T t_j. \]

Reminder

\(y_j\) is the true label.
\(\hat{y}_j\) is the prediction using all the data.
\(\tilde{y}_j\) is the leave-one-out prediction.

That’s it! we got an expression for \(\tilde{y}_j\) that doesn’t require inverting any matrix other than \(A\). It also has a nice interpretation: the difference between the prediction and the LOO prediction is the difference between the prediction an the true label, “amplified” by \(\frac{h_j }{1-h_j}\).

Python implementation

The approach outlined above adapts seamlessly into code. We’ll construct an estimator resembling the sklearn style, featuring standard fit and predict methods, alongside a function to compute \(\tilde{y}\), the leave-one-out predictions:

from typing import Self

import numpy as np
import scipy


class LinearRegressionWithQuadraticRegularization:
    def __init__(self, R) -> None:
        self.R = R

    def fit(self, X, y) -> Self:
        A = X.T @ X + self.R
        b = X.T @ y
        self.theta_ = scipy.linalg.solve(
            A,
            b,
            overwrite_a=True,
            overwrite_b=True,
            assume_a="pos",
        )
        return self

    def predict(self, X) -> np.ndarray:
        return X @ self.theta_

    def fit_loocv_predict(self, X, y) -> np.ndarray:
        A = X.T @ X + self.R
        b = X.T @ y
        temp = scipy.linalg.solve(
            A,
            np.vstack([b, X]).T,
            overwrite_a=True,
            overwrite_b=True,
            assume_a="pos",
        )
        self.theta_ = temp[:, 0]
        t = temp[:, 1:]
        h = np.einsum("ij,ji->i", X, t)  # h[i] = np.dot(X[i, :], t[:, i])
        y_hat = self.predict(X)
        return y_hat + (h / (1 - h)) * (y_hat - y)

Let’s check that our method for calculating the leave-one-out predictions is correct on random data, and compare it’s run time to the usual leave-one-out procedure.

from sklearn.model_selection import LeaveOneOut


def standard_loocv(model, X, y) -> np.ndarray:
    y_tilde = np.empty_like(y)
    for i, (train_index, test_index) in enumerate(LeaveOneOut().split(X)):
        X_loo = X[train_index, :]
        y_loo = y[train_index]
        model.fit(X_loo, y_loo)
        y_tilde[i] = model.predict(X[test_index, :])[0]
    return y_tilde


def gen_random_data(n: int, m: int) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    rng = np.random.default_rng(42)
    X = rng.standard_normal((n, m))
    L = rng.standard_normal((m, m))
    theta = L @ rng.standard_normal(m)
    y = X @ theta + rng.standard_normal(n)
    R = L @ L.T  # random positive definite matrix
    return X, y, R


X, y, R = gen_random_data(n=100, m=10)
model = LinearRegressionWithQuadraticRegularization(R=R)
print(
    f"max absolute error: {np.max(np.abs(model.fit_loocv_predict(X, y) - standard_loocv(model, X, y))):.3e}"
)
max absolute error: 1.243e-14

Good, the two methods to calculate \(\tilde{y}\) give the same result. Let’s also compare the runtime:

%timeit model.fit_loocv_predict(X, y) 
%timeit standard_loocv(model, X, y)
34.6 µs ± 1.35 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
2.39 ms ± 10.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Nice, a significant speedup. But that’s quite fast to begin with. Let’s increase n and m:

X, y, R = gen_random_data(n=1000, m=50)
model = LinearRegressionWithQuadraticRegularization(R=R)
print(f'max absolute error: {np.max(np.abs(model.fit_loocv_predict(X, y) - standard_loocv(model, X, y))):.3e}')
%timeit model.fit_loocv_predict(X, y) 
%timeit standard_loocv(model, X, y)
max absolute error: 8.527e-14
138 ms ± 16.4 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
The slowest run took 4.24 times longer than the fastest. This could mean that an intermediate result is being cached.
822 ms ± 461 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Hmm… Much less impressive. In theory the speedup should improve as the problem size increases. This is likely due to some python inefficiencies, not the algorithm itself. Let’s try to improve by using JAX’s just-in-time compilation feature:

import jax


class JitLinearRegressionWithQuadraticRegularization:
    def __init__(self, R) -> None:
        self.R = R

    def fit(self, X, y) -> Self:
        self.theta_ = self._fit(X, y, self.R)
        return self

    def predict(self, X) -> np.ndarray:
        return self._predict(X, self.theta_)

    def fit_loocv_predict(self, X, y) -> np.ndarray:
        self.theta_, y_tilde = self._fit_loocv_predict(X, y, self.R)
        return y_tilde
    
    @staticmethod
    @jax.jit
    def _fit(X, y, R) -> np.ndarray:
        return jax.scipy.linalg.solve(
            X.T @ X + R, 
            X.T @ y,
            overwrite_a=True,
            overwrite_b=True,
            assume_a="pos",
        )

    @staticmethod
    @jax.jit
    def _predict(X, theta) -> np.ndarray:
        return X @ theta

    @staticmethod
    @jax.jit
    def _fit_loocv_predict(X, y, R) -> np.ndarray:
        temp = jax.scipy.linalg.solve(
            X.T @ X + R,
            jax.numpy.vstack([X.T @ y, X]).T,
            overwrite_a=True,
            overwrite_b=True,
            assume_a="pos",
        )
        theta = temp[:, 0]
        t = temp[:, 1:]
        h = jax.numpy.einsum("ij,ji->i", X, t)  # h[i] = np.dot(X[i, :], t[:, i])
        y_hat = X @ theta
        return theta, y_hat + (h / (1 - h)) * (y_hat - y)
    
model = JitLinearRegressionWithQuadraticRegularization(R=R)
print(f'max absolute error: {np.max(np.abs(model.fit_loocv_predict(X, y) - standard_loocv(model, X, y))):.3e}')
%timeit model.fit_loocv_predict(X, y).block_until_ready()
%timeit standard_loocv(model, X, y)
max absolute error: 4.780e-05
1.75 ms ± 232 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
353 ms ± 11.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Much better!

Footnotes

  1. I am deliberately avoiding writing \(\theta = A^{-1} b\), as \(A\) does not have to be invertible for this equation to have a solution, and it allows me to avoid the usual “assuming full rank” caveats people tend to use here. Furthermore, it can mislead people into implementations like np.linalg.inv(A) @ b, which are less stable and efficient than implementations like np.linalg.solve(A, b).↩︎

  2. This approach translates better into code, as we get the expression for \(\tilde{y}_j\) directly, without going through an expression for \(\theta^{(j)}\) first. I also think Sherman-Morisson is a bit too strong here and can obscure some insights, so it’s nice to avoid it. But actually the other approach is just halfway it’s proof (see for example here).↩︎