
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "auto_examples/example_group_lasso.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        Click :ref:`here <sphx_glr_download_auto_examples_example_group_lasso.py>`
        to download the full example code

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_auto_examples_example_group_lasso.py:


GroupLasso for linear regression
================================

A sample script for group lasso regression

.. GENERATED FROM PYTHON SOURCE LINES 9-11

Setup
-----

.. GENERATED FROM PYTHON SOURCE LINES 11-22

.. code-block:: default


    import matplotlib.pyplot as plt
    import numpy as np
    from sklearn.metrics import r2_score

    from group_lasso import GroupLasso

    np.random.seed(0)
    GroupLasso.LOG_LOSSES = True









.. GENERATED FROM PYTHON SOURCE LINES 23-25

Set dataset parameters
----------------------

.. GENERATED FROM PYTHON SOURCE LINES 25-35

.. code-block:: default

    group_sizes = [np.random.randint(10, 20) for i in range(50)]
    active_groups = [np.random.randint(2) for _ in group_sizes]
    groups = np.concatenate(
        [size * [i] for i, size in enumerate(group_sizes)]
    ).reshape(-1, 1)
    num_coeffs = sum(group_sizes)
    num_datapoints = 10000
    noise_std = 20









.. GENERATED FROM PYTHON SOURCE LINES 36-38

Generate data matrix
--------------------

.. GENERATED FROM PYTHON SOURCE LINES 38-41

.. code-block:: default

    X = np.random.standard_normal((num_datapoints, num_coeffs))









.. GENERATED FROM PYTHON SOURCE LINES 42-44

Generate coefficients
---------------------

.. GENERATED FROM PYTHON SOURCE LINES 44-55

.. code-block:: default

    w = np.concatenate(
        [
            np.random.standard_normal(group_size) * is_active
            for group_size, is_active in zip(group_sizes, active_groups)
        ]
    )
    w = w.reshape(-1, 1)
    true_coefficient_mask = w != 0
    intercept = 2









.. GENERATED FROM PYTHON SOURCE LINES 56-58

Generate regression targets
---------------------------

.. GENERATED FROM PYTHON SOURCE LINES 58-62

.. code-block:: default

    y_true = X @ w + intercept
    y = y_true + np.random.randn(*y_true.shape) * noise_std









.. GENERATED FROM PYTHON SOURCE LINES 63-65

View noisy data and compute maximum R^2
---------------------------------------

.. GENERATED FROM PYTHON SOURCE LINES 65-74

.. code-block:: default

    plt.figure()
    plt.plot(y, y_true, ".")
    plt.xlabel("Noisy targets")
    plt.ylabel("Noise-free targets")
    # Use noisy y as true because that is what we would have access
    # to in a real-life setting.
    R2_best = r2_score(y, y_true)





.. image:: /auto_examples/images/sphx_glr_example_group_lasso_001.png
    :alt: example group lasso
    :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 75-77

Generate estimator and train it
-------------------------------

.. GENERATED FROM PYTHON SOURCE LINES 77-91

.. code-block:: default

    gl = GroupLasso(
        groups=groups,
        group_reg=5,
        l1_reg=0,
        frobenius_lipschitz=True,
        scale_reg="inverse_group_size",
        subsampling_scheme=1,
        supress_warning=True,
        n_iter=1000,
        tol=1e-3,
    )
    gl.fit(X, y)






.. rst-class:: sphx-glr-script-out

 Out:

 .. code-block:: none


    GroupLasso(frobenius_lipschitz=True, group_reg=5,
               groups=array([[ 0],
           [ 0],
           [ 0],
           [ 0],
           [ 0],
           [ 0],
           [ 0],
           [ 0],
           [ 0],
           [ 0],
           [ 0],
           [ 0],
           [ 0],
           [ 0],
           [ 0],
           [ 1],
           [ 1],
           [ 1],
           [ 1],
           [ 1],
           [ 1],
           [ 1],
           [ 1],
           [ 1],
           [ 1],
           [ 2],
           [ 2],
           [ 2],
           [ 2],
           [ 2],
           [ 2],
           [ 2],
           [ 2],
           [ 2],
           [ 2],
           [ 2],
           [ 2],
           [ 2],
           [ 3],
           [ 3],
           [ 3],
           [ 3],
           [ 3],
           [ 3],
           [ 3],
           [ 3],
           [ 3],
           [ 3],
           [ 3],
           [ 3],
           [ 3],
           [ 4],
           [ 4],
           [ 4],
           [ 4],
           [ 4],
           [ 4],
           [ 4],
           [ 4],
           [ 4],
           [ 4],
           [ 4],
           [ 4],
           [ 4],
           [ 4],
           [ 4],
           [ 4],
           [ 4],
           [ 5],
           [ 5],
           [ 5],
           [ 5],...
           [46],
           [46],
           [46],
           [46],
           [46],
           [46],
           [46],
           [46],
           [46],
           [46],
           [47],
           [47],
           [47],
           [47],
           [47],
           [47],
           [47],
           [47],
           [47],
           [47],
           [47],
           [47],
           [47],
           [47],
           [47],
           [47],
           [47],
           [48],
           [48],
           [48],
           [48],
           [48],
           [48],
           [48],
           [48],
           [48],
           [48],
           [48],
           [48],
           [49],
           [49],
           [49],
           [49],
           [49],
           [49],
           [49],
           [49],
           [49],
           [49]]),
               l1_reg=0, n_iter=1000, scale_reg='inverse_group_size',
               subsampling_scheme=1, supress_warning=True, tol=0.001)



.. GENERATED FROM PYTHON SOURCE LINES 92-94

Extract results and compute performance metrics
-----------------------------------------------

.. GENERATED FROM PYTHON SOURCE LINES 94-109

.. code-block:: default


    # Extract info from estimator
    yhat = gl.predict(X)
    sparsity_mask = gl.sparsity_mask_
    w_hat = gl.coef_

    # Compute performance metrics
    R2 = r2_score(y, yhat)

    # Print results
    print(f"Number variables: {len(sparsity_mask)}")
    print(f"Number of chosen variables: {sparsity_mask.sum()}")
    print(f"R^2: {R2}, best possible R^2 = {R2_best}")






.. rst-class:: sphx-glr-script-out

 Out:

 .. code-block:: none

    Number variables: 720
    Number of chosen variables: 313
    R^2: 0.29097931452380443, best possible R^2 = 0.46262785225190173




.. GENERATED FROM PYTHON SOURCE LINES 110-112

Visualise regression coefficients
---------------------------------

.. GENERATED FROM PYTHON SOURCE LINES 112-132

.. code-block:: default

    plt.figure()
    plt.plot(w, ".", label="True weights")
    plt.plot(w_hat, ".", label="Estimated weights")

    plt.figure()
    plt.plot([w.min(), w.max()], [w_hat.min(), w_hat.max()], "gray")
    plt.scatter(w, w_hat, s=10)
    plt.ylabel("Learned coefficients")
    plt.xlabel("True coefficients")

    plt.figure()
    plt.plot(gl.losses_)
    plt.title("Loss plot")
    plt.ylabel("Mean squared error")
    plt.xlabel("Iteration")

    print("X shape: {X.shape}".format(X=X))
    print("True intercept: {intercept}".format(intercept=intercept))
    print("Estimated intercept: {intercept}".format(intercept=gl.intercept_))
    plt.show()



.. rst-class:: sphx-glr-horizontal


    *

      .. image:: /auto_examples/images/sphx_glr_example_group_lasso_002.png
          :alt: example group lasso
          :class: sphx-glr-multi-img

    *

      .. image:: /auto_examples/images/sphx_glr_example_group_lasso_003.png
          :alt: example group lasso
          :class: sphx-glr-multi-img

    *

      .. image:: /auto_examples/images/sphx_glr_example_group_lasso_004.png
          :alt: Loss plot
          :class: sphx-glr-multi-img


.. rst-class:: sphx-glr-script-out

 Out:

 .. code-block:: none

    X shape: (10000, 720)
    True intercept: 2
    Estimated intercept: [2.08271211]





.. rst-class:: sphx-glr-timing

   **Total running time of the script:** ( 0 minutes  22.963 seconds)


.. _sphx_glr_download_auto_examples_example_group_lasso.py:


.. only :: html

 .. container:: sphx-glr-footer
    :class: sphx-glr-footer-example



  .. container:: sphx-glr-download sphx-glr-download-python

     :download:`Download Python source code: example_group_lasso.py <example_group_lasso.py>`



  .. container:: sphx-glr-download sphx-glr-download-jupyter

     :download:`Download Jupyter notebook: example_group_lasso.ipynb <example_group_lasso.ipynb>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
