# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/10_patch.ipynb.

# %% auto 0
__all__ = ['monkey_patch']

# %% ../nbs/10_patch.ipynb 4
import numpy as np
import jax
import jax.numpy as jnp
from jax._src import array
from fastcore.foundation import patch_to
import matplotlib.pyplot as plt

from .repr_str import StrProxy
from .repr_rgb import RGBProxy
from .repr_plt import PlotProxy
from .repr_chans import ChanProxy

# %% ../nbs/10_patch.ipynb 5
def _monkey_patch(cls):
    "Monkey-patch lovely features into `cls`" 

    if not hasattr(cls, '_plain_repr'):
        cls._plain_repr = cls.__repr__
        cls._plain_str = cls.__str__

    @patch_to(cls)
    def __repr__(self: jax.Array):
        return str(StrProxy(self))
    
    # __str__ is used when you do print(), and gives a less detailed version of the object.
    # __repr__ is used when you inspect an object in Jupyter or VSCode, and gives a more detailed version.
    # I think we want to patch both.
    @patch_to(cls)
    def __str__(self: jax.Array):
        return str(StrProxy(self))


    # Plain - the old behavior
    @patch_to(cls, as_prop=True)
    def p(self: jax.Array):
        return StrProxy(self, plain=True)

    # Verbose - print both stats and plain values
    @patch_to(cls, as_prop=True)
    def v(self: jax.Array):
        return StrProxy(self, verbose=True)

    @patch_to(cls, as_prop=True)
    def deeper(self: jax.Array):
        return StrProxy(self, depth=1)

    @patch_to(cls, as_prop=True)
    def rgb(t: jax.Array):
        return RGBProxy(t)
    
    @patch_to(cls, as_prop=True)
    def chans(t: jax.Array):
        return ChanProxy(t)

    @patch_to(cls, as_prop=True)
    def plt(t: jax.Array):
        return PlotProxy(t)


def monkey_patch():
    _monkey_patch(array.ArrayImpl)
    _monkey_patch(array.DeviceArray)    
