Metadata-Version: 2.1
Name: pytorch_saver
Version: 0.1.1
Summary: Simple helper to save and load PyTorch models
Author-email: Matheus Pedroni <pnmatheus@protonmail.com>
License: MIT License
        
        Copyright (c) 2022 Matheus Pedroni
        
        Permission is hereby granted, free of charge, to any person obtaining a copy
        of this software and associated documentation files (the "Software"), to deal
        in the Software without restriction, including without limitation the rights
        to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
        copies of the Software, and to permit persons to whom the Software is
        furnished to do so, subject to the following conditions:
        
        The above copyright notice and this permission notice shall be included in all
        copies or substantial portions of the Software.
        
        THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
        IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
        FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
        AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
        LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
        OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
        SOFTWARE.
License-File: LICENSE
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3
Requires-Python: >=3.8
Requires-Dist: torch>=1.8.2
Description-Content-Type: text/markdown

# pytorch-saver
Simple helper to save and load PyTorch models.

repository: https://github.com/mathpn/pytorch-saver

## Why use this package to save and load models?

PyTorch [suggests](https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_models_for_inference.html) two standard ways of saving and loading models: either saving their state_dict or saving the pickled model object itself.

Both methods have drawbacks:

- Saving the state_dict is very flexible, however we loose all the arguments used to create the model, the optimizer and (optionally) the scheduler;
- Saving a pickled snapshot solves this issue, but it's not flexible at all. Even minor changes to the model class can break the unpickling process and the arguments used to define the object are still obscured behind the objects themselves.
  
Therefore, the goal of this package is to provide a pratical way of creating models and associated objects, saving, and loading them without headaches. Also, any additional metadata should be included in the saved file.

## Installing

Install with pip:

    pip install pytorch-saver

Or clone the repository and go inside its folder:

    cd pytorch-saver

and install from source:

    pip install .


## How to use it

### Initializing objects

Import ModelContainer and create a new container instance.

    from pytorch_saver import ModelContainer
    container = ModelContainer()

This is the only part of the pipeline that breaks with Python conventions. Since we need to store all arguments used to create the objects as to recreate them, they are created through the initialize method.

Pass all the classes and dictionaries with all keyword arguments to initialize them to the initialize method.

    model_objects = container.initialize(
        Model,
        model_kwargs,
        torch.optim.Adam,
        optim_kwargs
    )

Model objects is a NamedTuple with three attributes: model, optimizer and scheduler. Access these objects (if created though initialize) and train your model.

### Saving checkpoints

Use the save method to save checkpoints:

    container.save("./", "tutorial")

This will save a checkpoint to "./tutorial_checkpoint_TIMESTAMP.zip", where TIMESTAMP is the current Unix timestamp in seconds.

Any additional keyword arguments provided will be saved as model metadata, as long as they are JSON-serializable.

    container.save("./", "tutorial", loss=0.55, epoch=5)

If you only want to save the model (ignoring optimizer and scheduler), use the save_inference method.

### Loading saved files

Use the load method to load checkpoints:

    from pytorch_saver import ModelContainer
    container = ModelContainer()
    metadata, objs = container.load(file_path)

"metadata" is a dictionary with the arguments used to initialize all objects, the timestamp, and any additional arguments passed to the saved method when saving this file.

"objs" is a NamedTuple with the same structure as the one returned by the initialize method.
