[1]:
%matplotlib inline
2.5. Multinomial Logistic Regression#
We would like to use an example to show how the sparse-constrained optimization for multinomial logistic regression works in our program.
2.5.1. Introduction#
Multinomial logistic regression is a type of regression analysis used to predict the probabilities of multiple categorical outcomes. It is an extension of binary logistic regression, which is used to predict the probability of a binary outcome.
2.5.2. Mathematical Derivation#
In multinomial logistic regression, we have multiple categories, denoted by \(k=1,2,...,K\). We want to predict the probability of each category given a set of predictor variables \(X\). We assume that the probability of each category is a function of the predictor variables, and that the probabilities for each category sum to 1.
We can model the probability of each category using the softmax function:
where \(Y\) is the categorical outcome, \(X\) is the vector of predictor variables, \(\beta _{0k}\) and \(\beta _k\) are the intercept and coefficient vectors for category \(k\), and \(e\) is the base of the natural logarithm.
The softmax function ensures that the probabilities for each category sum to 1. The numerator of the function represents the probability of category \(k\), and the denominator represents the sum of the probabilities for all categories.
We can estimate the coefficients using maximum likelihood estimation. The likelihood function for multinomial logistic regression is:
where \(n\) is the number of observations, \(I(Y_i=k)\) is an indicator function that equals 1 if \(Y_i=k\) and 0 otherwise, and \(P(Y_i=k|X_i=x_i)\) is the predicted probability of category \(k\) for observation \(i\).
The negative log-likelihood function is:
This is the function that we want to minimize in order to estimate the coefficients. We can use scope algorithm to find the values of \(\beta\) with sparsity constraints that minimize the negative log-likelihood function.
Here is Python code for solving sparse gamma regression problem:
2.5.3. Import necessary packages#
[2]:
import jax.numpy as jnp
import numpy as np
from skscope import ScopeSolver
import numpy as np
from abess.datasets import make_multivariate_glm_data
2.5.4. Set a seed#
[3]:
np.random.seed(3)
2.5.5. Generate the data#
Firstly, we shall conduct Multinomial logistic regression on an artificial dataset for demonstration. The make_multivariate_glm_data from abess.datasets function allows us to generate simulated data by specifying the family="multinomial".
The assumption behind this model is that the response vector follows a multinomial distribution. The artificial dataset contains 500 observations and 20 predictors but only five predictors have influence on the three possible classes.
[4]:
n = 500 # sample size
p = 20 # all predictors
k = 5 # real predictors
m = 3 # number of classes
data = make_multivariate_glm_data(n=n, p=p, k=k, family="multinomial", M=m)
X = data.x
y = data.y
print('real variables\' index:\n', set(np.nonzero(data.coef_)[0]))
print('real variables:\n', data.coef_)
real variables' index:
{0, 3, 7, 10, 19}
real variables:
[[ 5.44916029 -0.94953634 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ -1.39241163 -12.96678673 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 3.24543565 4.02033588 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ -1.38210809 4.07755579 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 5.79719104 3.61096451 0. ]]
2.5.6. Define multinomial regression loss#
Secondly, to carry out sparse-constrained optimization for multinomial logistic regression, we define the loss function multinomial_regression_loss accorting to 1 that matches the data generating function make_multivariate_glm_data.
[5]:
def multinomial_regression_loss(params):
beta = params.reshape((p, m))
# Compute the logits
logits = jnp.dot(X, beta)
# Compute the softmax probabilities
softmax_probs = jnp.exp(logits) / jnp.sum(jnp.exp(logits), axis=1, keepdims=True)
# Compute the NLL loss
loss = -jnp.mean(jnp.sum(y * jnp.log(softmax_probs), axis=1))
return loss
2.5.7. Use skscope to solve the sparse multinomial logistic regression problem#
After defining the data and the loss function, we can call ScopeSolver to solve the sparse-constrained optimization problem.
[6]:
solver = ScopeSolver(p*(m), k, group=[i for i in range(p) for j in range(m)])
params = solver.solve(multinomial_regression_loss, jit=True)
Now the solver.params contains the coefficients of multinomial logistic model with no more than 5 variables. That is, those variables with a coefficient 0 is unused in the model:
[7]:
print(solver.params.reshape((p, m)))
[[ 4.39822362 -2.50996204 -1.88825765]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 3.99429591 -9.68531333 5.69101694]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 0.93656944 2.00800227 -2.94457135]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[-2.31647311 3.43094489 -1.11447062]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 2.54118658 0.82594677 -3.36713346]]
We can further compare the coefficients estimated by skscope and the real coefficients in two-fold:
The true support set and the estimated support set
The true nonzero parameters and the estimated nonzero parameters
[8]:
print('real variables\' index:\n', set(np.nonzero(data.coef_)[0]))
print('Estimated variables\' index:\n', set(np.nonzero(solver.params.reshape((p, m)))[0]))
real variables' index:
{0, 3, 7, 10, 19}
Estimated variables' index:
{0, 3, 7, 10, 19}
[9]:
print("True parameter:\n", data.coef_)
print("Estimated parameter:\n", solver.params.reshape((p, m)))
True parameter:
[[ 5.44916029 -0.94953634 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ -1.39241163 -12.96678673 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 3.24543565 4.02033588 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ -1.38210809 4.07755579 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 5.79719104 3.61096451 0. ]]
Estimated parameter:
[[ 4.39822362 -2.50996204 -1.88825765]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 3.99429591 -9.68531333 5.69101694]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 0.93656944 2.00800227 -2.94457135]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[-2.31647311 3.43094489 -1.11447062]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 2.54118658 0.82594677 -3.36713346]]
[ ]: