
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "auto_examples/example_warm_start.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        Click :ref:`here <sphx_glr_download_auto_examples_example_warm_start.py>`
        to download the full example code

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_auto_examples_example_warm_start.py:


Warm start to choose regularisation strength
============================================

.. GENERATED FROM PYTHON SOURCE LINES 7-9

Setup
-----

.. GENERATED FROM PYTHON SOURCE LINES 9-19

.. code-block:: default


    import matplotlib.pyplot as plt
    import numpy as np

    from group_lasso import GroupLasso

    np.random.seed(0)
    GroupLasso.LOG_LOSSES = True









.. GENERATED FROM PYTHON SOURCE LINES 20-22

Set dataset parameters
----------------------

.. GENERATED FROM PYTHON SOURCE LINES 22-32

.. code-block:: default

    group_sizes = [np.random.randint(10, 20) for i in range(50)]
    active_groups = [np.random.randint(2) for _ in group_sizes]
    groups = np.concatenate(
        [size * [i] for i, size in enumerate(group_sizes)]
    ).reshape(-1, 1)
    num_coeffs = sum(group_sizes)
    num_datapoints = 10000
    noise_std = 20









.. GENERATED FROM PYTHON SOURCE LINES 33-35

Generate data matrix
--------------------

.. GENERATED FROM PYTHON SOURCE LINES 35-38

.. code-block:: default

    X = np.random.standard_normal((num_datapoints, num_coeffs))









.. GENERATED FROM PYTHON SOURCE LINES 39-41

Generate coefficients
---------------------

.. GENERATED FROM PYTHON SOURCE LINES 41-52

.. code-block:: default

    w = np.concatenate(
        [
            np.random.standard_normal(group_size) * is_active
            for group_size, is_active in zip(group_sizes, active_groups)
        ]
    )
    w = w.reshape(-1, 1)
    true_coefficient_mask = w != 0
    intercept = 2









.. GENERATED FROM PYTHON SOURCE LINES 53-55

Generate regression targets
---------------------------

.. GENERATED FROM PYTHON SOURCE LINES 55-59

.. code-block:: default

    y_true = X @ w + intercept
    y = y_true + np.random.randn(*y_true.shape) * noise_std









.. GENERATED FROM PYTHON SOURCE LINES 60-62

Generate estimator and train it
-------------------------------

.. GENERATED FROM PYTHON SOURCE LINES 62-84

.. code-block:: default

    num_regs = 10
    regularisations = np.logspace(-0.5, 1.5, num_regs)
    weights = np.empty((num_regs, w.shape[0],))
    gl = GroupLasso(
        groups=groups,
        group_reg=5,
        l1_reg=0,
        frobenius_lipschitz=True,
        scale_reg="inverse_group_size",
        subsampling_scheme=1,
        supress_warning=True,
        n_iter=1000,
        tol=1e-3,
        warm_start=True,  # Warm start to start each subsequent fit with previous weights
    )

    for i, group_reg in enumerate(regularisations[::-1]):
        gl.group_reg = group_reg
        gl.fit(X, y)
        weights[-(i + 1)] = gl.sparsity_mask_.squeeze()









.. GENERATED FROM PYTHON SOURCE LINES 85-87

Visualise chosen covariate groups
---------------------------------

.. GENERATED FROM PYTHON SOURCE LINES 87-94

.. code-block:: default

    plt.figure()
    plt.pcolormesh(np.arange(w.shape[0]), regularisations, -weights, cmap="gray")
    plt.yscale("log")
    plt.xlabel("Covariate number")
    plt.ylabel("Regularisation strength")
    plt.title("Active groups are black and inactive groups are white")
    plt.show()



.. image:: /auto_examples/images/sphx_glr_example_warm_start_001.png
    :alt: Active groups are black and inactive groups are white
    :class: sphx-glr-single-img


.. rst-class:: sphx-glr-script-out

 Out:

 .. code-block:: none

    /home/yngve/Dropbox/Programming/group-lasso/examples/example_warm_start.py:88: MatplotlibDeprecationWarning: shading='flat' when X and Y have the same dimensions as C is deprecated since 3.3.  Either specify the corners of the quadrilaterals with X and Y, or pass shading='auto', 'nearest' or 'gouraud', or set rcParams['pcolor.shading'].  This will become an error two minor releases later.
      plt.pcolormesh(np.arange(w.shape[0]), regularisations, -weights, cmap="gray")





.. rst-class:: sphx-glr-timing

   **Total running time of the script:** ( 1 minutes  4.583 seconds)


.. _sphx_glr_download_auto_examples_example_warm_start.py:


.. only :: html

 .. container:: sphx-glr-footer
    :class: sphx-glr-footer-example



  .. container:: sphx-glr-download sphx-glr-download-python

     :download:`Download Python source code: example_warm_start.py <example_warm_start.py>`



  .. container:: sphx-glr-download sphx-glr-download-jupyter

     :download:`Download Jupyter notebook: example_warm_start.ipynb <example_warm_start.ipynb>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
