Metadata-Version: 2.1
Name: torchmix
Version: 0.1.0rc2
Summary: 
Author: junhsss
Author-email: junhsssr@gmail.com
Requires-Python: >=3.10,<4.0
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Requires-Dist: einops (>=0.6.0,<0.7.0)
Requires-Dist: hydra-core (>=1.3.1,<2.0.0)
Requires-Dist: hydra-zen (>=0.8.0,<0.9.0)
Requires-Dist: jaxtyping (>=0.2.11,<0.3.0)
Description-Content-Type: text/markdown

<h1 align="center">torchmix</h1>

<h3 align="center">The missing component library for PyTorch</h3>

<br />

Welcome to torchmix, a collection of PyTorch modules that aims to reduce boilerplate and improve code modularity.

**Please note: `torchmix` is currently in development and has not been tested for production use. The API may change at any time.**

<br />

## Usage

To use `torchmix`, simply import the desired components:

```python
import torchmix.nn as nn  # Wrapped version of torch.nn
from torchmix import (
    Add,
    Attach,
    AvgPool,
    ChannelMixer,
    Extract,
    PatchEmbed,
    PositionEmbed,
    PreNorm,
    Repeat,
    SelfAttention,
    Token,
)
```

You can simply compose this components to build more complex architecture, as shown in the following example:

```python
vit_cls = nn.Sequential(
    Add(
        Attach(
            PatchEmbed(dim=1024),
            Token(dim=1024),
        ),
        PositionEmbed(
            seq_length=196 + 1,
            dim=1024,
        ),
    ),
    Repeat(
        nn.Sequential(
            PreNorm(
                ChannelMixer(
                    dim=1024,
                    expansion_factor=4,
                    act_layer=nn.GELU.partial(),
                ),
                dim=1024,
            ),
            PreNorm(
                SelfAttention(
                    dim=1024,
                    num_heads=8,
                    head_dim=64,
                ),
                dim=1024,
            ),
        )
    ),
    Extract(0),
)

vit_avg = nn.Sequential(
    Add(
        PatchEmbed(dim=1024),
        PositionEmbed(
            seq_length=196,
            dim=1024,
        ),
    ),
    Repeat(
        nn.Sequential(
            PreNorm(
                ChannelMixer(
                    dim=1024,
                    expansion_factor=4,
                    act_layer=nn.GELU.partial(),
                ),
                dim=1024,
            ),
            PreNorm(
                SelfAttention(
                    dim=1024,
                    num_heads=8,
                    head_dim=64,
                ),
                dim=1024,
            ),
        )
    ),
    AvgPool(),
)
```

### Integration with Hydra

Reproducibility is important, so it is always a good idea to store the configurations of your models. However, manually writing the configurations for complex, deeply nested PyTorch modules can be tedious and result in code that is difficult to understand and maintain. This is because the parent class may need to accept and pass along the parameters of its children classes, leading to a large number of arguments and strong coupling between the parent and child classes.

`torchmix` simplifies this process by **auto-magically** generating the full configuration of a PyTorch module **simply by instantiating it.** This enables effortless integration with the `hydra` ecosystem, which allows for easy storage and management of module configurations.

To generate a configuration for a typical MLP using `torchmix`, for example, you can do the following:

```python
from torchmix import nn

model = nn.Sequential(
    nn.Linear(1024, 4096),
    nn.Dropout(0.1),
    nn.GELU(),
    nn.Linear(4096, 1024),
    nn.Dropout(0.1),
)
```

You can then store the configuration in the `hydra`'s `ConfigStore` using:

```python
model.store(group="model", name="mlp")
```

Alternatively, you can export it to a YAML file if you want:

```python
model.export("mlp.yaml")
```

This will generate the following configuration:

```yaml
_target_: torchmix.nn.Sequential
_args_:
  - _target_: torchmix.nn.Linear
    in_features: 1024
    out_features: 4096
    bias: true
    device: null
    dtype: null
  - _target_: torchmix.nn.Dropout
    p: 0.1
    inplace: false
  - _target_: torchmix.nn.GELU
    approximate: none
  - _target_: torchmix.nn.Linear
    in_features: 4096
    out_features: 1024
    bias: true
    device: null
    dtype: null
  - _target_: torchmix.nn.Dropout
    p: 0.1
    inplace: false
```

You can always instantiate the actual PyTorch module from its configuration using `hydra`'s `instantiate` function.

To create custom modules with this functionality, simply subclass `MixModule` and define your module as you normally would:

```python
from torchmix import MixModule

class CustomModule(MixModule):
    def __init__(self, num_heads, dim, depth):
        pass

custom_module = CustomModule(16, 768, 12)
custom_module.store(group="model", name="custom")
```

## Documentation

Documentation is currently in progress. Please stay tuned! 🚀

## Contributing

We welcome contributions to the `torchmix` library. If you have ideas for new components or suggestions for improving the library, don't hesitate to open an issue or start a discussion. Please note that `torchmix` is still in the prototype phase, so any contributions should be considered experimental.

## License

`torchmix` is licensed under the MIT License.

