Metadata-Version: 2.1
Name: toy-gradlogp-tf2
Version: 0.1.0
Summary: Some toy examples of score matching algorithms written in TensorFlow 2.0
Home-page: https://github.com/Ending2015a/toy_gradlogp_tf2
Author: JoeHsiao
Author-email: joehsiao@gapp.nthu.edu.tw
License: MIT
Keywords: score-matching playform tensorflow2.0 tensorflow
Platform: UNKNOWN
Classifier: Development Status :: 2 - Pre-Alpha
Classifier: Intended Audience :: Science/Research
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Software Development :: Libraries :: Python Modules
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3.6
Classifier: Programming Language :: Python :: 3.7
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Requires-Python: >=3.6
Description-Content-Type: text/markdown
License-File: LICENSE


<img src="https://user-images.githubusercontent.com/18180004/136144615-0cd92028-8226-40c1-81ee-fa6c067e91e3.png" align="right" width="25%"/>

# toy_gradlogp_tf2

This repo implements some toy examples of the following score matching algorithms in TensorFlow 2.0:
* `ssm-vr`: [sliced score matching](https://arxiv.org/abs/1905.07088) with variance reduction
* `ssm`: [sliced score matching](https://arxiv.org/abs/1905.07088)
* `deen`: [deep energy estimator networks](https://arxiv.org/abs/1805.08306)
* `dsm`: [denoisnig score matching](http://www.iro.umontreal.ca/~vincentp/Publications/smdae_techreport.pdf)

Related projects:
* [toy_gradlogp](https://github.com/Ending2015a/toy_gradlogp): PyTorch Implementation.

## Installation
Basic requirements:
* Python >= 3.6
* TensorFlow >= 2.3.0

Install from PyPI
```shell
pip install toy_gradlogp_tf2
```

Or install the latest version from this repo
```shell
pip install git+https://github.com.Ending2015a/toy_gradlogp_tf2.git@master
```

## Examples
The examples are placed in [toy_gradlogp/run/](https://github.com/Ending2015a/toy_gradlogp_tf2/tree/master/toy_gradlogp/run)

### Train an energy model

Run `ssm-vr` on `2spirals` dataset
```shell
python -m toy_gradlogp.run.train_energy --loss ssm-vr --data 2spirals
```

To see the full options, type `--help` command:
```
python -m toy_gradlogp.run.train_energy --help
```

```
usage: train_energy.py [-h] [--logdir LOGDIR]
                       [--data {8gaussians,2spirals,checkerboard,rings}]
                       [--loss {ssm-vr,ssm,deen,dsm}]
                       [--noise {radermacher,sphere,gaussian}] [--lr LR]
                       [--size SIZE] [--eval_size EVAL_SIZE]
                       [--batch_size BATCH_SIZE] [--n_epochs N_EPOCHS]
                       [--n_slices N_SLICES] [--n_steps N_STEPS] [--eps EPS]
                       [--log_freq LOG_FREQ] [--eval_freq EVAL_FREQ]
                       [--vis_freq VIS_FREQ]

optional arguments:
  -h, --help            show this help message and exit
  --logdir LOGDIR
  --data {8gaussians,2spirals,checkerboard,rings}
                        dataset
  --loss {ssm-vr,ssm,deen,dsm}
                        loss type
  --noise {radermacher,sphere,gaussian}
                        noise type
  --lr LR               learning rate
  --size SIZE           dataset size
  --eval_size EVAL_SIZE
                        dataset size for evaluation
  --batch_size BATCH_SIZE
                        training batch size
  --n_epochs N_EPOCHS   number of epochs to train
  --n_slices N_SLICES   number of slices for sliced score matching
  --n_steps N_STEPS     number of steps for langevin dynamics
  --eps EPS             noise scale for langevin dynamics
  --log_freq LOG_FREQ   logging frequency (unit: epoch)
  --eval_freq EVAL_FREQ
                        evaluation frequency (unit: epoch)
  --vis_freq VIS_FREQ   visualization frequency (unit: epoch)
```

## Results

Tips: The larger density has a lower energy!

### `8gaussians`

| Algorithm | Results|
|-|-|
|`ssm-vr`|![](/assets/ssm-vr_8gaussians.png)|
|`ssm`|![](/assets/ssm_8gaussians.png)|
|`deen`| ![](/assets/deen_8gaussians.png) |
|`dsm`| ![](/assets/dsm_8gaussians.png) |

### `2spirals`

| Algorithm | Results|
|-|-|
|`ssm-vr`|![](/assets/ssm-vr_2spirals.png)|
|`ssm`|![](/assets/ssm_2spirals.png)|
|`deen`| ![](/assets/deen_2spirals.png) |
|`dsm`| ![](/assets/dsm_2spirals.png) |

### `checkerboard`
| Algorithm | Results|
|-|-|
|`ssm-vr`|![](/assets/ssm-vr_checkerboard.png)|
|`ssm`|![](/assets/ssm_checkerboard.png)|
|`deen`| ![](/assets/deen_checkerboard.png) |
|`dsm`| ![](/assets/dsm_checkerboard.png) |

### `rings`
| Algorithm | Results|
|-|-|
|`ssm-vr`|![](/assets/ssm-vr_rings.png)|
|`ssm`|![](/assets/ssm_rings.png)|
|`deen`| ![](/assets/deen_rings.png) |
|`dsm`| ![](/assets/dsm_rings.png) |


