# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/99_pytorch_doc.ipynb.

# %% ../nbs/99_pytorch_doc.ipynb 5
from __future__ import annotations
from types import ModuleType

# %% auto 0
__all__ = ['PYTORCH_URL', 'pytorch_doc_link']

# %% ../nbs/99_pytorch_doc.ipynb 7
PYTORCH_URL = 'https://pytorch.org/docs/stable/'

# %% ../nbs/99_pytorch_doc.ipynb 8
def _mod2page(
    mod:ModuleType, # A PyTorch module
) -> str:
    "Get the webpage name for a PyTorch module"
    if mod == Tensor: return 'tensors.html'
    name = mod.__name__
    name = name.replace('torch.', '').replace('utils.', '')
    if name.startswith('nn.modules'): return 'nn.html'
    return f'{name}.html'

# %% ../nbs/99_pytorch_doc.ipynb 10
import importlib

# %% ../nbs/99_pytorch_doc.ipynb 11
def pytorch_doc_link(
    name:str # Name of a PyTorch module, class or function
) -> (str, None):
    "Get the URL to the documentation of a PyTorch module, class or function"
    if name.startswith('F'): name = 'torch.nn.functional' + name[1:]
    if not name.startswith('torch.'): name = 'torch.' + name
    if name == 'torch.Tensor': return f'{PYTORCH_URL}tensors.html'
    try:
        mod = importlib.import_module(name)
        return f'{PYTORCH_URL}{_mod2page(mod)}'
    except: pass
    splits = name.split('.')
    mod_name,fname = '.'.join(splits[:-1]),splits[-1]
    if mod_name == 'torch.Tensor': return f'{PYTORCH_URL}tensors.html#{name}'
    try:
        mod = importlib.import_module(mod_name)
        page = _mod2page(mod)
        return f'{PYTORCH_URL}{page}#{name}'
    except: return None
