Metadata-Version: 2.1
Name: eqxvision
Version: 0.0.1
Summary: Root package info.
Author-email: Contributing Authors <aditya.91.singh@gmail.com>
Requires-Python: >=3.8
Description-Content-Type: text/markdown
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3 :: Only
Requires-Dist: jax==0.3.15
Requires-Dist: jaxlib==0.3.15
Requires-Dist: equinox
Requires-Dist: jinja2==3.0.3 ; extra == "docs"
Requires-Dist: Markdown>=3.3 ; extra == "docs"
Requires-Dist: MarkupSafe>=1.1 ; extra == "docs"
Requires-Dist: mkdocs>=1.2 ; extra == "docs"
Requires-Dist: mkdocs-autorefs>=0.3.1 ; extra == "docs"
Requires-Dist: pymdown-extensions>=6.3 ; extra == "docs"
Requires-Dist: mkdocs==1.3.0 ; extra == "docs"
Requires-Dist: mkdocs-autorefs ; extra == "docs"
Requires-Dist: mkdocs_include_exclude_files==0.0.1 ; extra == "docs"
Requires-Dist: mkdocs-material==7.3.6 ; extra == "docs"
Requires-Dist: mkdocs-material-extensions ; extra == "docs"
Requires-Dist: mkdocstrings==0.17.0 ; extra == "docs"
Requires-Dist: mkdocstrings-python ; extra == "docs"
Requires-Dist: mkdocstrings-python-legacy ; extra == "docs"
Requires-Dist: mknotebooks==0.7.1 ; extra == "docs"
Requires-Dist: pymdown-extensions==9.4 ; extra == "docs"
Requires-Dist: pytkdocs_tweaks==0.0.5 ; extra == "docs"
Requires-Dist: pre-commit ; extra == "test"
Requires-Dist: bluepy ; extra == "test"
Requires-Dist: pytest ; extra == "test"
Project-URL: Bug Tracker, https://github.com/paganpasta/eqxvision/issues
Project-URL: Homepage, https://github.com/paganpasta/eqxvision
Provides-Extra: docs
Provides-Extra: test

# Eqxvision

Eqxvision is a Python library providing computer vision models to the [Equinox](https://docs.kidger.site/equinox/) ecosystem.

## Installation

Use the package manager [pip](https://pip.pypa.io/en/stable/) to install foobar.

```bash
pip install eqxvision
```

## Usage

```python
import jax
import jax.random as jr
import equinox as eqx
from eqxvision.models import resnet18

@eqx.filter_jit
def forward(net, images, key):
    keys = jax.random.split(key, images.shape[0])
    jax.vmap(net)(images, key=keys)

net = resnet18(num_classes=1000)

images = jr.uniform(jr.PRNGKey(0), shape=(1,3,224,224))
output = forward(net, images, jr.PRNGKey(0))
```

## Tips
- Use `jax.vmap(net, axis_name='batch')(images)` for models with `batchnorms`.
- Don't forget to call `eqx.inference` for switching to `inference` mode.

## Roadmap

- [ ] Add VGGs, Inception, GoogLeNet
- [ ] Add/Explore functionality to load weights directly from torch.pth
- [ ] Doc fixes
- [ ] Build fixes
- [ ] Pre-commit Hooks


## Contributing
Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change.

Please make sure to update tests as appropriate.

## Acknowledgements
- [Torchvision](https://pytorch.org/vision/stable/index.html)
- [Equinox](https://github.com/patrick-kidger/equinox)
- [Patrick Kidger](https://github.com/patrick-kidger)

## License
[MIT](https://choosealicense.com/licenses/mit/)
