Efficient Leave One Out Cross Validation - Part 2

The non quadratic case
Author

Tom Shlomo

Published

March 30, 2024

In the first part, we developed a method for performing efficient leave-one-out cross-validation (LOOCV). This method was precise but mandated that the loss and regularization functions be quadratic. Here, we’ll introduce a similar technique that provides an approximation, but eliminates the necessity for the loss and regularization to be quadratic. Additionally, we’ll code this method in Python using JAX and showcase its application on a sample dataset.

Notation (same as part 1)

We denote the number of samples in the training dataset as \(n\).

The \(m\)-dimensional feature vectors are represented as \(x_1\) to \(x_n\), forming the rows of matrix \(X\).

Targets are denoted as \(y_1\) to \(y_n\), forming the vector \(y\). The model’s prediction for the \(i\)-th training sample is \(\hat{y}_i = x_i^T \theta\), where \(\theta\) is the coefficients vector. \(\hat{y} = X \theta\) represents the vector containing all predictions.

We fit \(\theta\) to the training data by minimizing the combined loss and regularization terms: \[ \theta := \arg\min_{\theta'} f(\theta'). \tag{1}\] where \[ f(\theta') := \sum_{i=1}^{n} l(x_i^T \theta'; y_i) + r(\theta'). \] Here, \(l(\hat{y}_i; y_i)\) represents the loss function, quantifying the difference between the prediction \(\hat{y}\) and the true target \(y_i\), while \(r\) is the regularization function. We assume \(l\) (as a function of \(\hat{y}_i\)) and \(r\) are convex and twice differentiable. Special cases of this model include ordinary least squares (\(l(\hat{y}_i; y_i) = (\hat{y}_i - y_i)^2\), \(r(\theta') = 0\)), ridge regression (\(l(\hat{y}_i; y_i) = (\hat{y}_i - y_i)^2\), \(r(\theta') = \alpha \| \theta' \|^2\)), logistic regression (\(l(\hat{y}_i;y_i) = \log \left( 1 + e^{-y_i \hat{y}_i}\right)\) with \(y_i \in \{ -1, 1\}\)), and Poisson regression (\(l(\hat{y}_i;y_i) = y_i \hat{y}_i - e^{\hat{y}_i}\)).

To denote the coefficients obtained by excluding the \(j\)-th example, we use \(\theta^{(j)}\): \[ \theta^{(j)} = \arg\min_{\theta'} f^{(j)} (\theta') \] where \[ f^{(j)}(\theta') := \sum_{i \neq j} l(x_i^T \theta'; y_i) + r(\theta') \] Similarly, \(X^{(j)}\) and \(y^{(j)}\), represent \(X\) and \(y\) with the \(j\)-th row removed, respectively. We denote by \(\tilde{y}_j\) the predicted label for sample \(j\) when it is left out: \[ \tilde{y}_j := x_j ^T \theta^{(j)} \tag{2}\] Our goal is calculating \(\tilde{y}_j\), for all \(j\), efficiently.

Deriving efficient LOOCV for the non-quadratic case

In this section, we extend our approach to scenarios where \(l\) or \(r\) are not quadratic. Although solving equation Equation 1 is not simplified to solving a linear equation as it did in part 1, we can resort to the following approximation: \[ H^{(j)} (\theta^{(j)} - \theta) \approx -g^{(j)} \tag{3}\] where \(H^{(j)}\) and \(g^{(j)}\) represent the Hessian and gradient of \(f^{(j)}\) at \(\theta\), respectively. The rationale here is that \(\theta\) and \(\theta^{(j)}\) should be relatively close (and closer as \(n\) increases), making it likely that Newton’s method on \(f^{(j)}\) converges in a single iteration when initialized on \(\theta\).

Similar to the quadratic case, we can relate \(H^{(j)}\) and \(g^{(j)}\) to \(H\) and \(g\), the Hessian and gradient of \(f\) at \(\theta\): \[\begin{align*} H^{(j)} &= H - x_j l''(\hat{y}_i ; y_i) x_j^T \\ g^{(j)} &= g - x_j l'(\hat{y}_i ; y_i) = - x_j l'(\hat{y}_i ; y_i) \end{align*}\] allowing us to rewrite Equation 3 as: \[ \left( H - x_j l''\left(\hat{y}_i ; y_i\right) x_j^T \right) \left( \theta^{(j)} - \theta \right) \approx x_j l'(\hat{y}_i ; y_i). \] Next, we introduce the second equation: \[\begin{align*} H \theta^{(j)} - x_j l''(\hat{y}_i ; y_i) \tilde{y}_j - H \theta + x_j l''(\hat{y}_i ; y_i) \hat{y}_j &\approx x_j l'(\hat{y}_i ; y_i) \\ \tilde{y}_j &= x_j ^T \theta^{(j)}. \end{align*}\] Now, we can eliminate \(\theta^{(j)}\) and solve for \(\tilde{y}_j\): \[\begin{align*} \theta^{(j)} &\approx \theta + t_j (l'(\hat{y}_i ; y_i) + l''(\hat{y}_i ; y_i) (\tilde{y}_j - \hat{y}_j)) \\ \tilde{y}_j &\approx x_j ^T \left( \theta + t_j (l'(\hat{y}_i ; y_i) + l''(\hat{y}_i ; y_i) (\tilde{y}_j - \hat{y}_j)) \right) \\ \tilde{y}_j &\approx \hat{y}_j + \frac{h_j}{1 - h_j l''(\hat{y}_i ; y_i)} l'(\hat{y}_i ; y_i) \end{align*}\] where \(t_j := H^{-1} x_j\) and \(h_j := x_j^T t_j\).

It’s worth noting the resemblance between the expression for \(\tilde{y}_j\) here and the expression obtained for the quadratic case.

Python implementation

Once more, we’ll turn to JAX, leveraging its automatic differentiation capabilities. Our estimator will take as inputs the loss and regularization functions, along with an optional “inverse link” function. This function can be employed to transform the predicted labels (e.g. a sigmoid to convert log-odds to probabilities in logistic regression, or an exponent to convert log-rate to rate in Poisson regression).

from typing import Callable

import jax
import numpy as np
import numpy.typing as npt
import scipy

Array = npt.NDArray[np.float64]


class GLMWithLOOCV:
    def __init__(
        self,
        loss: Callable[[Array, Array], Array],
        reg: Callable[[Array], float],
        inverse_link: Callable[[Array], Array],
    ) -> None:
        self.loss = loss
        self.reg = reg
        self.inverse_link = inverse_link

    def f(self, theta: Array, X: Array, y: Array) -> float:
        y_hat = X @ theta
        return self.loss(y_hat, y).sum() + self.reg(theta)

    def fit(self, X: Array, y: Array):
        # We optimize f with L-BFGS-B as it has reasonable performance with the data below,
        #  but any other convex optimization algorithm can be used here.
        result = scipy.optimize.minimize(
            jax.value_and_grad(lambda theta: self.f(theta, X, y)),
            x0=np.zeros(X.shape[1]),
            method="L-BFGS-B",
            jac=True,
        )
        self.theta_ = result.x
        return self

    def predict(self, X: Array) -> Array:
        return self.inverse_link(X @ self.theta_)

    def fit_loocv_predict(self, X: Array, y: Array) -> Array:
        self.fit(X, y)
        y_hat = X @ self.theta_
        l_prime = jax.vmap(jax.grad(self.loss, argnums=0))(y_hat, y)
        l_prime_prime = jax.vmap(jax.hessian(self.loss, argnums=0))(y_hat, y)
        H = jax.hessian(self.f, argnums=0)(self.theta_, X, y)
        t = scipy.linalg.solve(
            H,
            X.T,
            overwrite_a=True,
            assume_a="pos",
        )
        h = np.einsum("ij,ji->i", X, t)
        return self.inverse_link(y_hat + (h / (1 - h * l_prime_prime)) * l_prime)

Example

To illustrate the concepts discussed above, we train a classifier on a dataset for predicting heart disease events. From a quick glance over Kaggle, it appears that achieving an AUC of approximately 0.9 is feasible.

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler, FunctionTransformer
from sklearn import set_config
from sklearn.metrics import roc_auc_score

df = pd.read_csv("data/heart.csv")
df
Age Sex ChestPainType RestingBP Cholesterol FastingBS RestingECG MaxHR ExerciseAngina Oldpeak ST_Slope HeartDisease
0 40 M ATA 140 289 0 Normal 172 N 0.0 Up 0
1 49 F NAP 160 180 0 Normal 156 N 1.0 Flat 1
2 37 M ATA 130 283 0 ST 98 N 0.0 Up 0
3 48 F ASY 138 214 0 Normal 108 Y 1.5 Flat 1
4 54 M NAP 150 195 0 Normal 122 N 0.0 Up 0
... ... ... ... ... ... ... ... ... ... ... ... ...
913 45 M TA 110 264 0 Normal 132 N 1.2 Flat 1
914 68 M ASY 144 193 1 Normal 141 N 3.4 Flat 1
915 57 M ASY 130 131 0 Normal 115 Y 1.2 Flat 1
916 57 F ATA 130 236 0 LVH 174 N 0.0 Flat 1
917 38 M NAP 138 175 0 Normal 173 N 0.0 Up 0

918 rows × 12 columns

y = df["HeartDisease"]
X = df.drop(columns=["HeartDisease"])
X["one"] = 1.0  # an all-ones column to implicitly fit an intercept term
X = pd.get_dummies(X, drop_first=True)
x_train, x_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=42
)
print(f"{x_train.shape=}, {x_test.shape=}")
x_train.shape=(642, 16), x_test.shape=(276, 16)

Our approach involves using stratified logistic regression with a combination of Laplacian and sum of squares regularization. While this may not be the optimal model for this specific problem, it serves well for demonstrating the concepts.

In our model, we stratify over the sex of the patient, meaning we fit two coefficient vectors: one for males and one for females. The Laplacian regularization promotes similarity between the coefficient vectors for females and males. You can read more about stratified models with Laplacian regularization here.

def stratify(X: Array) -> Array:
    z = X["Sex_M"].values[:, np.newaxis]
    X = X.drop(columns=["Sex_M"]).astype(float).values
    return np.hstack([X * z, X * ~z])


transformer = Pipeline(
    [
        (
            "scale",
            ColumnTransformer(
                [
                    ("none", "passthrough", ["one", "Sex_M"]),
                    (
                        "scale",
                        StandardScaler(),
                        list(set(x_train.columns) - {"one", "Sex_M"}),
                    ),
                ],
                verbose_feature_names_out=False,
            ).set_output(transform="pandas"),
        ),
        ("stratify", FunctionTransformer(stratify)),
    ]
)

x_train = transformer.fit_transform(x_train)
x_test = transformer.transform(x_test)

Next we define the regularization matrices:

m = x_train.shape[1]
laplacian = np.array([[1, -1], [-1, 1]])
laplacian = np.kron(laplacian, np.eye(m // 2))
ridge = np.eye(m)
ridge[0] = 0  # no penalty on the intercept
ridge[m // 2] = 0

We have two hyperparameters in our model: the strength of the sum of squares (ridge) regularization and the strength of the Laplacian regularization.

For hyperparameter optimization, we utilize Optuna.

import optuna


def model_factory(alpha: float, beta: float) -> GLMWithLOOCV:
    R = alpha * ridge + beta * laplacian
    return GLMWithLOOCV(
        loss=lambda y_hat, y: -jax.nn.log_sigmoid((y * 2 - 1) * y_hat),
        reg=lambda theta: theta.T @ R @ theta,
        inverse_link=jax.nn.sigmoid,
    )


def objective(trial: optuna.Trial):
    alpha = trial.suggest_float("alpha", 1e-6, 1e3, log=True)
    beta = trial.suggest_float("beta", 1e-6, 1e3, log=True)
    model = model_factory(alpha, beta)
    y_tilde = model.fit_loocv_predict(x_train, y_train.values)
    return -roc_auc_score(
        y_train, y_tilde
    )  # minus since optuna minimizes the objective and we need to maximize


study: optuna.Study = optuna.create_study()
study.optimize(objective, n_trials=50)
model = model_factory(**study.best_params)
roc_auc_score(
    y_test,
    model.fit(x_train, y_train.values).predict(x_test),
)
[I 2024-03-24 08:56:29,175] A new study created in memory with name: no-name-61410be4-4d1d-4a9b-8cc6-068004336d5c
[I 2024-03-24 08:56:30,700] Trial 0 finished with value: -0.9092301389105666 and parameters: {'alpha': 0.012775258780126265, 'beta': 0.002191596541067304}. Best is trial 0 with value: -0.9092301389105666.
[I 2024-03-24 08:56:31,058] Trial 1 finished with value: -0.909025284844701 and parameters: {'alpha': 0.0006298752554559054, 'beta': 2.1509611015314225e-05}. Best is trial 0 with value: -0.9092301389105666.
[I 2024-03-24 08:56:31,605] Trial 2 finished with value: -0.9118639769002653 and parameters: {'alpha': 4.938878297009436e-06, 'beta': 780.3122728275399}. Best is trial 2 with value: -0.9118639769002653.
[I 2024-03-24 08:56:31,926] Trial 3 finished with value: -0.914858748244108 and parameters: {'alpha': 3.1356031871802617, 'beta': 30.256480198543056}. Best is trial 3 with value: -0.914858748244108.
[I 2024-03-24 08:56:32,239] Trial 4 finished with value: -0.9104885281723115 and parameters: {'alpha': 0.005919567636794085, 'beta': 0.21420343169280254}. Best is trial 3 with value: -0.914858748244108.
[I 2024-03-24 08:56:32,557] Trial 5 finished with value: -0.9091325893553925 and parameters: {'alpha': 0.0029083107541165135, 'beta': 0.003468355333742003}. Best is trial 3 with value: -0.914858748244108.
[I 2024-03-24 08:56:32,947] Trial 6 finished with value: -0.9103422038395502 and parameters: {'alpha': 0.25663606336322276, 'beta': 2.027770483000284e-06}. Best is trial 3 with value: -0.914858748244108.
[I 2024-03-24 08:56:33,339] Trial 7 finished with value: -0.912537068830966 and parameters: {'alpha': 1.0538711667069038, 'beta': 0.0003360283526189701}. Best is trial 3 with value: -0.914858748244108.
[I 2024-03-24 08:56:33,622] Trial 8 finished with value: -0.9089862650226315 and parameters: {'alpha': 1.079098816654311e-05, 'beta': 1.3846427704434615e-06}. Best is trial 3 with value: -0.914858748244108.
[I 2024-03-24 08:56:34,060] Trial 9 finished with value: -0.9089667551115967 and parameters: {'alpha': 1.2646277638647537e-06, 'beta': 5.240465554229351e-05}. Best is trial 3 with value: -0.914858748244108.
[I 2024-03-24 08:56:34,378] Trial 10 finished with value: -0.8929003433744341 and parameters: {'alpha': 521.943351068861, 'beta': 21.128403850663364}. Best is trial 3 with value: -0.914858748244108.
[I 2024-03-24 08:56:34,601] Trial 11 finished with value: -0.9151416419541125 and parameters: {'alpha': 7.849754690834663, 'beta': 0.6886856542708458}. Best is trial 11 with value: -0.9151416419541125.
[I 2024-03-24 08:56:34,839] Trial 12 finished with value: -0.9127614328078664 and parameters: {'alpha': 36.17696304784756, 'beta': 0.9382223092187879}. Best is trial 11 with value: -0.9151416419541125.
[I 2024-03-24 08:56:35,023] Trial 13 finished with value: -0.9154538005306695 and parameters: {'alpha': 20.819919697832866, 'beta': 18.009933396875475}. Best is trial 13 with value: -0.9154538005306695.
[I 2024-03-24 08:56:35,430] Trial 14 finished with value: -0.9134150148275324 and parameters: {'alpha': 50.2350037043773, 'beta': 3.819223887756615}. Best is trial 13 with value: -0.9154538005306695.
[I 2024-03-24 08:56:35,728] Trial 15 finished with value: -0.8607480099890744 and parameters: {'alpha': 768.6568030874997, 'beta': 0.07013747752246598}. Best is trial 13 with value: -0.9154538005306695.
[I 2024-03-24 08:56:36,291] Trial 16 finished with value: -0.9138930076478852 and parameters: {'alpha': 9.155957878379764, 'beta': 819.8096902730499}. Best is trial 13 with value: -0.9154538005306695.
[I 2024-03-24 08:56:36,608] Trial 17 finished with value: -0.9130540814733884 and parameters: {'alpha': 0.1485485239581707, 'beta': 38.97815021766497}. Best is trial 13 with value: -0.9154538005306695.
[I 2024-03-24 08:56:36,858] Trial 18 finished with value: -0.9053866864367098 and parameters: {'alpha': 78.72829138458891, 'beta': 0.016036059061954787}. Best is trial 13 with value: -0.9154538005306695.
[I 2024-03-24 08:56:37,161] Trial 19 finished with value: -0.9136783986265022 and parameters: {'alpha': 0.1375243107258574, 'beta': 1.8291549312508841}. Best is trial 13 with value: -0.9154538005306695.
[I 2024-03-24 08:56:37,528] Trial 20 finished with value: -0.9129760418292493 and parameters: {'alpha': 1.0707674807216208, 'beta': 117.83261301793227}. Best is trial 13 with value: -0.9154538005306695.
[I 2024-03-24 08:56:37,795] Trial 21 finished with value: -0.9159903230841268 and parameters: {'alpha': 4.416371155644641, 'beta': 9.316946205587612}. Best is trial 21 with value: -0.9159903230841268.
[I 2024-03-24 08:56:38,065] Trial 22 finished with value: -0.9146831590447948 and parameters: {'alpha': 12.426070601603788, 'beta': 0.3853665693475585}. Best is trial 21 with value: -0.9159903230841268.
[I 2024-03-24 08:56:38,316] Trial 23 finished with value: -0.8677228031840175 and parameters: {'alpha': 319.5534395847857, 'beta': 2.7762991166087905}. Best is trial 21 with value: -0.9159903230841268.
[I 2024-03-24 08:56:38,620] Trial 24 finished with value: -0.9158342437958482 and parameters: {'alpha': 3.112413409324069, 'beta': 7.37150462979163}. Best is trial 21 with value: -0.9159903230841268.
[I 2024-03-24 08:56:38,914] Trial 25 finished with value: -0.9150733572654908 and parameters: {'alpha': 0.7868633423237494, 'beta': 8.431547546648925}. Best is trial 21 with value: -0.9159903230841268.
[I 2024-03-24 08:56:39,290] Trial 26 finished with value: -0.9122541751209613 and parameters: {'alpha': 0.03174734240904159, 'beta': 114.13914385182184}. Best is trial 21 with value: -0.9159903230841268.
[I 2024-03-24 08:56:39,658] Trial 27 finished with value: -0.9116103480568128 and parameters: {'alpha': 90.62207805800448, 'beta': 202.6257677752164}. Best is trial 21 with value: -0.9159903230841268.
[I 2024-03-24 08:56:39,917] Trial 28 finished with value: -0.9160781176837834 and parameters: {'alpha': 6.553554396630455, 'beta': 11.167094954503991}. Best is trial 28 with value: -0.9160781176837834.
[I 2024-03-24 08:56:40,310] Trial 29 finished with value: -0.9098349461526456 and parameters: {'alpha': 0.029887003146111108, 'beta': 0.05116561128441574}. Best is trial 28 with value: -0.9160781176837834.
[I 2024-03-24 08:56:40,633] Trial 30 finished with value: -0.9140198220696114 and parameters: {'alpha': 2.942224930336697, 'beta': 0.00513193420243311}. Best is trial 28 with value: -0.9160781176837834.
[I 2024-03-24 08:56:40,934] Trial 31 finished with value: -0.915219681598252 and parameters: {'alpha': 22.919512362299404, 'beta': 12.27623192594196}. Best is trial 28 with value: -0.9160781176837834.
[I 2024-03-24 08:56:41,189] Trial 32 finished with value: -0.9157464491961917 and parameters: {'alpha': 2.757918964666127, 'beta': 6.038269195534732}. Best is trial 28 with value: -0.9160781176837834.
[I 2024-03-24 08:56:41,617] Trial 33 finished with value: -0.9142149211799594 and parameters: {'alpha': 2.838682256536785, 'beta': 0.22519746478261113}. Best is trial 28 with value: -0.9160781176837834.
[I 2024-03-24 08:56:41,949] Trial 34 finished with value: -0.9148392383330732 and parameters: {'alpha': 0.3616241392613928, 'beta': 5.830178783914518}. Best is trial 28 with value: -0.9160781176837834.
[I 2024-03-24 08:56:42,382] Trial 35 finished with value: -0.9120005462775089 and parameters: {'alpha': 0.0008822564977771567, 'beta': 307.9388398609111}. Best is trial 28 with value: -0.9160781176837834.
[I 2024-03-24 08:56:42,562] Trial 36 finished with value: -0.9134930544716716 and parameters: {'alpha': 0.04187871349145953, 'beta': 1.8734690287950477}. Best is trial 28 with value: -0.9160781176837834.
[I 2024-03-24 08:56:42,925] Trial 37 finished with value: -0.9120590760106134 and parameters: {'alpha': 120.58206602832338, 'beta': 38.5688359496023}. Best is trial 28 with value: -0.9160781176837834.
[I 2024-03-24 08:56:43,277] Trial 38 finished with value: -0.9141856563134072 and parameters: {'alpha': 2.86004817515422, 'beta': 0.14090599669599804}. Best is trial 28 with value: -0.9160781176837834.
[I 2024-03-24 08:56:43,658] Trial 39 finished with value: -0.9131028562509754 and parameters: {'alpha': 0.46403986565570593, 'beta': 48.64240100580483}. Best is trial 28 with value: -0.9160781176837834.
[I 2024-03-24 08:56:44,034] Trial 40 finished with value: -0.9092106289995319 and parameters: {'alpha': 3.056882608727012e-05, 'beta': 0.013794522533314474}. Best is trial 28 with value: -0.9160781176837834.
[I 2024-03-24 08:56:44,311] Trial 41 finished with value: -0.9159513032620571 and parameters: {'alpha': 7.453204156267973, 'beta': 18.41659964949616}. Best is trial 28 with value: -0.9160781176837834.
[I 2024-03-24 08:56:44,593] Trial 42 finished with value: -0.916068362728266 and parameters: {'alpha': 5.494711869607197, 'beta': 6.6964359014868124}. Best is trial 28 with value: -0.9160781176837834.
[I 2024-03-24 08:56:44,981] Trial 43 finished with value: -0.9152391915092868 and parameters: {'alpha': 6.789849837478972, 'beta': 0.9088200937340541}. Best is trial 28 with value: -0.9160781176837834.
[I 2024-03-24 08:56:45,353] Trial 44 finished with value: -0.9095422974871235 and parameters: {'alpha': 272.45143282464653, 'beta': 86.15445718504921}. Best is trial 28 with value: -0.9160781176837834.
[I 2024-03-24 08:56:45,689] Trial 45 finished with value: -0.9129857967847667 and parameters: {'alpha': 1.245276289906343, 'beta': 494.7421191070618}. Best is trial 28 with value: -0.9160781176837834.
[I 2024-03-24 08:56:46,090] Trial 46 finished with value: -0.9099032308412675 and parameters: {'alpha': 0.0911593221200086, 'beta': 0.0005329474451518324}. Best is trial 28 with value: -0.9160781176837834.
[I 2024-03-24 08:56:46,388] Trial 47 finished with value: -0.9154050257530826 and parameters: {'alpha': 19.418957034796815, 'beta': 17.869675293948607}. Best is trial 28 with value: -0.9160781176837834.
[I 2024-03-24 08:56:46,745] Trial 48 finished with value: -0.9114152489464649 and parameters: {'alpha': 0.004900207869786724, 'beta': 0.5092857463474738}. Best is trial 28 with value: -0.9160781176837834.
[I 2024-03-24 08:56:47,053] Trial 49 finished with value: -0.9153757608865303 and parameters: {'alpha': 5.6025164558499885, 'beta': 1.4684843094190325}. Best is trial 28 with value: -0.9160781176837834.
0.9398954703832751

Verification of implementation

To assess the accuracy of our implementation, we compare the approximated leave-one-out predictions with the actual leave-one-out predictions:

from time import time

import plotly.express as px
import plotly.io as pio

pio.renderers.default = "notebook"

from sklearn.model_selection import LeaveOneOut

t_start = time()
y_tilde = np.zeros(y_train.shape)
for i, (train_index, val_index) in enumerate(LeaveOneOut().split(x_train)):
    X_loo = x_train[train_index, :]
    y_loo = y_train.values[train_index]
    model.fit(X_loo, y_loo)
    y_tilde[i] = model.predict(x_train[val_index, :])[0]
standard_loocv_runtime = time() - t_start

t_start = time()
y_tilde_approx = model.fit_loocv_predict(x_train, y_train.values)
efficient_loocv_runtime = time() - t_start

print(f"{np.abs(y_tilde - y_tilde_approx).mean()=}")
px.scatter(x=y_tilde, y=y_tilde_approx)
np.abs(y_tilde - y_tilde_approx).mean()=4.465176e-05
px.histogram(y_tilde - y_tilde_approx)

The approximation demonstrates high accuracy in this instance. Now, let’s compare the runtimes:

print(f"{standard_loocv_runtime = :.1e}")
print(f"{efficient_loocv_runtime = :.1e}")
print(f"{standard_loocv_runtime/efficient_loocv_runtime = :.0f}.")
standard_loocv_runtime = 4.7e+01
efficient_loocv_runtime = 1.3e-01
standard_loocv_runtime/efficient_loocv_runtime = 357.

A significant speedup! However, it’s important to note that there is room for further optimization. For instance, the BFGS iterations could be initialized with a previous solution, or we could utilize JIT compilation as we did in part 1.

References

Efficient LOOCV for ordinary least squares and ridge regression is mentioned in several well known books like The Elements of Statistical Learning and An Introduction to Statistical Learning. I first encountered it in a brief mention in All of Statistics.

The only reference I am aware of that discusses the general quadratic case, and a similar approach for the non-quadratic approximation, is this theses by Rosa Meijer.