Comparison with cr-sparse#

In this notebook, we compare the performance of different implemention of common sparse constrained optimization algorithms in skscope and cr-sparse:

  • IHT: Iterative Hard Thresholding

  • OMP: Orthogonal Matching Pursuit

  • HTP: Hard Thresholding Pursuit

  • Grasp or CoSaMP: Compressive Sampling Matching Pursuit.

[2]:
import numpy as np
import pandas as pd
import time
import jax.numpy as jnp
from skscope.solver import *
import cr.sparse.dict as crdict
from cr.sparse.pursuit import iht, omp, htp, cosamp
from abess.datasets import make_glm_data

The following function generate synthetic data and solve the sparse constrained least-square problem.

The algorithm implemented in skscope and cr-sparse libraries are compared and the recovery accuracy and computation time are reported.

[3]:
def test(n=500, p=1000, s=5, random_state=None):
    print('='*20 + f'  n={n}, p={p}, s={s}  ' + '='*20 )
    rng = np.random.default_rng(random_state)
    true_support_set = rng.choice(np.arange(p), size=s, replace=False)
    real_coef = np.zeros(p)
    real_coef[true_support_set] = rng.choice(np.arange(1, 4), size=s) * rng.choice([1, -1], size=s)
    data = make_glm_data(n=n, p=p, k=s, family='gaussian', coef_=real_coef)
    X, y = data.x, data.y

    iterables = [['OMP', 'IHT', 'HTP', 'Grasp'], ['cr-sparse', 'skscope']]
    index = pd.MultiIndex.from_product(iterables, names=['Algorithm', 'Package'])
    res = pd.DataFrame(columns=['Accuracy', 'Time'], index = index)

    def objective(params):
        loss = jnp.mean((y - X @ params) ** 2)
        return loss

    for algo in iterables[0]:
        if algo == 'OMP':
            solver = OMPSolver(p, sparsity=s)
            model = omp
        elif algo == 'IHT':
            solver = IHTSolver(p, sparsity=s)
            model = iht
        elif algo == 'HTP':
            solver = HTPSolver(p, sparsity=s)
            model = htp
        elif algo == 'Grasp':
            solver = GraspSolver(p, sparsity=s)
            model = cosamp

        # cr-sparse
        t_begin = time.time()
        solution = model.matrix_solve(jnp.array(X), y, s)
        t_cr = time.time() - t_begin
        acc_cr = len(set(solution.I.tolist()) & set(true_support_set)) / s
        res.loc[(algo, 'cr-sparse')] = [acc_cr, np.round(t_cr, 4)]

        # skscope
        t_begin = time.time()
        params = solver.solve(objective, jit=True)
        t_skscope = time.time() - t_begin
        acc_skscope = len(set(np.nonzero(params)[0]) & set(np.nonzero(data.coef_)[0])) / s
        res.loc[(algo, 'skscope')] = [acc_skscope, np.round(t_skscope, 4)]

    print(res)

The results are shown in the following three tables and each correspons to a specific data dimension.

Both recovery accuracy and computation time show the superiority of skscope over cr-sparse for all the above algorithms.

[4]:
settings = [
    (500, 1000, 5),
    (2000, 5000, 10),
    (5000, 10000, 10),
]
for setting in settings:
    n, p, s = setting
    test(n=n, p=p, s=s)
====================  n=500, p=1000, s=5  ====================
                    Accuracy    Time
Algorithm Package
OMP       cr-sparse      0.2  1.9142
          skscope        1.0   0.198
IHT       cr-sparse      0.0  0.4089
          skscope        0.8  0.1685
HTP       cr-sparse      0.4  0.5739
          skscope        0.8   0.166
Grasp     cr-sparse      1.0  0.8989
          skscope        1.0  0.1799
====================  n=2000, p=5000, s=10  ====================
                    Accuracy     Time
Algorithm Package
OMP       cr-sparse      0.1    2.647
          skscope        1.0   1.8372
IHT       cr-sparse      0.6   3.4628
          skscope        1.0   1.2507
HTP       cr-sparse      0.6  42.7856
          skscope        1.0    1.257
Grasp     cr-sparse      1.0  43.2662
          skscope        1.0   1.5364
====================  n=5000, p=10000, s=10  ====================
                    Accuracy      Time
Algorithm Package
OMP       cr-sparse      0.1    2.5915
          skscope        1.0    8.4356
IHT       cr-sparse      0.7    8.3954
          skscope        1.0    6.1218
HTP       cr-sparse      0.3  590.8998
          skscope        1.0    6.1951
Grasp     cr-sparse      1.0  603.0331
          skscope        1.0    6.5937